import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import ChebConv
[docs]class TemporalConv(nn.Module):
r"""Temporal convolution block applied to nodes in the STGCN Layer
For details see: `"Spatio-Temporal Graph Convolutional Networks:
A Deep Learning Framework for Traffic Forecasting."
<https://arxiv.org/abs/1709.04875>`_ Based off the temporal convolution
introduced in "Convolutional Sequence to Sequence Learning" <https://arxiv.org/abs/1709.04875>`_
Args:
in_channels (int): Number of input features.
out_channels (int): Number of output features.
kernel_size (int): Convolutional kernel size.
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
super(TemporalConv, self).__init__()
self.conv_1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
self.conv_2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
self.conv_3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
[docs] def forward(self, X: torch.FloatTensor) -> torch.FloatTensor:
"""Forward pass through temporal convolution block.
Arg types:
* **X** (torch.FloatTensor) - Input data of shape
(batch_size, input_time_steps, num_nodes, in_channels).
Return types:
* **H** (torch.FloatTensor) - Output data of shape
(batch_size, in_channels, num_nodes, input_time_steps).
"""
X = X.permute(0, 3, 2, 1)
P = self.conv_1(X)
Q = torch.sigmoid(self.conv_2(X))
PQ = P * Q
H = F.relu(PQ + self.conv_3(X))
H = H.permute(0, 3, 2, 1)
return H
[docs]class STConv(nn.Module):
r"""Spatio-temporal convolution block using ChebConv Graph Convolutions.
For details see: `"Spatio-Temporal Graph Convolutional Networks:
A Deep Learning Framework for Traffic Forecasting"
<https://arxiv.org/abs/1709.04875>`_
NB. The ST-Conv block contains two temporal convolutions (TemporalConv)
with kernel size k. Hence for an input sequence of length m,
the output sequence will be length m-2(k-1).
Args:
in_channels (int): Number of input features.
hidden_channels (int): Number of hidden units output by graph convolution block
out_channels (int): Number of output features.
kernel_size (int): Size of the kernel considered.
K (int): Chebyshev filter size :math:`K`.
normalization (str, optional): The normalization scheme for the graph
Laplacian (default: :obj:`"sym"`):
1. :obj:`None`: No normalization
:math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
2. :obj:`"sym"`: Symmetric normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}`
3. :obj:`"rw"`: Random-walk normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
You need to pass :obj:`lambda_max` to the :meth:`forward` method of
this operator in case the normalization is non-symmetric.
:obj:`\lambda_max` should be a :class:`torch.Tensor` of size
:obj:`[num_graphs]` in a mini-batch scenario and a
scalar/zero-dimensional tensor when operating on single graphs.
You can pre-compute :obj:`lambda_max` via the
:class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
"""
def __init__(
self,
num_nodes: int,
in_channels: int,
hidden_channels: int,
out_channels: int,
kernel_size: int,
K: int,
normalization: str = "sym",
bias: bool = True,
):
super(STConv, self).__init__()
self.num_nodes = num_nodes
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.K = K
self.normalization = normalization
self.bias = bias
self._temporal_conv1 = TemporalConv(
in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=kernel_size,
)
self._graph_conv = ChebConv(
in_channels=hidden_channels,
out_channels=hidden_channels,
K=K,
normalization=normalization,
bias=bias,
)
self._temporal_conv2 = TemporalConv(
in_channels=hidden_channels,
out_channels=out_channels,
kernel_size=kernel_size,
)
self._batch_norm = nn.BatchNorm2d(num_nodes)
[docs] def forward(
self,
X: torch.FloatTensor,
edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
) -> torch.FloatTensor:
r"""Forward pass. If edge weights are not present the forward pass
defaults to an unweighted graph.
Arg types:
* **X** (PyTorch FloatTensor) - Sequence of node features of shape (Batch size X Input time steps X Num nodes X In channels).
* **edge_index** (PyTorch LongTensor) - Graph edge indices.
* **edge_weight** (PyTorch LongTensor, optional)- Edge weight vector.
Return types:
* **T** (PyTorch FloatTensor) - Sequence of node features.
"""
T_0 = self._temporal_conv1(X)
T = torch.zeros_like(T_0).to(T_0.device)
for b in range(T_0.size(0)):
for t in range(T_0.size(1)):
T[b][t] = self._graph_conv(T_0[b][t], edge_index, edge_weight)
T = F.relu(T)
T = self._temporal_conv2(T)
T = T.permute(0, 2, 1, 3)
T = self._batch_norm(T)
T = T.permute(0, 2, 1, 3)
return T