PyTorch Geometric Temporal¶
Contents
Discrete Recurrent Graph Convolutional Layers¶
-
class
GConvGRU
(in_channels: int, out_channels: int, K: int, normalization: str = 'sym', bias: bool = True)[source]¶ An implementation of the Chebyshev Graph Convolutional Gated Recurrent Unit Cell. For details see this paper: “Structured Sequence Modeling with Graph Convolutional Recurrent Networks.”
- Parameters
in_channels (int) – Number of input features.
out_channels (int) – Number of output features.
K (int) – Chebyshev filter size \(K\).
normalization (str, optional) –
The normalization scheme for the graph Laplacian (default:
"sym"
):1.
None
: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)2.
"sym"
: Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)3.
"rw"
: Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)You need to pass
lambda_max
to theforward()
method of this operator in case the normalization is non-symmetric.lambda_max
should be atorch.Tensor
of size[num_graphs]
in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-computelambda_max
via thetorch_geometric.transforms.LaplacianLambdaMax
transform.bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)
-
forward
(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None) → torch.FloatTensor[source]¶ Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state matrix is not present when the forward pass is called it is initialized with zeros.
- Arg types:
X (PyTorch Float Tensor) - Node features.
edge_index (PyTorch Long Tensor) - Graph edge indices.
edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.
H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.
- Return types:
H (PyTorch Float Tensor) - Hidden state matrix for all nodes.
-
class
GConvLSTM
(in_channels: int, out_channels: int, K: int, normalization: str = 'sym', bias: bool = True)[source]¶ An implementation of the Chebyshev Graph Convolutional Long Short Term Memory Cell. For details see this paper: “Structured Sequence Modeling with Graph Convolutional Recurrent Networks.”
- Parameters
in_channels (int) – Number of input features.
out_channels (int) – Number of output features.
K (int) – Chebyshev filter size \(K\).
normalization (str, optional) –
The normalization scheme for the graph Laplacian (default:
"sym"
):1.
None
: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)2.
"sym"
: Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)3.
"rw"
: Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)You need to pass
lambda_max
to theforward()
method of this operator in case the normalization is non-symmetric.lambda_max
should be atorch.Tensor
of size[num_graphs]
in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-computelambda_max
via thetorch_geometric.transforms.LaplacianLambdaMax
transform.bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)
-
forward
(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None, C: Optional[torch.FloatTensor] = None) → torch.FloatTensor[source]¶ Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state and cell state matrices are not present when the forward pass is called these are initialized with zeros.
- Arg types:
X (PyTorch Float Tensor) - Node features.
edge_index (PyTorch Long Tensor) - Graph edge indices.
edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.
H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.
C (PyTorch Float Tensor, optional) - Cell state matrix for all nodes.
- Return types:
H (PyTorch Float Tensor) - Hidden state matrix for all nodes.
C (PyTorch Float Tensor) - Cell state matrix for all nodes.
-
class
GCLSTM
(in_channels: int, out_channels: int, K: int, normalization: str = 'sym', bias: bool = True)[source]¶ An implementation of the the Integrated Graph Convolutional Long Short Term Memory Cell. For details see this paper: “GC-LSTM: Graph Convolution Embedded LSTM for Dynamic Link Prediction.”
- Parameters
in_channels (int) – Number of input features.
out_channels (int) – Number of output features.
K (int) – Chebyshev filter size \(K\).
normalization (str, optional) –
The normalization scheme for the graph Laplacian (default:
"sym"
):1.
None
: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)2.
"sym"
: Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)3.
"rw"
: Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)You need to pass
lambda_max
to theforward()
method of this operator in case the normalization is non-symmetric.lambda_max
should be atorch.Tensor
of size[num_graphs]
in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-computelambda_max
via thetorch_geometric.transforms.LaplacianLambdaMax
transform.bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)
-
forward
(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None, C: Optional[torch.FloatTensor] = None) → torch.FloatTensor[source]¶ Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state and cell state matrices are not present when the forward pass is called these are initialized with zeros.
- Arg types:
X (PyTorch Float Tensor) - Node features.
edge_index (PyTorch Long Tensor) - Graph edge indices.
edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.
H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.
C (PyTorch Float Tensor, optional) - Cell state matrix for all nodes.
- Return types:
H (PyTorch Float Tensor) - Hidden state matrix for all nodes.
C (PyTorch Float Tensor) - Cell state matrix for all nodes.
-
class
LRGCN
(in_channels: int, out_channels: int, num_relations: int, num_bases: int)[source]¶ An implementation of the Long Short Term Memory Relational Graph Convolution Layer. For details see this paper: “Predicting Path Failure In Time-Evolving Graphs.”
- Parameters
-
forward
(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_type: torch.LongTensor, H: Optional[torch.FloatTensor] = None, C: Optional[torch.FloatTensor] = None) → torch.FloatTensor[source]¶ Making a forward pass. If the hidden state and cell state matrices are not present when the forward pass is called these are initialized with zeros.
- Arg types:
X (PyTorch Float Tensor) - Node features.
edge_index (PyTorch Long Tensor) - Graph edge indices.
edge_type (PyTorch Long Tensor) - Edge type vector.
H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.
C (PyTorch Float Tensor, optional) - Cell state matrix for all nodes.
- Return types:
H (PyTorch Float Tensor) - Hidden state matrix for all nodes.
C (PyTorch Float Tensor) - Cell state matrix for all nodes.
-
class
DyGrEncoder
(conv_out_channels: int, conv_num_layers: int, conv_aggr: str, lstm_out_channels: int, lstm_num_layers: int)[source]¶ An implementation of the integrated Gated Graph Convolution Long Short Term Memory Layer. For details see this paper: “Predictive Temporal Embedding of Dynamic Graphs.”
- Parameters
-
forward
(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None, C: Optional[torch.FloatTensor] = None) → torch.FloatTensor[source]¶ Making a forward pass. If the hidden state and cell state matrices are not present when the forward pass is called these are initialized with zeros.
- Arg types:
X (PyTorch Float Tensor) - Node features.
edge_index (PyTorch Long Tensor) - Graph edge indices.
edge_weight (PyTorch Float Tensor, optional) - Edge weight vector.
H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.
C (PyTorch Float Tensor, optional) - Cell state matrix for all nodes.
- Return types:
H_tilde (PyTorch Float Tensor) - Output matrix for all nodes.
H (PyTorch Float Tensor) - Hidden state matrix for all nodes.
C (PyTorch Float Tensor) - Cell state matrix for all nodes.
-
class
EvolveGCNH
(num_of_nodes: int, in_channels: int, improved: bool = False, cached: bool = False, normalize: bool = True, add_self_loops: bool = True)[source]¶ An implementation of the Evolving Graph Convolutional Hidden Layer. For details see this paper: “EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graph.”
- Parameters
num_of_nodes (int) – Number of vertices.
in_channels (int) – Number of filters.
improved (bool, optional) – If set to
True
, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default:False
)cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default:
True
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)
-
forward
(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None) → torch.FloatTensor[source]¶ Making a forward pass.
- Arg types:
X (PyTorch Float Tensor) - Node embedding.
edge_index (PyTorch Long Tensor) - Graph edge indices.
edge_weight (PyTorch Float Tensor, optional) - Edge weight vector.
- Return types:
X (PyTorch Float Tensor) - Output matrix for all nodes.
-
class
EvolveGCNO
(in_channels: int, improved: bool = False, cached: bool = False, normalize: bool = True, add_self_loops: bool = True)[source]¶ An implementation of the Evolving Graph Convolutional without Hidden Layer. For details see this paper: “EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graph.”
- Parameters
in_channels (int) – Number of filters.
improved (bool, optional) – If set to
True
, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default:False
)cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default:
True
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)
-
forward
(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None) → torch.FloatTensor[source]¶ Making a forward pass.
- Arg types:
X (PyTorch Float Tensor) - Node embedding.
edge_index (PyTorch Long Tensor) - Graph edge indices.
edge_weight (PyTorch Float Tensor, optional) - Edge weight vector.
- Return types:
X (PyTorch Float Tensor) - Output matrix for all nodes.
-
class
DCRNN
(in_channels: int, out_channels: int, K: int, bias: bool = True)[source]¶ An implementation of the Diffusion Convolutional Gated Recurrent Unit. For details see: “Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting”
- Parameters
-
forward
(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None) → torch.FloatTensor[source]¶ Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state matrix is not present when the forward pass is called it is initialized with zeros.
- Parameters
X (PyTorch Float Tensor) – Node features.
edge_index (PyTorch Long Tensor) – Graph edge indices.
edge_weight (PyTorch Long Tensor, optional) – Edge weight vector.
H (PyTorch Float Tensor, optional) – Hidden state matrix for all nodes.
- Return types:
H (PyTorch Float Tensor): Hidden state matrix for all nodes.
Discrete Temporal Graph Convolutional Layers¶
-
class
STConv
(num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, kernel_size: int, K: int, normalization: str = 'sym', bias: bool = True)[source]¶ Spatio-temporal convolution block using ChebConv Graph Convolutions. For details see: “Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting”
NB. The ST-Conv block contains two temporal convolutions (TemporalConv) with kernel size k. Hence for an input sequence of length m, the output sequence will be length m-2(k-1).
- Parameters
in_channels (int) – Number of input features.
hidden_channels (int) – Number of hidden units output by graph convolution block
out_channels (int) – Number of output features.
= (kernel_size) –
K (int) – Chebyshev filter size \(K\).
normalization (str, optional) –
The normalization scheme for the graph Laplacian (default:
"sym"
):1.
None
: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)2.
"sym"
: Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)3.
"rw"
: Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)You need to pass
lambda_max
to theforward()
method of this operator in case the normalization is non-symmetric.lambda_max
should be atorch.Tensor
of size[num_graphs]
in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-computelambda_max
via thetorch_geometric.transforms.LaplacianLambdaMax
transform.bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)
-
forward
(X, edge_index, edge_weight)[source]¶ Forward pass. If edge weights are not present the forward pass defaults to an unweighted graph.
- Parameters
X (PyTorch Float Tensor) – Sequence of node features of shape (batch_size, input_time_steps, num_nodes, in_channels)
edge_index (PyTorch Long Tensor) – Graph edge indices.
edge_weight (PyTorch Long Tensor, optional) – Edge weight vector.
- Return Types:
Out (PyTorch Float Tensor): (Sequence) of node features
Auxiliary Graph Convolutional Layers¶
-
class
TemporalConv
(in_channels, out_channels, kernel_size=3)[source]¶ Temporal convolution block applied to nodes in the STGCN Layer For details see: “Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting”
Based off the temporal convolution introduced in “Convolutional Sequence to Sequence Learning” <https://arxiv.org/abs/1709.04875>`_
NB. Given an input sequence of length m and a kernel size of k the output sequence will have length m-(k-1)
- Parameters
-
class
DConv
(in_channels, out_channels, K, bias=True)[source]¶ An implementation of the Diffusion Convolution Layer. For details see: “Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting”
- Parameters
-
forward
(x, edge_index, edge_weight)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
message
(x_j, norm)[source]¶ Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.