PyTorch Geometric Temporal

Recurrent Graph Convolutional Layers

class GConvGRU(in_channels: int, out_channels: int, K: int, normalization: str = 'sym', bias: bool = True)[source]

An implementation of the Chebyshev Graph Convolutional Gated Recurrent Unit Cell. For details see this paper: “Structured Sequence Modeling with Graph Convolutional Recurrent Networks.”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • K (int) – Chebyshev filter size \(K\).

  • normalization (str, optional) –

    The normalization scheme for the graph Laplacian (default: "sym"):

    1. None: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)

    2. "sym": Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)

    3. "rw": Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)

    You need to pass lambda_max to the forward() method of this operator in case the normalization is non-symmetric. lambda_max should be a torch.Tensor of size [num_graphs] in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute lambda_max via the torch_geometric.transforms.LaplacianLambdaMax transform.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None, lambda_max: Optional[torch.Tensor] = None)torch.FloatTensor[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state matrix is not present when the forward pass is called it is initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.

  • H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

  • lambda_max (PyTorch Tensor, optional but mandatory if normalization is not sym) - Largest eigenvalue of Laplacian.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

training: bool
class GConvLSTM(in_channels: int, out_channels: int, K: int, normalization: str = 'sym', bias: bool = True)[source]

An implementation of the Chebyshev Graph Convolutional Long Short Term Memory Cell. For details see this paper: “Structured Sequence Modeling with Graph Convolutional Recurrent Networks.”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • K (int) – Chebyshev filter size \(K\).

  • normalization (str, optional) –

    The normalization scheme for the graph Laplacian (default: "sym"):

    1. None: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)

    2. "sym": Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)

    3. "rw": Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)

    You need to pass lambda_max to the forward() method of this operator in case the normalization is non-symmetric. lambda_max should be a torch.Tensor of size [num_graphs] in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute lambda_max via the torch_geometric.transforms.LaplacianLambdaMax transform.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None, C: Optional[torch.FloatTensor] = None, lambda_max: Optional[torch.Tensor] = None)torch.FloatTensor[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state and cell state matrices are not present when the forward pass is called these are initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.

  • H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

  • C (PyTorch Float Tensor, optional) - Cell state matrix for all nodes.

  • lambda_max (PyTorch Tensor, optional but mandatory if normalization is not sym) - Largest eigenvalue of Laplacian.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

  • C (PyTorch Float Tensor) - Cell state matrix for all nodes.

training: bool
class GCLSTM(in_channels: int, out_channels: int, K: int, normalization: str = 'sym', bias: bool = True)[source]

An implementation of the the Integrated Graph Convolutional Long Short Term Memory Cell. For details see this paper: “GC-LSTM: Graph Convolution Embedded LSTM for Dynamic Link Prediction.”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • K (int) – Chebyshev filter size \(K\).

  • normalization (str, optional) –

    The normalization scheme for the graph Laplacian (default: "sym"):

    1. None: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)

    2. "sym": Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)

    3. "rw": Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)

    You need to pass lambda_max to the forward() method of this operator in case the normalization is non-symmetric. lambda_max should be a torch.Tensor of size [num_graphs] in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute lambda_max via the torch_geometric.transforms.LaplacianLambdaMax transform.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None, C: Optional[torch.FloatTensor] = None, lambda_max: Optional[torch.Tensor] = None)torch.FloatTensor[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state and cell state matrices are not present when the forward pass is called these are initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.

  • H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

  • C (PyTorch Float Tensor, optional) - Cell state matrix for all nodes.

  • lambda_max (PyTorch Tensor, optional but mandatory if normalization is not sym) - Largest eigenvalue of Laplacian.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

  • C (PyTorch Float Tensor) - Cell state matrix for all nodes.

training: bool
class LRGCN(in_channels: int, out_channels: int, num_relations: int, num_bases: int)[source]

An implementation of the Long Short Term Memory Relational Graph Convolution Layer. For details see this paper: “Predicting Path Failure In Time-Evolving Graphs.”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • num_relations (int) – Number of relations.

  • num_bases (int) – Number of bases.

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_type: torch.LongTensor, H: Optional[torch.FloatTensor] = None, C: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass. If the hidden state and cell state matrices are not present when the forward pass is called these are initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_type (PyTorch Long Tensor) - Edge type vector.

  • H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

  • C (PyTorch Float Tensor, optional) - Cell state matrix for all nodes.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

  • C (PyTorch Float Tensor) - Cell state matrix for all nodes.

training: bool
class DyGrEncoder(conv_out_channels: int, conv_num_layers: int, conv_aggr: str, lstm_out_channels: int, lstm_num_layers: int)[source]

An implementation of the integrated Gated Graph Convolution Long Short Term Memory Layer. For details see this paper: “Predictive Temporal Embedding of Dynamic Graphs.”

Parameters
  • conv_out_channels (int) – Number of output channels for the GGCN.

  • conv_num_layers (int) – Number of Gated Graph Convolutions.

  • conv_aggr (str) – Aggregation scheme to use ("add", "mean", "max").

  • lstm_out_channels (int) – Number of LSTM channels.

  • lstm_num_layers (int) – Number of neurons in LSTM.

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None, C: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass. If the hidden state and cell state matrices are not present when the forward pass is called these are initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Float Tensor, optional) - Edge weight vector.

  • H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

  • C (PyTorch Float Tensor, optional) - Cell state matrix for all nodes.

Return types:
  • H_tilde (PyTorch Float Tensor) - Output matrix for all nodes.

  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

  • C (PyTorch Float Tensor) - Cell state matrix for all nodes.

training: bool
class EvolveGCNH(num_of_nodes: int, in_channels: int, improved: bool = False, cached: bool = False, normalize: bool = True, add_self_loops: bool = True)[source]

An implementation of the Evolving Graph Convolutional Hidden Layer. For details see this paper: “EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graph.”

Parameters
  • num_of_nodes (int) – Number of vertices.

  • in_channels (int) – Number of filters.

  • improved (bool, optional) – If set to True, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default: False)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default: True)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass.

Arg types:
  • X (PyTorch Float Tensor) - Node embedding.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Float Tensor, optional) - Edge weight vector.

Return types:
  • X (PyTorch Float Tensor) - Output matrix for all nodes.

reset_parameters()[source]
training: bool
class EvolveGCNO(in_channels: int, improved: bool = False, cached: bool = False, normalize: bool = True, add_self_loops: bool = True)[source]

An implementation of the Evolving Graph Convolutional without Hidden Layer. For details see this paper: “EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graph.” :param in_channels: Number of filters. :type in_channels: int :param improved: If set to True, the layer computes

\(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default: False)

Parameters
  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default: True)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass. Arg types:

  • X (PyTorch Float Tensor) - Node embedding.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Float Tensor, optional) - Edge weight vector.

Return types:
  • X (PyTorch Float Tensor) - Output matrix for all nodes.

reset_parameters()[source]
training: bool
class GCNConv_Fixed_W(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs)[source]

The graph convolutional operator adapted from the “Semi-supervised Classification with Graph Convolutional Networks” paper, with weights not trainable. .. math:

\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},

where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than 1 representing edge weights via the optional edge_weight tensor. Its node-wise formulation is given by: .. math:

\mathbf{x}^{\prime}_i = \mathbf{\Theta} \sum_{j \in \mathcal{N}(v) \cup
\{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j

with \(\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}\), where \(e_{j,i}\) denotes the edge weight from source node j to target node i (default: 1.0) :param in_channels: Size of each input sample, or -1 to derive

the size from the first input(s) to the forward method.

Parameters
  • out_channels (int) – Size of each output sample.

  • improved (bool, optional) – If set to True, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default: False)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • normalize (bool, optional) – Whether to add self-loops and compute symmetric normalization coefficients on the fly. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(W: torch.FloatTensor, x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_weight: Optional[torch.Tensor] = None)torch.Tensor[source]
message(x_j: torch.Tensor, edge_weight: Optional[torch.Tensor])torch.Tensor[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

reset_parameters()[source]
class TGCN(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True)[source]

An implementation of the Temporal Graph Convolutional Gated Recurrent Cell. For details see this paper: “T-GCN: A Temporal Graph ConvolutionalNetwork for Traffic Prediction.”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • improved (bool) – Stronger self loops. Default is False.

  • cached (bool) – Caching the message weights. Default is False.

  • add_self_loops (bool) – Adding self-loops for smoothing. Default is True.

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state matrix is not present when the forward pass is called it is initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.

  • H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

training: bool
class TGCN2(in_channels: int, out_channels: int, batch_size: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True)[source]

An implementation THAT SUPPORTS BATCHES of the Temporal Graph Convolutional Gated Recurrent Cell. For details see this paper: “T-GCN: A Temporal Graph ConvolutionalNetwork for Traffic Prediction.”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • batch_size (int) – Size of the batch.

  • improved (bool) – Stronger self loops. Default is False.

  • cached (bool) – Caching the message weights. Default is False.

  • add_self_loops (bool) – Adding self-loops for smoothing. Default is True.

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state matrix is not present when the forward pass is called it is initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.

  • H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

training: bool
class A3TGCN(in_channels: int, out_channels: int, periods: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True)[source]

An implementation of the Attention Temporal Graph Convolutional Cell. For details see this paper: “A3T-GCN: Attention Temporal Graph Convolutional Network for Traffic Forecasting.”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • periods (int) – Number of time periods.

  • improved (bool) – Stronger self loops (default False).

  • cached (bool) – Caching the message weights (default False).

  • add_self_loops (bool) – Adding self-loops for smoothing (default True).

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state matrix is not present when the forward pass is called it is initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor): Node features for T time periods.

  • edge_index (PyTorch Long Tensor): Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional)*: Edge weight vector.

  • H (PyTorch Float Tensor, optional): Hidden state matrix for all nodes.

Return types:
  • H (PyTorch Float Tensor): Hidden state matrix for all nodes.

training: bool
class A3TGCN2(in_channels: int, out_channels: int, periods: int, batch_size: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True)[source]

An implementation THAT SUPPORTS BATCHES of the Attention Temporal Graph Convolutional Cell. For details see this paper: “A3T-GCN: Attention Temporal Graph Convolutional Network for Traffic Forecasting.”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • periods (int) – Number of time periods.

  • improved (bool) – Stronger self loops (default False).

  • cached (bool) – Caching the message weights (default False).

  • add_self_loops (bool) – Adding self-loops for smoothing (default True).

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state matrix is not present when the forward pass is called it is initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor): Node features for T time periods.

  • edge_index (PyTorch Long Tensor): Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional)*: Edge weight vector.

  • H (PyTorch Float Tensor, optional): Hidden state matrix for all nodes.

Return types:
  • H (PyTorch Float Tensor): Hidden state matrix for all nodes.

training: bool
class MPNNLSTM(in_channels: int, hidden_size: int, num_nodes: int, window: int, dropout: float)[source]

An implementation of the Message Passing Neural Network with Long Short Term Memory. For details see this paper: “Transfer Graph Neural Networks for Pandemic Forecasting.”

Parameters
  • in_channels (int) – Number of input features.

  • hidden_size (int) – Dimension of hidden representations.

  • num_nodes (int) – Number of nodes in the network.

  • window (int) – Number of past samples included in the input.

  • dropout (float) – Dropout rate.

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: torch.FloatTensor)torch.FloatTensor[source]

Making a forward pass through the whole architecture.

Arg types:
  • X (PyTorch FloatTensor) - Node features.

  • edge_index (PyTorch LongTensor) - Graph edge indices.

  • edge_weight (PyTorch LongTensor, optional) - Edge weight vector.

Return types:
  • H (PyTorch FloatTensor) - The hidden representation of size 2*nhid+in_channels+window-1 for each node.

training: bool
class DCRNN(in_channels: int, out_channels: int, K: int, bias: bool = True)[source]

An implementation of the Diffusion Convolutional Gated Recurrent Unit. For details see: “Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • K (int) – Filter size \(K\).

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias (default True)

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, H: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph. If the hidden state matrix is not present when the forward pass is called it is initialized with zeros.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.

  • H (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

training: bool
class AGCRN(number_of_nodes: int, in_channels: int, out_channels: int, K: int, embedding_dimensions: int)[source]

An implementation of the Adaptive Graph Convolutional Recurrent Unit. For details see: “Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting” :param number_of_nodes: Number of vertices. :type number_of_nodes: int :param in_channels: Number of input features. :type in_channels: int :param out_channels: Number of output features. :type out_channels: int :param K: Filter size \(K\). :type K: int :param embedding_dimensions: Number of node embedding dimensions. :type embedding_dimensions: int

forward(X: torch.FloatTensor, E: torch.FloatTensor, H: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass. Arg types:

  • X (PyTorch Float Tensor) - Node feature matrix.

  • E (PyTorch Float Tensor) - Node embedding matrix.

  • H (PyTorch Float Tensor) - Node hidden state matrix. Default is None.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

training: bool

Temporal Graph Attention Layers

class STConv(num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, kernel_size: int, K: int, normalization: str = 'sym', bias: bool = True)[source]

Spatio-temporal convolution block using ChebConv Graph Convolutions. For details see: “Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting”

NB. The ST-Conv block contains two temporal convolutions (TemporalConv) with kernel size k. Hence for an input sequence of length m, the output sequence will be length m-2(k-1).

Parameters
  • in_channels (int) – Number of input features.

  • hidden_channels (int) – Number of hidden units output by graph convolution block

  • out_channels (int) – Number of output features.

  • kernel_size (int) – Size of the kernel considered.

  • K (int) – Chebyshev filter size \(K\).

  • normalization (str, optional) –

    The normalization scheme for the graph Laplacian (default: "sym"):

    1. None: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)

    2. "sym": Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)

    3. "rw": Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)

    You need to pass lambda_max to the forward() method of this operator in case the normalization is non-symmetric. lambda_max should be a torch.Tensor of size [num_graphs] in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute lambda_max via the torch_geometric.transforms.LaplacianLambdaMax transform.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Forward pass. If edge weights are not present the forward pass defaults to an unweighted graph.

Arg types:
  • X (PyTorch FloatTensor) - Sequence of node features of shape (Batch size X Input time steps X Num nodes X In channels).

  • edge_index (PyTorch LongTensor) - Graph edge indices.

  • edge_weight (PyTorch LongTensor, optional)- Edge weight vector.

Return types:
  • T (PyTorch FloatTensor) - Sequence of node features.

training: bool
class ASTGCN(nb_block: int, in_channels: int, K: int, nb_chev_filter: int, nb_time_filter: int, time_strides: int, num_for_predict: int, len_input: int, num_of_vertices: int, normalization: Optional[str] = None, bias: bool = True)[source]

An implementation of the Attention Based Spatial-Temporal Graph Convolutional Cell. For details see this paper: “Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting.”

Parameters
  • nb_block (int) – Number of ASTGCN blocks in the model.

  • in_channels (int) – Number of input features.

  • K (int) – Order of Chebyshev polynomials. Degree is K-1.

  • nb_chev_filters (int) – Number of Chebyshev filters.

  • nb_time_filters (int) – Number of time filters.

  • time_strides (int) – Time strides during temporal convolution.

  • edge_index (array) – edge indices.

  • num_for_predict (int) – Number of predictions to make in the future.

  • len_input (int) – Length of the input sequence.

  • num_of_vertices (int) – Number of vertices in the graph.

  • normalization (str, optional) – The normalization scheme for the graph Laplacian (default: "sym"): 1. None: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) 2. "sym": Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\) 3. "rw": Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\) You need to pass lambda_max to the forward() method of this operator in case the normalization is non-symmetric. lambda_max should be a torch.Tensor of size [num_graphs] in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute lambda_max via the torch_geometric.transforms.LaplacianLambdaMax transform.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

forward(X: torch.FloatTensor, edge_index: torch.LongTensor)torch.FloatTensor[source]

Making a forward pass.

Arg types:
  • X (PyTorch FloatTensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).

  • edge_index (PyTorch LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time.

Return types:
  • X (PyTorch FloatTensor)* - Hidden state tensor for all nodes, with shape (B, N_nodes, T_out).

training: bool
class MSTGCN(nb_block: int, in_channels: int, K: int, nb_chev_filter: int, nb_time_filter: int, time_strides: int, num_for_predict: int, len_input: int)[source]

An implementation of the Multi-Component Spatial-Temporal Graph Convolution Networks, a degraded version of ASTGCN. For details see this paper: “Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting.”

Parameters
  • nb_block (int) – Number of ASTGCN blocks in the model.

  • in_channels (int) – Number of input features.

  • K (int) – Order of Chebyshev polynomials. Degree is K-1.

  • nb_chev_filter (int) – Number of Chebyshev filters.

  • nb_time_filter (int) – Number of time filters.

  • time_strides (int) – Time strides during temporal convolution.

  • num_for_predict (int) – Number of predictions to make in the future.

  • len_input (int) – Length of the input sequence.

forward(X: torch.FloatTensor, edge_index: torch.LongTensor)torch.FloatTensor[source]

Making a forward pass. This module takes a likst of MSTGCN blocks and use a final convolution to serve as a multi-component fusion. B is the batch size. N_nodes is the number of nodes in the graph. F_in is the dimension of input features. T_in is the length of input sequence in time. T_out is the length of output sequence in time.

Arg types:
  • X (PyTorch FloatTensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).

  • edge_index (PyTorch LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time.

Return types:
  • X (PyTorch FloatTensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, T_out).

training: bool
class GMAN(L: int, K: int, d: int, num_his: int, bn_decay: float, steps_per_day: int, use_bias: bool, mask: bool)[source]

An implementation of GMAN. For details see this paper: “GMAN: A Graph Multi-Attention Network for Traffic Prediction.”

Parameters
  • L (int) – Number of STAtt blocks in the encoder/decoder.

  • K (int) – Number of attention heads.

  • d (int) – Dimension of each attention head outputs.

  • num_his (int) – Number of history steps.

  • bn_decay (float) – Batch normalization momentum.

  • steps_per_day (int) – Number of steps in a day.

  • use_bias (bool) – Whether to use bias in Fully Connected layers.

  • mask (bool) – Whether to mask attention score in temporal attention.

forward(X: torch.FloatTensor, SE: torch.FloatTensor, TE: torch.FloatTensor)torch.FloatTensor[source]

Making a forward pass of GMAN.

Arg types:
  • X (PyTorch Float Tensor) - Input sequence, with shape (batch_size, num_hist, num of nodes).

  • SE (Pytorch Float Tensor) - Spatial embedding, with shape (numbed of nodes, K * d).

  • TE (Pytorch Float Tensor) - Temporal embedding, with shape (batch_size, num_his + num_pred, 2).

Return types:
  • X (PyTorch Float Tensor) - Output sequence for prediction, with shape (batch_size, num_pred, num of nodes).

training: bool
class SpatioTemporalAttention(K: int, d: int, bn_decay: float, mask: bool)[source]

An implementation of the spatial-temporal attention block, with spatial attention and temporal attention followed by gated fusion. For details see this paper: “GMAN: A Graph Multi-Attention Network for Traffic Prediction.”

Parameters
  • K (int) – Number of attention heads.

  • d (int) – Dimension of each attention head outputs.

  • bn_decay (float) – Batch normalization momentum.

  • mask (bool) – Whether to mask attention score in temporal attention.

forward(X: torch.FloatTensor, STE: torch.FloatTensor)torch.FloatTensor[source]

Making a forward pass of the spatial-temporal attention block.

Arg types:
  • X (PyTorch Float Tensor) - Input sequence, with shape (batch_size, num_step, num_nodes, K*d).

  • STE (Pytorch Float Tensor) - Spatial-temporal embedding, with shape (batch_size, num_step, num_nodes, K*d).

Return types:
  • X (PyTorch Float Tensor) - Attention scores, with shape (batch_size, num_step, num_nodes, K*d).

training: bool
class GraphConstructor(nnodes: int, k: int, dim: int, alpha: float, xd: Optional[int] = None)[source]

An implementation of the graph learning layer to construct an adjacency matrix. For details see this paper: “Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks.”

Parameters
  • nnodes (int) – Number of nodes in the graph.

  • k (int) – Number of largest values to consider in constructing the neighbourhood of a node (pick the “nearest” k nodes).

  • dim (int) – Dimension of the node embedding.

  • alpha (float, optional) – Tanh alpha for generating adjacency matrix, alpha controls the saturation rate

  • xd (int, optional) – Static feature dimension, default None.

forward(idx: torch.LongTensor, FE: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass to construct an adjacency matrix from node embeddings.

Arg types:
  • idx (Pytorch Long Tensor) - Input indices, a permutation of the number of nodes, default None (no permutation).

  • FE (Pytorch Float Tensor, optional) - Static feature, default None.

Return types:
  • A (PyTorch Float Tensor) - Adjacency matrix constructed from node embeddings.

training: bool
class MTGNN(gcn_true: bool, build_adj: bool, gcn_depth: int, num_nodes: int, kernel_set: list, kernel_size: int, dropout: float, subgraph_size: int, node_dim: int, dilation_exponential: int, conv_channels: int, residual_channels: int, skip_channels: int, end_channels: int, seq_length: int, in_dim: int, out_dim: int, layers: int, propalpha: float, tanhalpha: float, layer_norm_affline: bool, xd: Optional[int] = None)[source]

An implementation of the Multivariate Time Series Forecasting Graph Neural Networks. For details see this paper: “Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks.”

Parameters
  • gcn_true (bool) – Whether to add graph convolution layer.

  • build_adj (bool) – Whether to construct adaptive adjacency matrix.

  • gcn_depth (int) – Graph convolution depth.

  • num_nodes (int) – Number of nodes in the graph.

  • kernel_set (list of int) – List of kernel sizes.

  • kernel_size (int) – Size of kernel for convolution, to calculate receptive field size.

  • dropout (float) – Droupout rate.

  • subgraph_size (int) – Size of subgraph.

  • node_dim (int) – Dimension of nodes.

  • dilation_exponential (int) – Dilation exponential.

  • conv_channels (int) – Convolution channels.

  • residual_channels (int) – Residual channels.

  • skip_channels (int) – Skip channels.

  • end_channels (int) – End channels.

  • seq_length (int) – Length of input sequence.

  • in_dim (int) – Input dimension.

  • out_dim (int) – Output dimension.

  • layers (int) – Number of layers.

  • propalpha (float) – Prop alpha, ratio of retaining the root nodes’s original states in mix-hop propagation, a value between 0 and 1.

  • tanhalpha (float) – Tanh alpha for generating adjacency matrix, alpha controls the saturation rate.

  • layer_norm_affline (bool) – Whether to do elementwise affine in Layer Normalization.

  • xd (int, optional) – Static feature dimension, default None.

forward(X_in: torch.FloatTensor, A_tilde: Optional[torch.FloatTensor] = None, idx: Optional[torch.LongTensor] = None, FE: Optional[torch.FloatTensor] = None)torch.FloatTensor[source]

Making a forward pass of MTGNN.

Arg types:
  • X_in (PyTorch FloatTensor) - Input sequence, with shape (batch_size, in_dim, num_nodes, seq_len).

  • A_tilde (Pytorch FloatTensor, optional) - Predefined adjacency matrix, default None.

  • idx (Pytorch LongTensor, optional) - Input indices, a permutation of the num_nodes, default None (no permutation).

  • FE (Pytorch FloatTensor, optional) - Static feature, default None.

Return types:
  • X (PyTorch FloatTensor) - Output sequence for prediction, with shape (batch_size, seq_len, num_nodes, 1).

training: bool
class AAGCN(in_channels: int, out_channels: int, edge_index: torch.LongTensor, num_nodes: int, stride: int = 1, residual: bool = True, adaptive: bool = True, attention: bool = True)[source]

Two-Stream Adaptive Graph Convolutional Network.

For details see this paper: “Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition.”. This implementation is based on the authors Github Repo https://github.com/lshiwjx/2s-AGCN. It’s used by the author for classifying actions from sequences of 3D body joint coordinates.

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • edge_index (PyTorch LongTensor) – Graph edge indices.

  • num_nodes (int) – Number of nodes in the network.

  • stride (int, optional) – Time strides during temporal convolution. (default: 1)

  • residual (bool, optional) – Applying residual connection. (default: True)

  • adaptive (bool, optional) – Adaptive node connection weights. (default: True)

  • attention (bool, optional) – Applying spatial-temporal-channel-attention.

  • (defaultTrue)

forward(x)[source]

Making a forward pass.

Arg types:
  • X (PyTorch FloatTensor) - Node features for T time periods,

with shape (B, F_in, T_in, N_nodes).

Return types:
  • X (PyTorch FloatTensor)* - Sequence of node features,

with shape (B, out_channels, T_in//stride, N_nodes).

training: bool
class DNNTSP(items_total: int, item_embedding_dim: int, n_heads: int)[source]

An implementation of the Deep Neural Network for Temporal Set Prediction. For details see: “Predicting Temporal Sets with Deep Neural Networks”

Parameters
  • items_total (int) – Total number of items in the sets. Cardinality of the union.

  • item_embedding_dim (int) – Item embedding dimensions.

  • n_heads (int) – Number of attention heads.

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None)[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

training: bool

Auxiliary Graph Convolutional Layers

class TemporalConv(in_channels: int, out_channels: int, kernel_size: int = 3)[source]

Temporal convolution block applied to nodes in the STGCN Layer For details see: “Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting.” Based off the temporal convolution

introduced in “Convolutional Sequence to Sequence Learning” <https://arxiv.org/abs/1709.04875>`_

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • kernel_size (int) – Convolutional kernel size.

forward(X: torch.FloatTensor)torch.FloatTensor[source]

Forward pass through temporal convolution block.

Arg types:
  • X (torch.FloatTensor) - Input data of shape

    (batch_size, input_time_steps, num_nodes, in_channels).

Return types:
  • H (torch.FloatTensor) - Output data of shape

    (batch_size, in_channels, num_nodes, input_time_steps).

training: bool
class DConv(in_channels, out_channels, K, bias=True)[source]

An implementation of the Diffusion Convolution Layer. For details see: “Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting”

Parameters
  • in_channels (int) – Number of input features.

  • out_channels (int) – Number of output features.

  • K (int) – Filter size \(K\).

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias (default True).

forward(X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: torch.FloatTensor)torch.FloatTensor[source]

Making a forward pass. If edge weights are not present the forward pass defaults to an unweighted graph.

Arg types:
  • X (PyTorch Float Tensor) - Node features.

  • edge_index (PyTorch Long Tensor) - Graph edge indices.

  • edge_weight (PyTorch Long Tensor, optional) - Edge weight vector.

Return types:
  • H (PyTorch Float Tensor) - Hidden state matrix for all nodes.

message(x_j, norm)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

class ChebConvAttention(in_channels: int, out_channels: int, K: int, normalization: Optional[str] = None, bias: bool = True, **kwargs)[source]

The chebyshev spectral graph convolutional operator with attention from the Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting.” paper \(\mathbf{\hat{L}}\) denotes the scaled and normalized Laplacian \(\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}\).

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • K (int) – Chebyshev filter size \(K\).

  • normalization (str, optional) – The normalization scheme for the graph Laplacian (default: "sym"): 1. None: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) 2. "sym": Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\) 3. "rw": Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\) You need to pass lambda_max to the forward() method of this operator in case the normalization is non-symmetric. lambda_max should be a torch.Tensor of size [num_graphs] in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute lambda_max via the torch_geometric.transforms.LaplacianLambdaMax transform.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.FloatTensor, edge_index: torch.LongTensor, spatial_attention: torch.FloatTensor, edge_weight: Optional[torch.Tensor] = None, batch: Optional[torch.Tensor] = None, lambda_max: Optional[torch.Tensor] = None)torch.FloatTensor[source]

Making a forward pass of the ChebConv Attention layer (Chebyshev graph convolution operation).

Arg types:
  • x (PyTorch Float Tensor) - Node features for T time periods, with shape (B, N_nodes, F_in).

  • edge_index (Tensor array) - Edge indices.

  • spatial_attention (PyTorch Float Tensor) - Spatial attention weights, with shape (B, N_nodes, N_nodes).

  • edge_weight (PyTorch Float Tensor, optional) - Edge weights corresponding to edge indices.

  • batch (PyTorch Tensor, optional) - Batch labels for each edge.

  • lambda_max (optional, but mandatory if normalization is None) - Largest eigenvalue of Laplacian.

Return types:
  • out (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, F_out).

message(x_j, norm)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

class AVWGCN(in_channels: int, out_channels: int, K: int, embedding_dimensions: int)[source]

An implementation of the Node Adaptive Graph Convolution Layer. For details see: “Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting” :param in_channels: Number of input features. :type in_channels: int :param out_channels: Number of output features. :type out_channels: int :param K: Filter size \(K\). :type K: int :param embedding_dimensions: Number of node embedding dimensions. :type embedding_dimensions: int

forward(X: torch.FloatTensor, E: torch.FloatTensor)torch.FloatTensor[source]

Making a forward pass. Arg types:

  • X (PyTorch Float Tensor) - Node features.

  • E (PyTorch Float Tensor) - Node embeddings.

Return types:
  • X_G (PyTorch Float Tensor) - Hidden state matrix for all nodes.

training: bool
class UnitGCN(in_channels: int, out_channels: int, A: torch.FloatTensor, coff_embedding: int = 4, num_subset: int = 3, adaptive: bool = True, attention: bool = True)[source]

Graph Convolutional Block applied to nodes in the Two-Stream Adaptive Graph Convolutional Network as originally implemented in the Github Repo <https://github.com/lshiwjx/2s-AGCN>. For implementational details see https://arxiv.org/abs/1805.07694. Temporal attention, spatial attention and channel-wise attention will be applied. :param in_channels: Number of input features. :type in_channels: int :param out_channels: Number of output features. :type out_channels: int :param A: Adaptive Graph. :type A: Tensor array :param coff_embedding: Coefficient Embeddings. (default: :int:`4`) :type coff_embedding: int, optional :param num_subset: Subsets for adaptive graphs, see :type num_subset: int, optional :param \(\mathbf{A}: //arxiv.org/abs/1805.07694 :param \mathbf{B}: //arxiv.org/abs/1805.07694 :param \mathbf{C}\) in https: //arxiv.org/abs/1805.07694 :param for details. (default: :int:`3`) :param adaptive: Apply Adaptive Graph Convolutions. (default: True) :type adaptive: bool, optional :param attention: Apply Attention. (default: True) :type attention: bool, optional

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class UnitTCN(in_channels: int, out_channels: int, kernel_size: int = 9, stride: int = 1)[source]

Temporal Convolutional Block applied to nodes in the Two-Stream Adaptive Graph Convolutional Network as originally implemented in the Github Repo <https://github.com/lshiwjx/2s-AGCN>. For implementational details see https://arxiv.org/abs/1805.07694 :param in_channels: Number of input features. :type in_channels: int :param out_channels: Number of output features. :type out_channels: int :param kernel_size: Convolutional kernel size. (default: 9) :type kernel_size: int :param stride: Temporal Convolutional kernel stride. (default: 1) :type stride: int

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool

Heterogeneous Graph Convolutional Layers

class HeteroGCLSTM(in_channels_dict: dict, out_channels: int, metadata: tuple, bias: bool = True)[source]

An implementation similar to the Integrated Graph Convolutional Long Short Term Memory Cell for heterogeneous Graphs.

Parameters
  • in_channels_dict (dict of keys=str and values=int) – Dimension of each node’s input features.

  • out_channels (int) – Number of output features.

  • metadata (tuple) – Metadata on node types and edge types in the graphs. Can be generated via PyG method snapshot.metadata() where snapshot is a single HeteroData object.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

forward(x_dict, edge_index_dict, h_dict=None, c_dict=None)[source]

Making a forward pass. If the hidden state and cell state matrix dicts are not present when the forward pass is called these are initialized with zeros.

Arg types:
  • x_dict (Dictionary where keys=Strings and values=PyTorch Float Tensors) - Node features dicts. Can

    be obtained via PyG method snapshot.x_dict where snapshot is a single HeteroData object.

  • edge_index_dict (Dictionary where keys=Tuples and values=PyTorch Long Tensors) - Graph edge type

    and index dicts. Can be obtained via PyG method snapshot.edge_index_dict.

  • h_dict (Dictionary where keys=Strings and values=PyTorch Float Tensor, optional) - Node type and

    hidden state matrix dict for all nodes.

  • c_dict (Dictionary where keys=Strings and values=PyTorch Float Tensor, optional) - Node type and

    cell state matrix dict for all nodes.

Return types:
  • h_dict (Dictionary where keys=Strings and values=PyTorch Float Tensor) - Node type and

    hidden state matrix dict for all nodes.

  • c_dict (Dictionary where keys=Strings and values=PyTorch Float Tensor) - Node type and

    cell state matrix dict for all nodes.

training: bool