import math
from typing import Optional, List, Union
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.typing import OptTensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.transforms import LaplacianLambdaMax
from torch_geometric.utils import remove_self_loops, add_self_loops, get_laplacian
[docs]class ChebConvAttention(MessagePassing):
r"""The chebyshev spectral graph convolutional operator with attention from the
`Attention Based Spatial-Temporal Graph Convolutional
Networks for Traffic Flow Forecasting." <https://ojs.aaai.org/index.php/AAAI/article/view/3881>`_ paper
:math:`\mathbf{\hat{L}}` denotes the scaled and normalized Laplacian
:math:`\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}`.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
K (int): Chebyshev filter size :math:`K`.
normalization (str, optional): The normalization scheme for the graph
Laplacian (default: :obj:`"sym"`):
1. :obj:`None`: No normalization
:math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
2. :obj:`"sym"`: Symmetric normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}`
3. :obj:`"rw"`: Random-walk normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
You need to pass :obj:`lambda_max` to the :meth:`forward` method of
this operator in case the normalization is non-symmetric.
:obj:`\lambda_max` should be a :class:`torch.Tensor` of size
:obj:`[num_graphs]` in a mini-batch scenario and a
scalar/zero-dimensional tensor when operating on single graphs.
You can pre-compute :obj:`lambda_max` via the
:class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
K: int,
normalization: Optional[str] = None,
bias: bool = True,
**kwargs
):
kwargs.setdefault("aggr", "add")
super(ChebConvAttention, self).__init__(**kwargs)
assert K > 0
assert normalization in [None, "sym", "rw"], "Invalid normalization"
self._in_channels = in_channels
self._out_channels = out_channels
self._normalization = normalization
self._weight = Parameter(torch.Tensor(K, in_channels, out_channels))
if bias:
self._bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("_bias", None)
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_uniform_(self._weight)
if self._bias is not None:
nn.init.uniform_(self._bias)
#--forward pass-----
def __norm__(
self,
edge_index,
num_nodes: Optional[int],
edge_weight: OptTensor,
normalization: Optional[str],
lambda_max,
dtype: Optional[int] = None,
batch: OptTensor = None,
):
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
edge_index, edge_weight = get_laplacian(
edge_index, edge_weight, normalization, dtype, num_nodes
)
if batch is not None and lambda_max.numel() > 1:
lambda_max = lambda_max[batch[edge_index[0]]]
edge_weight = (2.0 * edge_weight) / lambda_max
edge_weight.masked_fill_(edge_weight == float("inf"), 0)
edge_index, edge_weight = add_self_loops(
edge_index, edge_weight, fill_value=-1.0, num_nodes=num_nodes
)
assert edge_weight is not None
return edge_index, edge_weight #for example 307 nodes as deg, 340 edges , 307 nodes as self connections
[docs] def forward(
self,
x: torch.FloatTensor,
edge_index: torch.LongTensor,
spatial_attention: torch.FloatTensor,
edge_weight: OptTensor = None,
batch: OptTensor = None,
lambda_max: OptTensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass of the ChebConv Attention layer (Chebyshev graph convolution operation).
Arg types:
* x (PyTorch Float Tensor) - Node features for T time periods, with shape (B, N_nodes, F_in).
* edge_index (Tensor array) - Edge indices.
* spatial_attention (PyTorch Float Tensor) - Spatial attention weights, with shape (B, N_nodes, N_nodes).
* edge_weight (PyTorch Float Tensor, optional) - Edge weights corresponding to edge indices.
* batch (PyTorch Tensor, optional) - Batch labels for each edge.
* lambda_max (optional, but mandatory if normalization is None) - Largest eigenvalue of Laplacian.
Return types:
* out (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, F_out).
"""
if self._normalization != "sym" and lambda_max is None:
raise ValueError(
"You need to pass `lambda_max` to `forward() in`"
"case the normalization is non-symmetric."
)
if lambda_max is None:
lambda_max = torch.tensor(2.0, dtype=x.dtype, device=x.device)
if not isinstance(lambda_max, torch.Tensor):
lambda_max = torch.tensor(lambda_max, dtype=x.dtype, device=x.device)
assert lambda_max is not None
edge_index, norm = self.__norm__(
edge_index,
x.size(self.node_dim),
edge_weight,
self._normalization,
lambda_max,
dtype=x.dtype,
batch=batch,
)
row, col = edge_index # refer to the index of each note each is a list of nodes not a number # (954, 954)
Att_norm = norm * spatial_attention[:, row, col] # spatial_attention for example (32, 307, 307), -> (954) * (32, 954) -> (32, 954)
num_nodes = x.size(self.node_dim) #for example 307
# (307, 307) * (32, 307, 307) -> (32, 307, 307) -permute-> (32, 307,307) * (32, 307, 1) -> (32, 307, 1)
TAx_0 = torch.matmul(
(torch.eye(num_nodes).to(edge_index.device) * spatial_attention).permute(
0, 2, 1
),
x,
) #for example (32, 307, 1)
out = torch.matmul(TAx_0, self._weight[0]) #for example (32, 307, 1) * [1, 64] -> (32, 307, 64)
edge_index_transpose = edge_index[[1, 0]]
if self._weight.size(0) > 1:
TAx_1 = self.propagate(
edge_index_transpose, x=TAx_0, norm=Att_norm, size=None
)
out = out + torch.matmul(TAx_1, self._weight[1])
for k in range(2, self._weight.size(0)):
TAx_2 = self.propagate(edge_index_transpose, x=TAx_1, norm=norm, size=None)
TAx_2 = 2.0 * TAx_2 - TAx_0
out = out + torch.matmul(TAx_2, self._weight[k])
TAx_0, TAx_1 = TAx_1, TAx_2
if self._bias is not None:
out += self._bias
return out #? (b, N, F_out) (32, 307, 64)
[docs] def message(self, x_j, norm):
if norm.dim() == 1: # true
return norm.view(-1, 1) * x_j # (954, 1) * (32, 954, 1) -> (32, 954, 1)
else:
d1, d2 = norm.shape
return norm.view(d1, d2, 1) * x_j
def __repr__(self):
return "{}({}, {}, K={}, normalization={})".format(
self.__class__.__name__,
self._in_channels,
self._out_channels,
self._weight.size(0),
self._normalization,
)
class SpatialAttention(nn.Module):
r"""An implementation of the Spatial Attention Module (i.e compute spatial attention scores). For details see this paper:
`"Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow
Forecasting." <https://ojs.aaai.org/index.php/AAAI/article/view/3881>`_
Args:
in_channels (int): Number of input features.
num_of_vertices (int): Number of vertices in the graph.
num_of_timesteps (int): Number of time lags.
"""
def __init__(self, in_channels: int, num_of_vertices: int, num_of_timesteps: int):
super(SpatialAttention, self).__init__()
self._W1 = nn.Parameter(torch.FloatTensor(num_of_timesteps)) #for example (12)
self._W2 = nn.Parameter(torch.FloatTensor(in_channels, num_of_timesteps)) #for example (1, 12)
self._W3 = nn.Parameter(torch.FloatTensor(in_channels)) #for example (1)
self._bs = nn.Parameter(torch.FloatTensor(1, num_of_vertices, num_of_vertices)) #for example (1,307, 307)
self._Vs = nn.Parameter(torch.FloatTensor(num_of_vertices, num_of_vertices)) #for example (307, 307)
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
else:
nn.init.uniform_(p)
def forward(self, X: torch.FloatTensor) -> torch.FloatTensor:
"""
Making a forward pass of the spatial attention layer.
Arg types:
* **X** (PyTorch FloatTensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).
Return types:
* **S** (PyTorch FloatTensor) - Spatial attention score matrices, with shape (B, N_nodes, N_nodes).
"""
# lhs = left hand side embedding;
# to calculcate it :
# multiply with W1 (B, N, F_in, T)(T) -> (B,N,F_in)
# multiply with W2 (B,N,F_in)(F_in,T)->(B,N,T)
# for example (32, 307, 1, 12) * (12) -> (32, 307, 1) * (1, 12) -> (32, 307, 12)
LHS = torch.matmul(torch.matmul(X, self._W1), self._W2)
# rhs = right hand side embedding
# to calculcate it :
# mutliple W3 with X (F)(B,N,F,T)->(B, N, T)
# transpose (B, N, T) -> (B, T, N)
# for example (1)(32, 307, 1, 12) -> (32, 307, 12) -transpose-> (32, 12, 307)
RHS = torch.matmul(self._W3, X).transpose(-1, -2)
# Then, we multiply LHS with RHS :
# (B,N,T)(B,T, N)->(B,N,N)
# for example (32, 307, 12) * (32, 12, 307) -> (32, 307, 307)
# Then multiply Vs(N,N) with the output
# (N,N)(B, N, N)->(B,N,N) (32, 307, 307)
# for example (307, 307) * (32, 307, 307) -> (32, 307, 307)
S = torch.matmul(self._Vs, torch.sigmoid(torch.matmul(LHS, RHS) + self._bs))
S = F.softmax(S, dim=1)
return S # (B,N,N) for example (32, 307, 307)
class TemporalAttention(nn.Module):
r"""An implementation of the Temporal Attention Module( i.e. compute temporal attention scores). For details see this paper:
`"Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow
Forecasting." <https://ojs.aaai.org/index.php/AAAI/article/view/3881>`_
Args:
in_channels (int): Number of input features.
num_of_vertices (int): Number of vertices in the graph.
num_of_timesteps (int): Number of time lags.
"""
def __init__(self, in_channels: int, num_of_vertices: int, num_of_timesteps: int):
super(TemporalAttention, self).__init__()
self._U1 = nn.Parameter(torch.FloatTensor(num_of_vertices)) # for example 307
self._U2 = nn.Parameter(torch.FloatTensor(in_channels, num_of_vertices)) #for example (1, 307)
self._U3 = nn.Parameter(torch.FloatTensor(in_channels)) # for example (1)
self._be = nn.Parameter(
torch.FloatTensor(1, num_of_timesteps, num_of_timesteps)
) # for example (1,12,12)
self._Ve = nn.Parameter(torch.FloatTensor(num_of_timesteps, num_of_timesteps)) #for example (12, 12)
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
else:
nn.init.uniform_(p)
def forward(self, X: torch.FloatTensor) -> torch.FloatTensor:
"""
Making a forward pass of the temporal attention layer.
Arg types:
* **X** (PyTorch FloatTensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).
Return types:
* **E** (PyTorch FloatTensor) - Temporal attention score matrices, with shape (B, T_in, T_in).
"""
# lhs = left hand side embedding;
# to calculcate it :
# permute x:(B, N, F_in, T) -> (B, T, F_in, N)
# multiply with U1 (B, T, F_in, N)(N) -> (B,T,F_in)
# multiply with U2 (B,T,F_in)(F_in,N)->(B,T,N)
# for example (32, 307, 1, 12) -premute-> (32, 12, 1, 307) * (307) -> (32, 12, 1) * (1, 307) -> (32, 12, 307)
LHS = torch.matmul(torch.matmul(X.permute(0, 3, 2, 1), self._U1), self._U2) # (32, 12, 307)
#rhs = right hand side embedding
# to calculcate it :
# mutliple U3 with X (F)(B,N,F,T)->(B, N, T)
# for example (1)(32, 307, 1, 12) -> (32, 307, 12)
RHS = torch.matmul(self._U3, X) # (32, 307, 12)
# Them we multiply LHS with RHS :
# (B,T,N)(B,N,T)->(B,T,T)
# for example (32, 12, 307) * (32, 307, 12) -> (32, 12, 12)
# Then multiply Ve(T,T) with the output
# (T,T)(B, T, T)->(B,T,T)
# for example (12, 12) * (32, 12, 12) -> (32, 12, 12)
E = torch.matmul(self._Ve, torch.sigmoid(torch.matmul(LHS, RHS) + self._be))
E = F.softmax(E, dim=1) # (B, T, T) for example (32, 12, 12)
return E
class ASTGCNBlock(nn.Module):
r"""An implementation of the Attention Based Spatial-Temporal Graph Convolutional Block.
For details see this paper: `"Attention Based Spatial-Temporal Graph Convolutional
Networks for Traffic Flow Forecasting." <https://ojs.aaai.org/index.php/AAAI/article/view/3881>`_
Args:
in_channels (int): Number of input features.
K (int): Order of Chebyshev polynomials. Degree is K-1.
nb_chev_filter (int): Number of Chebyshev filters.
nb_time_filter (int): Number of time filters.
time_strides (int): Time strides during temporal convolution.
num_of_vertices (int): Number of vertices in the graph.
num_of_timesteps (int): Number of time lags.
normalization (str, optional): The normalization scheme for the graph
Laplacian (default: :obj:`"sym"`):
1. :obj:`None`: No normalization
:math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
2. :obj:`"sym"`: Symmetric normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}`
3. :obj:`"rw"`: Random-walk normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
You need to pass :obj:`lambda_max` to the :meth:`forward` method of
this operator in case the normalization is non-symmetric.
:obj:`\lambda_max` should be a :class:`torch.Tensor` of size
:obj:`[num_graphs]` in a mini-batch scenario and a
scalar/zero-dimensional tensor when operating on single graphs.
You can pre-compute :obj:`lambda_max` via the
:class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
"""
def __init__(
self,
in_channels: int,
K: int,
nb_chev_filter: int,
nb_time_filter: int,
time_strides: int,
num_of_vertices: int,
num_of_timesteps: int,
normalization: Optional[str] = None,
bias: bool = True,
):
super(ASTGCNBlock, self).__init__()
self._temporal_attention = TemporalAttention(
in_channels, num_of_vertices, num_of_timesteps
)
self._spatial_attention = SpatialAttention(
in_channels, num_of_vertices, num_of_timesteps
)
self._chebconv_attention = ChebConvAttention(
in_channels, nb_chev_filter, K, normalization, bias
)
self._time_convolution = nn.Conv2d(
nb_chev_filter,
nb_time_filter,
kernel_size=(1, 3),
stride=(1, time_strides),
padding=(0, 1),
)
self._residual_convolution = nn.Conv2d(
in_channels, nb_time_filter, kernel_size=(1, 1), stride=(1, time_strides)
)
self._layer_norm = nn.LayerNorm(nb_time_filter)
self._normalization = normalization
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
else:
nn.init.uniform_(p)
def forward(
self,
X: torch.FloatTensor,
edge_index: Union[torch.LongTensor, List[torch.LongTensor]],
) -> torch.FloatTensor:
"""
Making a forward pass with the ASTGCN block.
Arg types:
* **X** (PyTorch Float Tensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).
* **edge_index** (LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time.
Return types:
* **X** (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, nb_time_filter, T_out).
"""
batch_size, num_of_vertices, num_of_features, num_of_timesteps = X.shape # (32, 307, 1, 12)
X_tilde = self._temporal_attention(X) # (b, T, T) (32, 12, 12) * reshaped x(32, 307, 12) -reshape> (32, 307, 1, 12)
# xreshaped is e.g. (32, 307, 12) * (32, 12, 12) -then_reshaped> (32, 307, 1, 12)
X_tilde = torch.matmul(X.reshape(batch_size, -1, num_of_timesteps), X_tilde)
X_tilde = X_tilde.reshape(
batch_size, num_of_vertices, num_of_features, num_of_timesteps
)
X_tilde = self._spatial_attention(X_tilde) # (B,N,N) for example (32, 307, 307)
if not isinstance(edge_index, list):
data = Data(
edge_index=edge_index, edge_attr=None, num_nodes=num_of_vertices
)
if self._normalization != "sym":
lambda_max = LaplacianLambdaMax()(data).lambda_max
else:
lambda_max = None
X_hat = []
for t in range(num_of_timesteps):
X_hat.append(
torch.unsqueeze(
self._chebconv_attention(
X[:, :, :, t], edge_index, X_tilde, lambda_max=lambda_max
),
-1,
)
)
X_hat = F.relu(torch.cat(X_hat, dim=-1))
else:
X_hat = []
for t in range(num_of_timesteps):
data = Data(
edge_index=edge_index[t], edge_attr=None, num_nodes=num_of_vertices
)
if self._normalization != "sym":
lambda_max = LaplacianLambdaMax()(data).lambda_max
else:
lambda_max = None
X_hat.append(
torch.unsqueeze(
self._chebconv_attention(
X[:, :, :, t], edge_index[t], X_tilde, lambda_max=lambda_max
),
-1,
)
)
X_hat = F.relu(torch.cat(X_hat, dim=-1))
# (b,N,F,T)->(b,F,N,T) for example (32, 307, 64, 12) -premute->(32, 64, 307,12)
# then convolution along the time axis is applied
X_hat = self._time_convolution(X_hat.permute(0, 2, 1, 3)) # will give (32, 64, 307,12)
# (b,N,F,T)-permute>(b,F,N,T) (1,1)->(b,F,N,T) (32, 64, 307, 12)
X = self._residual_convolution(X.permute(0, 2, 1, 3)) # will also give (32, 64, 307,12)
#-adding X + X_hat->(32, 64, 307, 12)-premuting-> (32, 12, 307, 64)-layer_normalization_-premuting->(32, 307, 64,12)
X = self._layer_norm(F.relu(X + X_hat).permute(0, 3, 2, 1))
X = X.permute(0, 2, 3, 1)
return X # (b,N,F,T) for example (32, 307, 64,12)
[docs]class ASTGCN(nn.Module):
r"""An implementation of the Attention Based Spatial-Temporal Graph Convolutional Cell.
For details see this paper: `"Attention Based Spatial-Temporal Graph Convolutional
Networks for Traffic Flow Forecasting." <https://ojs.aaai.org/index.php/AAAI/article/view/3881>`_
Args:
nb_block (int): Number of ASTGCN blocks in the model.
in_channels (int): Number of input features.
K (int): Order of Chebyshev polynomials. Degree is K-1.
nb_chev_filters (int): Number of Chebyshev filters.
nb_time_filters (int): Number of time filters.
time_strides (int): Time strides during temporal convolution.
edge_index (array): edge indices.
num_for_predict (int): Number of predictions to make in the future.
len_input (int): Length of the input sequence.
num_of_vertices (int): Number of vertices in the graph.
normalization (str, optional): The normalization scheme for the graph
Laplacian (default: :obj:`"sym"`):
1. :obj:`None`: No normalization
:math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
2. :obj:`"sym"`: Symmetric normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}`
3. :obj:`"rw"`: Random-walk normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
You need to pass :obj:`lambda_max` to the :meth:`forward` method of
this operator in case the normalization is non-symmetric.
:obj:`\lambda_max` should be a :class:`torch.Tensor` of size
:obj:`[num_graphs]` in a mini-batch scenario and a
scalar/zero-dimensional tensor when operating on single graphs.
You can pre-compute :obj:`lambda_max` via the
:class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
"""
def __init__(
self,
nb_block: int,
in_channels: int,
K: int,
nb_chev_filter: int,
nb_time_filter: int,
time_strides: int,
num_for_predict: int,
len_input: int,
num_of_vertices: int,
normalization: Optional[str] = None,
bias: bool = True,
):
super(ASTGCN, self).__init__()
self._blocklist = nn.ModuleList(
[
ASTGCNBlock(
in_channels,
K,
nb_chev_filter,
nb_time_filter,
time_strides,
num_of_vertices,
len_input,
normalization,
bias,
)
]
)
self._blocklist.extend(
[
ASTGCNBlock(
nb_time_filter,
K,
nb_chev_filter,
nb_time_filter,
1,
num_of_vertices,
len_input // time_strides,
normalization,
bias,
)
for _ in range(nb_block - 1)
]
)
self._final_conv = nn.Conv2d(
int(len_input / time_strides),
num_for_predict,
kernel_size=(1, nb_time_filter),
)
self._reset_parameters()
def _reset_parameters(self):
"""
Resetting the parameters.
"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
else:
nn.init.uniform_(p)
[docs] def forward(
self, X: torch.FloatTensor, edge_index: torch.LongTensor
) -> torch.FloatTensor:
"""
Making a forward pass.
Arg types:
* **X** (PyTorch FloatTensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).
* **edge_index** (PyTorch LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time.
Return types:
* **X** (PyTorch FloatTensor)* - Hidden state tensor for all nodes, with shape (B, N_nodes, T_out).
"""
for block in self._blocklist:
# original x is (B,N,F_in,T) will give (B,N,F_out,T) for example (32, 307, 1, 12) -> (32, 307, 64, 12)
X = block(X, edge_index)
# (b,N,F,T)->(b,T,N,F)-conv<1,F>->(b,c_out*T,N,1)
# for example (32, 307, 64, 12) -permute-> (32, 12, 307,64) -final_conv-> (32, 12, 307, 1)
X = self._final_conv(X.permute(0, 3, 1, 2))
# (b,c_out*T,N)->(b,N,T)
X = X[:, :, :, -1] # (b,c_out*T,N) for example (32, 12, 307)
X = X.permute(0, 2, 1) # (b,T,N)-> (b,N,T)
return X #(b,N,T) for exmaple (32, 307,12)