import torch
from torch.nn import GRU
from torch_geometric.nn import TopKPooling
from .evolvegcno import glorot, GCNConv_Fixed_W
[docs]class EvolveGCNH(torch.nn.Module):
r"""An implementation of the Evolving Graph Convolutional Hidden Layer.
For details see this paper: `"EvolveGCN: Evolving Graph Convolutional
Networks for Dynamic Graph." <https://arxiv.org/abs/1902.10191>`_
Args:
num_of_nodes (int): Number of vertices.
in_channels (int): Number of filters.
improved (bool, optional): If set to :obj:`True`, the layer computes
:math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
(default: :obj:`False`)
cached (bool, optional): If set to :obj:`True`, the layer will cache
the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
cached version for further executions.
This parameter should only be set to :obj:`True` in transductive
learning scenarios. (default: :obj:`False`)
normalize (bool, optional): Whether to add self-loops and apply
symmetric normalization. (default: :obj:`True`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
"""
def __init__(
self,
num_of_nodes: int,
in_channels: int,
improved: bool = False,
cached: bool = False,
normalize: bool = True,
add_self_loops: bool = True,
):
super(EvolveGCNH, self).__init__()
self.num_of_nodes = num_of_nodes
self.in_channels = in_channels
self.improved = improved
self.cached = cached
self.normalize = normalize
self.add_self_loops = add_self_loops
self.weight = None
self.initial_weight = torch.nn.Parameter(torch.Tensor(in_channels, in_channels))
self._create_layers()
self.reset_parameters()
[docs] def reset_parameters(self):
glorot(self.initial_weight)
def _create_layers(self):
self.ratio = self.in_channels / self.num_of_nodes
self.pooling_layer = TopKPooling(self.in_channels, self.ratio)
self.recurrent_layer = GRU(
input_size=self.in_channels, hidden_size=self.in_channels, num_layers=1
)
self.conv_layer = GCNConv_Fixed_W(
in_channels=self.in_channels,
out_channels=self.in_channels,
improved=self.improved,
cached=self.cached,
normalize=self.normalize,
add_self_loops=self.add_self_loops
)
[docs] def forward(
self,
X: torch.FloatTensor,
edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass.
Arg types:
* **X** *(PyTorch Float Tensor)* - Node embedding.
* **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
* **edge_weight** *(PyTorch Float Tensor, optional)* - Edge weight vector.
Return types:
* **X** *(PyTorch Float Tensor)* - Output matrix for all nodes.
"""
X_tilde = self.pooling_layer(X, edge_index)
X_tilde = X_tilde[0][None, :, :]
if self.weight is None:
self.weight = self.initial_weight.data
W = self.weight[None, :, :]
X_tilde, W = self.recurrent_layer(X_tilde, W)
X = self.conv_layer(W.squeeze(dim=0), X, edge_index, edge_weight)
return X