Source code for torch_geometric_temporal.nn.hetero.heterogclstm

import torch
from torch.nn import Parameter
from torch_geometric.nn import HeteroConv, SAGEConv
from torch_geometric.nn.inits import glorot
import torch.nn as nn

[docs]class HeteroGCLSTM(torch.nn.Module): r"""An implementation similar to the Integrated Graph Convolutional Long Short Term Memory Cell for heterogeneous Graphs. Args: in_channels_dict (dict of keys=str and values=int): Dimension of each node's input features. out_channels (int): Number of output features. metadata (tuple): Metadata on node types and edge types in the graphs. Can be generated via PyG method :obj:`snapshot.metadata()` where snapshot is a single HeteroData object. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) """ def __init__( self, in_channels_dict: dict, out_channels: int, metadata: tuple, bias: bool = True ): super(HeteroGCLSTM, self).__init__() self.in_channels_dict = in_channels_dict self.out_channels = out_channels self.metadata = metadata self.bias = bias self._create_parameters_and_layers() self._set_parameters() def _create_input_gate_parameters_and_layers(self): self.conv_i = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), out_channels=self.out_channels, bias=self.bias) for edge_type in self.metadata[1]}) self.W_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels)) for node_type, in_channels in self.in_channels_dict.items()}) self.b_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels)) for node_type in self.in_channels_dict}) def _create_forget_gate_parameters_and_layers(self): self.conv_f = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), out_channels=self.out_channels, bias=self.bias) for edge_type in self.metadata[1]}) self.W_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels)) for node_type, in_channels in self.in_channels_dict.items()}) self.b_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels)) for node_type in self.in_channels_dict}) def _create_cell_state_parameters_and_layers(self): self.conv_c = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), out_channels=self.out_channels, bias=self.bias) for edge_type in self.metadata[1]}) self.W_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels)) for node_type, in_channels in self.in_channels_dict.items()}) self.b_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels)) for node_type in self.in_channels_dict}) def _create_output_gate_parameters_and_layers(self): self.conv_o = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), out_channels=self.out_channels, bias=self.bias) for edge_type in self.metadata[1]}) self.W_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels)) for node_type, in_channels in self.in_channels_dict.items()}) self.b_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels)) for node_type in self.in_channels_dict}) def _create_parameters_and_layers(self): self._create_input_gate_parameters_and_layers() self._create_forget_gate_parameters_and_layers() self._create_cell_state_parameters_and_layers() self._create_output_gate_parameters_and_layers() def _set_parameters(self): for key in self.W_i: glorot(self.W_i[key]) for key in self.W_f: glorot(self.W_f[key]) for key in self.W_c: glorot(self.W_c[key]) for key in self.W_o: glorot(self.W_o[key]) for key in self.b_i: glorot(self.b_i[key]) for key in self.b_f: glorot(self.b_f[key]) for key in self.b_c: glorot(self.b_c[key]) for key in self.b_o: glorot(self.b_o[key]) def _set_hidden_state(self, x_dict, h_dict): if h_dict is None: h_dict = {node_type: torch.zeros(X.shape[0], self.out_channels) for node_type, X in x_dict.items()} return h_dict def _set_cell_state(self, x_dict, c_dict): if c_dict is None: c_dict = {node_type: torch.zeros(X.shape[0], self.out_channels) for node_type, X in x_dict.items()} return c_dict def _calculate_input_gate(self, x_dict, edge_index_dict, h_dict, c_dict): i_dict = {node_type: torch.matmul(X, self.W_i[node_type]) for node_type, X in x_dict.items()} conv_i = self.conv_i(h_dict, edge_index_dict) i_dict = {node_type: I + conv_i[node_type] for node_type, I in i_dict.items()} i_dict = {node_type: I + self.b_i[node_type] for node_type, I in i_dict.items()} i_dict = {node_type: torch.sigmoid(I) for node_type, I in i_dict.items()} return i_dict def _calculate_forget_gate(self, x_dict, edge_index_dict, h_dict, c_dict): f_dict = {node_type: torch.matmul(X, self.W_f[node_type]) for node_type, X in x_dict.items()} conv_f = self.conv_f(h_dict, edge_index_dict) f_dict = {node_type: F + conv_f[node_type] for node_type, F in f_dict.items()} f_dict = {node_type: F + self.b_f[node_type] for node_type, F in f_dict.items()} f_dict = {node_type: torch.sigmoid(F) for node_type, F in f_dict.items()} return f_dict def _calculate_cell_state(self, x_dict, edge_index_dict, h_dict, c_dict, i_dict, f_dict): t_dict = {node_type: torch.matmul(X, self.W_c[node_type]) for node_type, X in x_dict.items()} conv_c = self.conv_c(h_dict, edge_index_dict) t_dict = {node_type: T + conv_c[node_type] for node_type, T in t_dict.items()} t_dict = {node_type: T + self.b_c[node_type] for node_type, T in t_dict.items()} t_dict = {node_type: torch.tanh(T) for node_type, T in t_dict.items()} c_dict = {node_type: f_dict[node_type] * C + i_dict[node_type] * t_dict[node_type] for node_type, C in c_dict.items()} return c_dict def _calculate_output_gate(self, x_dict, edge_index_dict, h_dict, c_dict): o_dict = {node_type: torch.matmul(X, self.W_o[node_type]) for node_type, X in x_dict.items()} conv_o = self.conv_o(h_dict, edge_index_dict) o_dict = {node_type: O + conv_o[node_type] for node_type, O in o_dict.items()} o_dict = {node_type: O + self.b_o[node_type] for node_type, O in o_dict.items()} o_dict = {node_type: torch.sigmoid(O) for node_type, O in o_dict.items()} return o_dict def _calculate_hidden_state(self, o_dict, c_dict): h_dict = {node_type: o_dict[node_type] * torch.tanh(C) for node_type, C in c_dict.items()} return h_dict
[docs] def forward( self, x_dict, edge_index_dict, h_dict=None, c_dict=None, ): """ Making a forward pass. If the hidden state and cell state matrix dicts are not present when the forward pass is called these are initialized with zeros. Arg types: * **x_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensors)* - Node features dicts. Can be obtained via PyG method :obj:`snapshot.x_dict` where snapshot is a single HeteroData object. * **edge_index_dict** *(Dictionary where keys=Tuples and values=PyTorch Long Tensors)* - Graph edge type and index dicts. Can be obtained via PyG method :obj:`snapshot.edge_index_dict`. * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and hidden state matrix dict for all nodes. * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and cell state matrix dict for all nodes. Return types: * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and hidden state matrix dict for all nodes. * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and cell state matrix dict for all nodes. """ h_dict = self._set_hidden_state(x_dict, h_dict) c_dict = self._set_cell_state(x_dict, c_dict) i_dict = self._calculate_input_gate(x_dict, edge_index_dict, h_dict, c_dict) f_dict = self._calculate_forget_gate(x_dict, edge_index_dict, h_dict, c_dict) c_dict = self._calculate_cell_state(x_dict, edge_index_dict, h_dict, c_dict, i_dict, f_dict) o_dict = self._calculate_output_gate(x_dict, edge_index_dict, h_dict, c_dict) h_dict = self._calculate_hidden_state(o_dict, c_dict) return h_dict, c_dict