Introduction
PyTorch Geometric Temporal is a temporal graph neural network extension library for PyTorch Geometric. It builds on open-source deep-learning and graph processing libraries. PyTorch Geometric Temporal consists of state-of-the-art deep learning and parametric learning methods to process spatio-temporal signals. It is the first open-source library for temporal deep learning on geometric structures and provides constant time difference graph neural networks on dynamic and static graphs. We make this happen with the use of discrete time graph snapshots. Implemented methods cover a wide range of data mining (WWW, KDD), artificial intelligence and machine learning (AAAI, ICONIP, ICLR) conferences, workshops, and pieces from prominent journals.
Citing
If you find PyTorch Geometric Temporal useful in your research, please consider adding the following citation:
>@inproceedings{rozemberczki2021pytorch,
author = {Benedek Rozemberczki and Paul Scherer and Yixuan He and George Panagopoulos and Alexander Riedel and Maria Astefanoaei and Oliver Kiss and Ferenc Beres and Guzman Lopez and Nicolas Collignon and Rik Sarkar},
title = {{PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models}},
year = {2021},
booktitle={Proceedings of the 30th ACM International Conference on Information and Knowledge Management},
pages = {4564–4573},
}
We briefly overview the fundamental concepts and features of PyTorch Geometric Temporal through simple examples.
Data Structures
PyTorch Geometric Temporal is designed to provide easy to use data iterators which are parametrized with spatiotemporal data. These iterators can serve snapshots which are formed by a single graph or multiple graphs which are batched together with the block diagonal batching trick.
Temporal Signal Iterators
PyTorch Geometric Temporal offers data iterators for spatio-temporal datasets which contain the temporal snapshots. There are three types of data iterators:
StaticGraphTemporalSignal- Is designed for temporal signals defined on a static graph.DynamicGraphTemporalSignal- Is designed for temporal signals defined on a dynamic graph.DynamicGraphStaticSignal- Is designed for static signals defined on a dynamic graph.
Temporal Data Snapshots
A temporal data snapshot is a PyTorch Geometric Data object. Please take a look at this readme for the details. The returned temporal snapshot has the following attributes:
edge_index- A PyTorchLongTensorof edge indices used for node feature aggregation (optional).edge_attr- A PyTorchFloatTensorof edge features used for weighting the node feature aggregation (optional).x- A PyTorchFloatTensorof vertex features (optional).y- A PyTorchFloatTensororLongTensorof vertex targets (optional).
Temporal Signal Iterators with Batches
PyTorch Geometric Temporal offers data iterators for batched spatiotemporal datasets which contain the batched temporal snapshots. There are three types of batched data iterators:
StaticGraphTemporalSignalBatch- Is designed for temporal signals defined on a batch of static graphs.DynamicGraphTemporalSignalBatch- Is designed for temporal signals defined on a batch of dynamic graphs.DynamicGraphStaticSignalBatch- Is designed for static signals defined on a batch of dynamic graphs.
Temporal Batch Snapshots
A temporal batch snapshot is a PyTorch Geometric Batch object. Please take a look at this readme for the details. The returned temporal batch snapshot has the following attributes:
edge_index- A PyTorchLongTensorof edge indices used for node feature aggregation (optional).edge_attr- A PyTorchFloatTensorof edge features used for weighting the node feature aggregation (optional).x- A PyTorchFloatTensorof vertex features (optional).y- A PyTorchFloatTensororLongTensorof vertex targets (optional).batch- A PyTorchLongTensorof batch indices (optional).
Benchmark Datasets
We released and included a number of datasets which can be used for comparing the performance of temporal graph neural networks algorithms. The related machine learning tasks are node and graph level supervised learning.
Newly Released Datasets
In order to benchmark graph neural networks we released the following datasets:
Integrated Datasets
We also integrated existing datasets for performance evaluation:
The Hungarian Chickenpox Dataset can be loaded by the following code snippet. The dataset returned by the public get_dataset method is a StaticGraphTemporalSignal object.
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
loader = ChickenpoxDatasetLoader()
dataset = loader.get_dataset()
Spatiotemporal Signal Splitting
We provide functions to create temporal splits of the data iterators. These functions return train and test data iterators which split the original iterator using a fix train-test ratio. Snapshots from the earlier time periods contribute to the training dataset and snapshots from the later periods contribute to the test dataset. This way temporal forecasts can be evaluated in a real life like scenario. The function split_temporal_signal takes either a StaticGraphTemporalSignal or a DynamicGraphTemporalSignal object and returns two iterators according to the split ratio specified by train_ratio.
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
loader = ChickenpoxDatasetLoader()
dataset = loader.get_dataset()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)
Index-Batching & DDP
Index-batching is a technique that reduces the memory cost of training ST-GNNs with spatiotemporal data with no impact on accuracy, enabling greater scalability and training on the full PeMS dataset without graph partioning for the first time. Leveraging the reduced memory footprint, this technique also enables GPU-index-batching - a technique that performs preprocessing entirely in GPU memory and utilizes a single CPU-to-GPU mem-copy in place of batch-level CPU-to-GPU transfers throughout training. We implemented GPU-index-batching and index-batching for the following existing datasets and added two new datasets (highlighted in bold) to PyTorch Geometric Temporal (PGT):
PeMs-Bay
WindmillLarge
HungaryChickenpox
PeMSAllLA
PeMS
Utilizing index-batching requires minimal modifications to the existing PGT workflow. Simply initialize the DatasetLoader object with the flag index=True and then call loader.get_index_dataset() For example, the following is a sample training loop with PeMS-Bay and DCRNN:
model = BatchedDCRNN(2, 2, K=3)
model = DDP(model)
loader = PemsBayDatasetLoader(index=True)
train_dataloader, _, _, edges, edge_weights, means, stds = loader.get_index_dataset(world_size=world_size, ddp_rank=worker_rank, batch_size=batch_size)
for batch in train_dataloader:
X_batch, y_batch = batch
# Forward pass
outputs = model(X_batch, edges, edge_weights)
# Calculate loss
loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
Index-batching uses a sequence-to-sequence batch format, where the data is of shape (batch_size, seq_length, num_graph_nodes, num_features). In the future, we hope to integrate index-batching into all existing PGT datasets. Examples can be found in the PGT Github repository.
Distributed Data Parallel Training
Using Dask-DDP, PGT now supports distributed data parallel (DDP) training with the following datasets:
PeMs-Bay
PeMSAllLA
PeMS
DDP training requires minimal modifications to the existing training loop. For example, to modify the index-batching training loop to utilize DDP, we 1) pass world_size and ddp_rank to the get_index_dataset method and 2) wrap the model in the PyTorch DDP wrapper (note that a Dask cluster must be initialized).
model = BatchedDCRNN(2, 2, K=3)
model = DDP(model)
loader = PemsBayDatasetLoader(index=True)
train_dataloader, _, _, edges, edge_weights, means, stds = loader.get_index_dataset(world_size=world_size, ddp_rank=worker_rank, batch_size=batch_size)
for batch in train_dataloader:
X_batch, y_batch = batch
# Forward pass
outputs = model(X_batch, edges, edge_weights)
# Calculate loss
loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
A simple script (and further instructions) for multi-GPU/multi-node DDP and Dask initialization are available within the PGT Github repository.
Applications
In the following we will overview two case studies where PyTorch Geometric Temporal can be used to solve real world relevant machine learning problems. One of them is about epidemiological forecasting the other on is about predicting web traffic.
Epidemiological Forecasting
We are using the Hungarian Chickenpox Cases dataset in this case study. We will train a regressor to predict the weekly cases reported by the counties using a recurrent graph convolutional network. First, we will load the dataset and create an appropriate spatio-temporal split.
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
loader = ChickenpoxDatasetLoader()
dataset = loader.get_dataset()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
In the next steps we will define the recurrent graph neural network architecture used for solving the supervised task. The constructor defines a DCRNN layer and a feedforward layer. It is important to note that the final non-linearity is not integrated into the recurrent graph convolutional operation. This design principle is used consistently and it was taken from PyTorch Geometric. Because of this, we defined a ReLU non-linearity between the recurrent and linear layers manually. The final linear layer is not followed by a non-linearity as we solve a regression problem with zero-mean targets.
import torch
torch.manual_seed(1)
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features):
super(RecurrentGCN, self).__init__()
self.recurrent = DCRNN(node_features, 32, 1)
self.linear = torch.nn.Linear(32, 1)
def forward(self, x, edge_index, edge_weight):
h = self.recurrent(x, edge_index, edge_weight)
h = F.relu(h)
h = self.linear(h)
return h
Let us define a model (we have 4 node features) and train it on the training split (first 20% of the temporal snapshots) for 200 epochs. We backpropagate when the loss from every temporal snapshot is accumulated. We will use the Adam optimizer with a learning rate of 0.01. The tqdm function is used for measuring the runtime need for each training epoch.
from tqdm import tqdm
model = RecurrentGCN(node_features = 4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in tqdm(range(200)):
cost = 0
for time, snapshot in enumerate(train_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost = cost / (time+1)
cost.backward()
optimizer.step()
optimizer.zero_grad()
Using the holdout we will evaluate the performance of the trained recurrent graph convolutional network and calculate the mean squared error across all the spatial units and time periods.
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))
>>> MSE: 0.7418
Web Traffic Prediction
We are using the Wikipedia Maths dataset in this case study. We will train a recurrent graph neural network to predict the daily views on Wikipedia pages using a recurrent graph convolutional network. First, we will load the dataset and use 14 lagged traffic variables. Next, we create an appropriate spatio-temporal split using 50% of days for training of the model.
from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
loader = WikiMathsDatasetLoader()
dataset = loader.get_dataset(lags=14)
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)
In the next steps we will define the recurrent graph neural network architecture used for solving the supervised task. The constructor defines a GConvGRU layer and a feedforward layer. It is important to note again that the non-linearity is not integrated into the recurrent graph convolutional operation. The convolutional model has a fixed number of filters (which can be parametrized) and considers 2nd order neighborhoods.
import torch
torch.manual_seed(1)
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features, filters):
super(RecurrentGCN, self).__init__()
self.recurrent = GConvGRU(node_features, filters, 2)
self.linear = torch.nn.Linear(filters, 1)
def forward(self, x, edge_index, edge_weight):
h = self.recurrent(x, edge_index, edge_weight)
h = F.relu(h)
h = self.linear(h)
return h
Let us define a model (we have 14 node features) and train it on the training split (first 50% of the temporal snapshots) for 50 epochs. We backpropagate the loss from every temporal snapshot individually. We will use the Adam optimizer with a learning rate of 0.01. The tqdm function is used for measuring the runtime need for each training epoch.
from tqdm import tqdm
model = RecurrentGCN(node_features=14, filters=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in tqdm(range(50)):
for time, snapshot in enumerate(train_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost.backward()
optimizer.step()
optimizer.zero_grad()
Using the holdout traffic data we will evaluate the performance of the trained recurrent graph convolutional network and calculate the mean squared error across all of the web pages and days.
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))
>>> MSE: 0.5264