Source code for torch_geometric_temporal.nn.attention.tsagcn

import math
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from torch_geometric.utils.to_dense_adj import to_dense_adj
import torch.nn.functional as F


class GraphAAGCN:
    r"""
    Defining the Graph for the Two-Stream Adaptive Graph Convolutional Network.
    It's composed of the normalized inward-links, outward-links and
    self-links between the nodes as originally defined in the
    `authors repo  <https://github.com/lshiwjx/2s-AGCN/blob/master/graph/tools.py>`
    resulting in the shape of (3, num_nodes, num_nodes).
    Args:
        edge_index (Tensor array): Edge indices
        num_nodes (int): Number of nodes
    Return types:
            * **A** (PyTorch Float Tensor) - Three layer normalized adjacency matrix
    """

    def __init__(self, edge_index: list, num_nodes: int):
        self.num_nodes = num_nodes
        self.edge_index = edge_index
        self.A = self.get_spatial_graph(self.num_nodes)

    def get_spatial_graph(self, num_nodes):
        self_mat = torch.eye(num_nodes)
        inward_mat = torch.squeeze(to_dense_adj(self.edge_index))
        inward_mat_norm = F.normalize(inward_mat, dim=0, p=1)
        outward_mat = inward_mat.transpose(0, 1)
        outward_mat_norm = F.normalize(outward_mat, dim=0, p=1)
        adj_mat = torch.stack((self_mat, inward_mat_norm, outward_mat_norm))
        return adj_mat


[docs]class UnitTCN(nn.Module): r""" Temporal Convolutional Block applied to nodes in the Two-Stream Adaptive Graph Convolutional Network as originally implemented in the `Github Repo <https://github.com/lshiwjx/2s-AGCN>`. For implementational details see https://arxiv.org/abs/1805.07694 Args: in_channels (int): Number of input features. out_channels (int): Number of output features. kernel_size (int): Convolutional kernel size. (default: :obj:`9`) stride (int): Temporal Convolutional kernel stride. (default: :obj:`1`) """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 9, stride: int = 1 ): super(UnitTCN, self).__init__() pad = int((kernel_size - 1) / 2) self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0), stride=(stride, 1), ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self._conv_init(self.conv) self._bn_init(self.bn, 1) def _bn_init(self, bn, scale): nn.init.constant_(bn.weight, scale) nn.init.constant_(bn.bias, 0) def _conv_init(self, conv): nn.init.kaiming_normal_(conv.weight, mode="fan_out") nn.init.constant_(conv.bias, 0)
[docs] def forward(self, x): x = self.bn(self.conv(x)) return x
[docs]class UnitGCN(nn.Module): r""" Graph Convolutional Block applied to nodes in the Two-Stream Adaptive Graph Convolutional Network as originally implemented in the `Github Repo <https://github.com/lshiwjx/2s-AGCN>`. For implementational details see https://arxiv.org/abs/1805.07694. Temporal attention, spatial attention and channel-wise attention will be applied. Args: in_channels (int): Number of input features. out_channels (int): Number of output features. A (Tensor array): Adaptive Graph. coff_embedding (int, optional): Coefficient Embeddings. (default: :int:`4`) num_subset (int, optional): Subsets for adaptive graphs, see :math:`\mathbf{A}, \mathbf{B}, \mathbf{C}` in https://arxiv.org/abs/1805.07694 for details. (default: :int:`3`) adaptive (bool, optional): Apply Adaptive Graph Convolutions. (default: :obj:`True`) attention (bool, optional): Apply Attention. (default: :obj:`True`) """ def __init__( self, in_channels: int, out_channels: int, A: torch.FloatTensor, coff_embedding: int = 4, num_subset: int = 3, adaptive: bool = True, attention: bool = True, ): super(UnitGCN, self).__init__() self.inter_c = out_channels // coff_embedding self.out_c = out_channels self.in_c = in_channels self.num_subset = num_subset self.A = A self.num_jpts = A.shape[-1] self.attention = attention self.adaptive = adaptive self.conv_d = nn.ModuleList() for i in range(self.num_subset): self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1)) if self.adaptive: self._init_adaptive_layers() else: self.A = Variable(self.A, requires_grad=False) if self.attention: self._init_attention_layers() if in_channels != out_channels: self.down = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1), nn.BatchNorm2d(out_channels) ) else: self.down = lambda x: x self.bn = nn.BatchNorm2d(out_channels) self.soft = nn.Softmax(-2) self.tan = nn.Tanh() self.sigmoid = nn.Sigmoid() self.relu = nn.ReLU(inplace=True) self._init_conv_bn() def _bn_init(self, bn, scale): nn.init.constant_(bn.weight, scale) nn.init.constant_(bn.bias, 0) def _conv_init(self, conv): nn.init.kaiming_normal_(conv.weight, mode="fan_out") nn.init.constant_(conv.bias, 0) def _conv_branch_init(self, conv, branches): weight = conv.weight n = weight.size(0) k1 = weight.size(1) k2 = weight.size(2) nn.init.normal_(weight, 0, math.sqrt(2.0 / (n * k1 * k2 * branches))) nn.init.constant_(conv.bias, 0) def _init_conv_bn(self): for m in self.modules(): if isinstance(m, nn.Conv2d): self._conv_init(m) elif isinstance(m, nn.BatchNorm2d): self._bn_init(m, 1) self._bn_init(self.bn, 1e-6) for i in range(self.num_subset): self._conv_branch_init(self.conv_d[i], self.num_subset) def _init_attention_layers(self): # temporal attention self.conv_ta = nn.Conv1d(self.out_c, 1, 9, padding=4) nn.init.constant_(self.conv_ta.weight, 0) nn.init.constant_(self.conv_ta.bias, 0) # s attention ker_jpt = self.num_jpts - 1 if not self.num_jpts % 2 else self.num_jpts pad = (ker_jpt - 1) // 2 self.conv_sa = nn.Conv1d(self.out_c, 1, ker_jpt, padding=pad) nn.init.xavier_normal_(self.conv_sa.weight) nn.init.constant_(self.conv_sa.bias, 0) # channel attention rr = 2 self.fc1c = nn.Linear(self.out_c, self.out_c // rr) self.fc2c = nn.Linear(self.out_c // rr, self.out_c) nn.init.kaiming_normal_(self.fc1c.weight) nn.init.constant_(self.fc1c.bias, 0) nn.init.constant_(self.fc2c.weight, 0) nn.init.constant_(self.fc2c.bias, 0) def _init_adaptive_layers(self): self.PA = nn.Parameter(self.A) self.alpha = nn.Parameter(torch.zeros(1)) self.conv_a = nn.ModuleList() self.conv_b = nn.ModuleList() for i in range(self.num_subset): self.conv_a.append(nn.Conv2d(self.in_c, self.inter_c, 1)) self.conv_b.append(nn.Conv2d(self.in_c, self.inter_c, 1)) def _attentive_forward(self, y): # spatial attention se = y.mean(-2) # N C V se1 = self.sigmoid(self.conv_sa(se)) y = y * se1.unsqueeze(-2) + y # temporal attention se = y.mean(-1) se1 = self.sigmoid(self.conv_ta(se)) y = y * se1.unsqueeze(-1) + y # channel attention se = y.mean(-1).mean(-1) se1 = self.relu(self.fc1c(se)) se2 = self.sigmoid(self.fc2c(se1)) y = y * se2.unsqueeze(-1).unsqueeze(-1) + y return y def _adaptive_forward(self, x, y): N, C, T, V = x.size() A = self.PA for i in range(self.num_subset): A1 = ( self.conv_a[i](x) .permute(0, 3, 1, 2) .contiguous() .view(N, V, self.inter_c * T) ) A2 = self.conv_b[i](x).view(N, self.inter_c * T, V) A1 = self.tan(torch.matmul(A1, A2) / A1.size(-1)) # N V V A1 = A[i] + A1 * self.alpha A2 = x.view(N, C * T, V) z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V)) y = z + y if y is not None else z return y def _non_adaptive_forward(self, x, y): N, C, T, V = x.size() for i in range(self.num_subset): A1 = self.A[i] A2 = x.view(N, C * T, V) z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V)) y = z + y if y is not None else z return y
[docs] def forward(self, x): N, C, T, V = x.size() y = None if self.adaptive: y = self._adaptive_forward(x, y) else: y = self._non_adaptive_forward(x, y) y = self.bn(y) y += self.down(x) y = self.relu(y) if self.attention: y = self._attentive_forward(y) return y
[docs]class AAGCN(nn.Module): r"""Two-Stream Adaptive Graph Convolutional Network. For details see this paper: `"Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition." <https://arxiv.org/abs/1805.07694>`_. This implementation is based on the authors Github Repo https://github.com/lshiwjx/2s-AGCN. It's used by the author for classifying actions from sequences of 3D body joint coordinates. Args: in_channels (int): Number of input features. out_channels (int): Number of output features. edge_index (PyTorch LongTensor): Graph edge indices. num_nodes (int): Number of nodes in the network. stride (int, optional): Time strides during temporal convolution. (default: :obj:`1`) residual (bool, optional): Applying residual connection. (default: :obj:`True`) adaptive (bool, optional): Adaptive node connection weights. (default: :obj:`True`) attention (bool, optional): Applying spatial-temporal-channel-attention. (default: :obj:`True`) """ def __init__( self, in_channels: int, out_channels: int, edge_index: torch.LongTensor, num_nodes: int, stride: int = 1, residual: bool = True, adaptive: bool = True, attention: bool = True, ): super(AAGCN, self).__init__() self.edge_index = edge_index self.num_nodes = num_nodes self.graph = GraphAAGCN(self.edge_index, self.num_nodes) self.A = self.graph.A self.gcn1 = UnitGCN( in_channels, out_channels, self.A, adaptive=adaptive, attention=attention ) self.tcn1 = UnitTCN(out_channels, out_channels, stride=stride) self.relu = nn.ReLU(inplace=True) self.attention = attention if not residual: self.residual = lambda x: 0 elif (in_channels == out_channels) and (stride == 1): self.residual = lambda x: x else: self.residual = UnitTCN( in_channels, out_channels, kernel_size=1, stride=stride )
[docs] def forward(self, x): """ Making a forward pass. Arg types: * **X** (PyTorch FloatTensor) - Node features for T time periods, with shape (B, F_in, T_in, N_nodes). Return types: * **X** (PyTorch FloatTensor)* - Sequence of node features, with shape (B, out_channels, T_in//stride, N_nodes). """ y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x)) return y