Source code for torch_geometric_temporal.nn.recurrent.agcrn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.inits import glorot, zeros


[docs]class AVWGCN(nn.Module): r"""An implementation of the Node Adaptive Graph Convolution Layer. For details see: `"Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting" <https://arxiv.org/abs/2007.02842>`_ Args: in_channels (int): Number of input features. out_channels (int): Number of output features. K (int): Filter size :math:`K`. embedding_dimensions (int): Number of node embedding dimensions. """ def __init__( self, in_channels: int, out_channels: int, K: int, embedding_dimensions: int ): super(AVWGCN, self).__init__() self.K = K self.weights_pool = torch.nn.Parameter( torch.Tensor(embedding_dimensions, K, in_channels, out_channels) ) self.bias_pool = torch.nn.Parameter( torch.Tensor(embedding_dimensions, out_channels) ) glorot(self.weights_pool) zeros(self.bias_pool)
[docs] def forward(self, X: torch.FloatTensor, E: torch.FloatTensor) -> torch.FloatTensor: r"""Making a forward pass. Arg types: * **X** (PyTorch Float Tensor) - Node features. * **E** (PyTorch Float Tensor) - Node embeddings. Return types: * **X_G** (PyTorch Float Tensor) - Hidden state matrix for all nodes. """ number_of_nodes = E.shape[0] supports = F.softmax(F.relu(torch.mm(E, E.transpose(0, 1))), dim=1) support_set = [torch.eye(number_of_nodes).to(supports.device), supports] for _ in range(2, self.K): support = torch.matmul(2 * supports, support_set[-1]) - support_set[-2] support_set.append(support) supports = torch.stack(support_set, dim=0) W = torch.einsum("nd,dkio->nkio", E, self.weights_pool) bias = torch.matmul(E, self.bias_pool) X_G = torch.einsum("knm,bmc->bknc", supports, X) X_G = X_G.permute(0, 2, 1, 3) X_G = torch.einsum("bnki,nkio->bno", X_G, W) + bias return X_G
[docs]class AGCRN(nn.Module): r"""An implementation of the Adaptive Graph Convolutional Recurrent Unit. For details see: `"Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting" <https://arxiv.org/abs/2007.02842>`_ Args: number_of_nodes (int): Number of vertices. in_channels (int): Number of input features. out_channels (int): Number of output features. K (int): Filter size :math:`K`. embedding_dimensions (int): Number of node embedding dimensions. """ def __init__( self, number_of_nodes: int, in_channels: int, out_channels: int, K: int, embedding_dimensions: int, ): super(AGCRN, self).__init__() self.number_of_nodes = number_of_nodes self.in_channels = in_channels self.out_channels = out_channels self.K = K self.embedding_dimensions = embedding_dimensions self._setup_layers() def _setup_layers(self): self._gate = AVWGCN( in_channels=self.in_channels + self.out_channels, out_channels=2 * self.out_channels, K=self.K, embedding_dimensions=self.embedding_dimensions, ) self._update = AVWGCN( in_channels=self.in_channels + self.out_channels, out_channels=self.out_channels, K=self.K, embedding_dimensions=self.embedding_dimensions, ) def _set_hidden_state(self, X, H): if H is None: H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device) return H
[docs] def forward( self, X: torch.FloatTensor, E: torch.FloatTensor, H: torch.FloatTensor = None ) -> torch.FloatTensor: r"""Making a forward pass. Arg types: * **X** (PyTorch Float Tensor) - Node feature matrix. * **E** (PyTorch Float Tensor) - Node embedding matrix. * **H** (PyTorch Float Tensor) - Node hidden state matrix. Default is None. Return types: * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes. """ H = self._set_hidden_state(X, H) X_H = torch.cat((X, H), dim=-1) Z_R = torch.sigmoid(self._gate(X_H, E)) Z, R = torch.split(Z_R, self.out_channels, dim=-1) C = torch.cat((X, Z * H), dim=-1) HC = torch.tanh(self._update(C, E)) H = R * H + (1 - R) * HC return H