Source code for torch_geometric_temporal.nn.attention.mtgnn

from __future__ import division

import numbers
from typing import Optional

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F


class Linear(nn.Module):
    r"""An implementation of the linear layer, conducting 2D convolution.
    For details see this paper: `"Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks."
    <https://arxiv.org/pdf/2005.11650.pdf>`_

    Args:
        c_in (int): Number of input channels.
        c_out (int): Number of output channels.
        bias (bool, optional): Whether to have bias. Default: True.
    """

    def __init__(self, c_in: int, c_out: int, bias: bool = True):
        super(Linear, self).__init__()
        self._mlp = torch.nn.Conv2d(
            c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=bias
        )

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)

    def forward(self, X: torch.FloatTensor) -> torch.FloatTensor:
        """
        Making a forward pass of the linear layer.

        Arg types:
            * **X** (Pytorch Float Tensor) - Input tensor, with shape (batch_size, c_in, num_nodes, seq_len).

        Return types:
            * **X** (PyTorch Float Tensor) - Output tensor, with shape (batch_size, c_out, num_nodes, seq_len).
        """
        return self._mlp(X)


class MixProp(nn.Module):
    r"""An implementation of the dynatic mix-hop propagation layer.
    For details see this paper: `"Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks."
    <https://arxiv.org/pdf/2005.11650.pdf>`_

    Args:
        c_in (int): Number of input channels.
        c_out (int): Number of output channels.
        gdep (int): Depth of graph convolution.
        dropout (float): Dropout rate.
        alpha (float): Ratio of retaining the root nodes's original states, a value between 0 and 1.
    """

    def __init__(self, c_in: int, c_out: int, gdep: int, dropout: float, alpha: float):
        super(MixProp, self).__init__()
        self._mlp = Linear((gdep + 1) * c_in, c_out)
        self._gdep = gdep
        self._dropout = dropout
        self._alpha = alpha

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)

    def forward(self, X: torch.FloatTensor, A: torch.FloatTensor) -> torch.FloatTensor:
        """
        Making a forward pass of mix-hop propagation.

        Arg types:
            * **X** (Pytorch Float Tensor) - Input feature Tensor, with shape (batch_size, c_in, num_nodes, seq_len).
            * **A** (PyTorch Float Tensor) - Adjacency matrix, with shape (num_nodes, num_nodes).

        Return types:
            * **H_0** (PyTorch Float Tensor) - Hidden representation for all nodes, with shape (batch_size, c_out, num_nodes, seq_len).
        """
        A = A + torch.eye(A.size(0)).to(X.device)
        d = A.sum(1)
        H = X
        H_0 = X
        A = A / d.view(-1, 1)
        for _ in range(self._gdep):
            H = self._alpha * X + (1 - self._alpha) * torch.einsum(
                "ncwl,vw->ncvl", (H, A)
            )
            H_0 = torch.cat((H_0, H), dim=1)
        H_0 = self._mlp(H_0)
        return H_0


class DilatedInception(nn.Module):
    r"""An implementation of the dilated inception layer.
    For details see this paper: `"Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks."
    <https://arxiv.org/pdf/2005.11650.pdf>`_

    Args:
        c_in (int): Number of input channels.
        c_out (int): Number of output channels.
        kernel_set (list of int): List of kernel sizes.
        dilated_factor (int, optional): Dilation factor.
    """

    def __init__(self, c_in: int, c_out: int, kernel_set: list, dilation_factor: int):
        super(DilatedInception, self).__init__()
        self._time_conv = nn.ModuleList()
        self._kernel_set = kernel_set
        c_out = int(c_out / len(self._kernel_set))
        for kern in self._kernel_set:
            self._time_conv.append(
                nn.Conv2d(c_in, c_out, (1, kern), dilation=(1, dilation_factor))
            )
        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)

    def forward(self, X_in: torch.FloatTensor) -> torch.FloatTensor:
        """
        Making a forward pass of dilated inception.

        Arg types:
            * **X_in** (Pytorch Float Tensor) - Input feature Tensor, with shape (batch_size, c_in, num_nodes, seq_len).

        Return types:
            * **X** (PyTorch Float Tensor) - Hidden representation for all nodes,
            with shape (batch_size, c_out, num_nodes, seq_len-6).
        """
        X = []
        for i in range(len(self._kernel_set)):
            X.append(self._time_conv[i](X_in))
        for i in range(len(self._kernel_set)):
            X[i] = X[i][..., -X[-1].size(3) :]
        X = torch.cat(X, dim=1)
        return X


[docs]class GraphConstructor(nn.Module): r"""An implementation of the graph learning layer to construct an adjacency matrix. For details see this paper: `"Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks." <https://arxiv.org/pdf/2005.11650.pdf>`_ Args: nnodes (int): Number of nodes in the graph. k (int): Number of largest values to consider in constructing the neighbourhood of a node (pick the "nearest" k nodes). dim (int): Dimension of the node embedding. alpha (float, optional): Tanh alpha for generating adjacency matrix, alpha controls the saturation rate xd (int, optional): Static feature dimension, default None. """ def __init__( self, nnodes: int, k: int, dim: int, alpha: float, xd: Optional[int] = None ): super(GraphConstructor, self).__init__() if xd is not None: self._static_feature_dim = xd self._linear1 = nn.Linear(xd, dim) self._linear2 = nn.Linear(xd, dim) else: self._embedding1 = nn.Embedding(nnodes, dim) self._embedding2 = nn.Embedding(nnodes, dim) self._linear1 = nn.Linear(dim, dim) self._linear2 = nn.Linear(dim, dim) self._k = k self._alpha = alpha self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.uniform_(p)
[docs] def forward( self, idx: torch.LongTensor, FE: Optional[torch.FloatTensor] = None ) -> torch.FloatTensor: """ Making a forward pass to construct an adjacency matrix from node embeddings. Arg types: * **idx** (Pytorch Long Tensor) - Input indices, a permutation of the number of nodes, default None (no permutation). * **FE** (Pytorch Float Tensor, optional) - Static feature, default None. Return types: * **A** (PyTorch Float Tensor) - Adjacency matrix constructed from node embeddings. """ if FE is None: nodevec1 = self._embedding1(idx) nodevec2 = self._embedding2(idx) else: assert FE.shape[1] == self._static_feature_dim nodevec1 = FE[idx, :] nodevec2 = nodevec1 nodevec1 = torch.tanh(self._alpha * self._linear1(nodevec1)) nodevec2 = torch.tanh(self._alpha * self._linear2(nodevec2)) a = torch.mm(nodevec1, nodevec2.transpose(1, 0)) - torch.mm( nodevec2, nodevec1.transpose(1, 0) ) A = F.relu(torch.tanh(self._alpha * a)) mask = torch.zeros(idx.size(0), idx.size(0)).to(A.device) mask.fill_(float("0")) s1, t1 = A.topk(self._k, 1) mask.scatter_(1, t1, s1.fill_(1)) A = A * mask return A
class LayerNormalization(nn.Module): __constants__ = ["normalized_shape", "weight", "bias", "eps", "elementwise_affine"] r"""An implementation of the layer normalization layer. For details see this paper: `"Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks." <https://arxiv.org/pdf/2005.11650.pdf>`_ Args: normalized_shape (int): Input shape from an expected input of size. eps (float, optional): Value added to the denominator for numerical stability. Default: 1e-5. elementwise_affine (bool, optional): Whether to conduct elementwise affine transformation or not. Default: True. """ def __init__( self, normalized_shape: int, eps: float = 1e-5, elementwise_affine: bool = True ): super(LayerNormalization, self).__init__() self._normalized_shape = tuple(normalized_shape) self._eps = eps self._elementwise_affine = elementwise_affine if self._elementwise_affine: self._weight = nn.Parameter(torch.Tensor(*normalized_shape)) self._bias = nn.Parameter(torch.Tensor(*normalized_shape)) else: self.register_parameter("_weight", None) self.register_parameter("_bias", None) self._reset_parameters() def _reset_parameters(self): if self._elementwise_affine: init.ones_(self._weight) init.zeros_(self._bias) def forward(self, X: torch.FloatTensor, idx: torch.LongTensor) -> torch.FloatTensor: """ Making a forward pass of layer normalization. Arg types: * **X** (Pytorch Float Tensor) - Input tensor, with shape (batch_size, feature_dim, num_nodes, seq_len). * **idx** (Pytorch Long Tensor) - Input indices. Return types: * **X** (PyTorch Float Tensor) - Output tensor, with shape (batch_size, feature_dim, num_nodes, seq_len). """ if self._elementwise_affine: return F.layer_norm( X, tuple(X.shape[1:]), self._weight[:, idx, :], self._bias[:, idx, :], self._eps, ) else: return F.layer_norm( X, tuple(X.shape[1:]), self._weight, self._bias, self._eps ) class MTGNNLayer(nn.Module): r"""An implementation of the MTGNN layer. For details see this paper: `"Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks." <https://arxiv.org/pdf/2005.11650.pdf>`_ Args: dilation_exponential (int): Dilation exponential. rf_size_i (int): Size of receptive field. kernel_size (int): Size of kernel for convolution, to calculate receptive field size. j (int): Iteration index. residual_channels (int): Residual channels. conv_channels (int): Convolution channels. skip_channels (int): Skip channels. kernel_set (list of int): List of kernel sizes. new_dilation (int): Dilation. layer_norm_affline (bool): Whether to do elementwise affine in Layer Normalization. gcn_true (bool): Whether to add graph convolution layer. seq_length (int): Length of input sequence. receptive_field (int): Receptive field. dropout (float): Droupout rate. gcn_depth (int): Graph convolution depth. num_nodes (int): Number of nodes in the graph. propalpha (float): Prop alpha, ratio of retaining the root nodes's original states in mix-hop propagation, a value between 0 and 1. """ def __init__( self, dilation_exponential: int, rf_size_i: int, kernel_size: int, j: int, residual_channels: int, conv_channels: int, skip_channels: int, kernel_set: list, new_dilation: int, layer_norm_affline: bool, gcn_true: bool, seq_length: int, receptive_field: int, dropout: float, gcn_depth: int, num_nodes: int, propalpha: float, ): super(MTGNNLayer, self).__init__() self._dropout = dropout self._gcn_true = gcn_true if dilation_exponential > 1: rf_size_j = int( rf_size_i + (kernel_size - 1) * (dilation_exponential ** j - 1) / (dilation_exponential - 1) ) else: rf_size_j = rf_size_i + j * (kernel_size - 1) self._filter_conv = DilatedInception( residual_channels, conv_channels, kernel_set=kernel_set, dilation_factor=new_dilation, ) self._gate_conv = DilatedInception( residual_channels, conv_channels, kernel_set=kernel_set, dilation_factor=new_dilation, ) self._residual_conv = nn.Conv2d( in_channels=conv_channels, out_channels=residual_channels, kernel_size=(1, 1), ) if seq_length > receptive_field: self._skip_conv = nn.Conv2d( in_channels=conv_channels, out_channels=skip_channels, kernel_size=(1, seq_length - rf_size_j + 1), ) else: self._skip_conv = nn.Conv2d( in_channels=conv_channels, out_channels=skip_channels, kernel_size=(1, receptive_field - rf_size_j + 1), ) if gcn_true: self._mixprop_conv1 = MixProp( conv_channels, residual_channels, gcn_depth, dropout, propalpha ) self._mixprop_conv2 = MixProp( conv_channels, residual_channels, gcn_depth, dropout, propalpha ) if seq_length > receptive_field: self._normalization = LayerNormalization( (residual_channels, num_nodes, seq_length - rf_size_j + 1), elementwise_affine=layer_norm_affline, ) else: self._normalization = LayerNormalization( (residual_channels, num_nodes, receptive_field - rf_size_j + 1), elementwise_affine=layer_norm_affline, ) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.uniform_(p) def forward( self, X: torch.FloatTensor, X_skip: torch.FloatTensor, A_tilde: Optional[torch.FloatTensor], idx: torch.LongTensor, training: bool, ) -> torch.FloatTensor: """ Making a forward pass of MTGNN layer. Arg types: * **X** (PyTorch FloatTensor) - Input feature tensor, with shape (batch_size, in_dim, num_nodes, seq_len). * **X_skip** (PyTorch FloatTensor) - Input feature tensor for skip connection, with shape (batch_size, in_dim, num_nodes, seq_len). * **A_tilde** (Pytorch FloatTensor or None) - Predefined adjacency matrix. * **idx** (Pytorch LongTensor) - Input indices. * **training** (bool) - Whether in traning mode. Return types: * **X** (PyTorch FloatTensor) - Output sequence tensor, with shape (batch_size, seq_len, num_nodes, seq_len). * **X_skip** (PyTorch FloatTensor) - Output feature tensor for skip connection, with shape (batch_size, in_dim, num_nodes, seq_len). """ X_residual = X X_filter = self._filter_conv(X) X_filter = torch.tanh(X_filter) X_gate = self._gate_conv(X) X_gate = torch.sigmoid(X_gate) X = X_filter * X_gate X = F.dropout(X, self._dropout, training=training) X_skip = self._skip_conv(X) + X_skip if self._gcn_true: X = self._mixprop_conv1(X, A_tilde) + self._mixprop_conv2( X, A_tilde.transpose(1, 0) ) else: X = self._residual_conv(X) X = X + X_residual[:, :, :, -X.size(3) :] X = self._normalization(X, idx) return X, X_skip
[docs]class MTGNN(nn.Module): r"""An implementation of the Multivariate Time Series Forecasting Graph Neural Networks. For details see this paper: `"Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks." <https://arxiv.org/pdf/2005.11650.pdf>`_ Args: gcn_true (bool): Whether to add graph convolution layer. build_adj (bool): Whether to construct adaptive adjacency matrix. gcn_depth (int): Graph convolution depth. num_nodes (int): Number of nodes in the graph. kernel_set (list of int): List of kernel sizes. kernel_size (int): Size of kernel for convolution, to calculate receptive field size. dropout (float): Droupout rate. subgraph_size (int): Size of subgraph. node_dim (int): Dimension of nodes. dilation_exponential (int): Dilation exponential. conv_channels (int): Convolution channels. residual_channels (int): Residual channels. skip_channels (int): Skip channels. end_channels (int): End channels. seq_length (int): Length of input sequence. in_dim (int): Input dimension. out_dim (int): Output dimension. layers (int): Number of layers. propalpha (float): Prop alpha, ratio of retaining the root nodes's original states in mix-hop propagation, a value between 0 and 1. tanhalpha (float): Tanh alpha for generating adjacency matrix, alpha controls the saturation rate. layer_norm_affline (bool): Whether to do elementwise affine in Layer Normalization. xd (int, optional): Static feature dimension, default None. """ def __init__( self, gcn_true: bool, build_adj: bool, gcn_depth: int, num_nodes: int, kernel_set: list, kernel_size: int, dropout: float, subgraph_size: int, node_dim: int, dilation_exponential: int, conv_channels: int, residual_channels: int, skip_channels: int, end_channels: int, seq_length: int, in_dim: int, out_dim: int, layers: int, propalpha: float, tanhalpha: float, layer_norm_affline: bool, xd: Optional[int] = None, ): super(MTGNN, self).__init__() self._gcn_true = gcn_true self._build_adj_true = build_adj self._num_nodes = num_nodes self._dropout = dropout self._seq_length = seq_length self._layers = layers self._idx = torch.arange(self._num_nodes) self._mtgnn_layers = nn.ModuleList() self._graph_constructor = GraphConstructor( num_nodes, subgraph_size, node_dim, alpha=tanhalpha, xd=xd ) self._set_receptive_field(dilation_exponential, kernel_size, layers) new_dilation = 1 for j in range(1, layers + 1): self._mtgnn_layers.append( MTGNNLayer( dilation_exponential=dilation_exponential, rf_size_i=1, kernel_size=kernel_size, j=j, residual_channels=residual_channels, conv_channels=conv_channels, skip_channels=skip_channels, kernel_set=kernel_set, new_dilation=new_dilation, layer_norm_affline=layer_norm_affline, gcn_true=gcn_true, seq_length=seq_length, receptive_field=self._receptive_field, dropout=dropout, gcn_depth=gcn_depth, num_nodes=num_nodes, propalpha=propalpha, ) ) new_dilation *= dilation_exponential self._setup_conv( in_dim, skip_channels, end_channels, residual_channels, out_dim ) self._reset_parameters() def _setup_conv( self, in_dim, skip_channels, end_channels, residual_channels, out_dim ): self._start_conv = nn.Conv2d( in_channels=in_dim, out_channels=residual_channels, kernel_size=(1, 1) ) if self._seq_length > self._receptive_field: self._skip_conv_0 = nn.Conv2d( in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self._seq_length), bias=True, ) self._skip_conv_E = nn.Conv2d( in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, self._seq_length - self._receptive_field + 1), bias=True, ) else: self._skip_conv_0 = nn.Conv2d( in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self._receptive_field), bias=True, ) self._skip_conv_E = nn.Conv2d( in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, 1), bias=True, ) self._end_conv_1 = nn.Conv2d( in_channels=skip_channels, out_channels=end_channels, kernel_size=(1, 1), bias=True, ) self._end_conv_2 = nn.Conv2d( in_channels=end_channels, out_channels=out_dim, kernel_size=(1, 1), bias=True, ) def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.uniform_(p) def _set_receptive_field(self, dilation_exponential, kernel_size, layers): if dilation_exponential > 1: self._receptive_field = int( 1 + (kernel_size - 1) * (dilation_exponential ** layers - 1) / (dilation_exponential - 1) ) else: self._receptive_field = layers * (kernel_size - 1) + 1
[docs] def forward( self, X_in: torch.FloatTensor, A_tilde: Optional[torch.FloatTensor] = None, idx: Optional[torch.LongTensor] = None, FE: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ Making a forward pass of MTGNN. Arg types: * **X_in** (PyTorch FloatTensor) - Input sequence, with shape (batch_size, in_dim, num_nodes, seq_len). * **A_tilde** (Pytorch FloatTensor, optional) - Predefined adjacency matrix, default None. * **idx** (Pytorch LongTensor, optional) - Input indices, a permutation of the num_nodes, default None (no permutation). * **FE** (Pytorch FloatTensor, optional) - Static feature, default None. Return types: * **X** (PyTorch FloatTensor) - Output sequence for prediction, with shape (batch_size, seq_len, num_nodes, 1). """ seq_len = X_in.size(3) assert ( seq_len == self._seq_length ), "Input sequence length not equal to preset sequence length." if self._seq_length < self._receptive_field: X_in = nn.functional.pad( X_in, (self._receptive_field - self._seq_length, 0, 0, 0) ) if self._gcn_true: if self._build_adj_true: if idx is None: A_tilde = self._graph_constructor(self._idx.to(X_in.device), FE=FE) else: A_tilde = self._graph_constructor(idx, FE=FE) X = self._start_conv(X_in) X_skip = self._skip_conv_0( F.dropout(X_in, self._dropout, training=self.training) ) if idx is None: for mtgnn in self._mtgnn_layers: X, X_skip = mtgnn( X, X_skip, A_tilde, self._idx.to(X_in.device), self.training ) else: for mtgnn in self._mtgnn_layers: X, X_skip = mtgnn(X, X_skip, A_tilde, idx, self.training) X_skip = self._skip_conv_E(X) + X_skip X = F.relu(X_skip) X = F.relu(self._end_conv_1(X)) X = self._end_conv_2(X) return X