Source code for torch_geometric_temporal.signal.dynamic_graph_static_signal

import torch
import numpy as np
from typing import Sequence, Union
from torch_geometric.data import Data


Edge_Indices = Sequence[Union[np.ndarray, None]]
Edge_Weights = Sequence[Union[np.ndarray, None]]
Node_Feature = Union[np.ndarray, None]
Targets = Sequence[Union[np.ndarray, None]]
Additional_Features = Sequence[np.ndarray]


[docs]class DynamicGraphStaticSignal(object): r"""A data iterator object to contain a dynamic graph with a changing edge set and weights . The node labels (target) are also dynamic. The iterator returns a single discrete temporal snapshot for a time period (e.g. day or week). This single snapshot is a Pytorch Geometric Data object. Between two temporal snapshots the edges, edge weights, target matrices and optionally passed attributes might change. Args: edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors. edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors. feature (Numpy array): Node feature tensor. targets (Sequence of Numpy arrays): Sequence of node label (target) tensors. **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes. """ def __init__( self, edge_indices: Edge_Indices, edge_weights: Edge_Weights, feature: Node_Feature, targets: Targets, **kwargs: Additional_Features ): self.edge_indices = edge_indices self.edge_weights = edge_weights self.feature = feature self.targets = targets self.additional_feature_keys = [] for key, value in kwargs.items(): setattr(self, key, value) self.additional_feature_keys.append(key) self._check_temporal_consistency() self._set_snapshot_count() def _check_temporal_consistency(self): assert len(self.edge_indices) == len( self.edge_weights ), "Temporal dimension inconsistency." assert len(self.targets) == len( self.edge_indices ), "Temporal dimension inconsistency." for key in self.additional_feature_keys: assert len(self.targets) == len( getattr(self, key) ), "Temporal dimension inconsistency." def _set_snapshot_count(self): self.snapshot_count = len(self.targets) def _get_edge_index(self, time_index: int): if self.edge_indices[time_index] is None: return self.edge_indices[time_index] else: return torch.LongTensor(self.edge_indices[time_index]) def _get_edge_weight(self, time_index: int): if self.edge_weights[time_index] is None: return self.edge_weights[time_index] else: return torch.FloatTensor(self.edge_weights[time_index]) def _get_feature(self): if self.feature is None: return self.feature else: return torch.FloatTensor(self.feature) def _get_target(self, time_index: int): if self.targets[time_index] is None: return self.targets[time_index] else: if self.targets[time_index].dtype.kind == "i": return torch.LongTensor(self.targets[time_index]) elif self.targets[time_index].dtype.kind == "f": return torch.FloatTensor(self.targets[time_index]) def _get_additional_feature(self, time_index: int, feature_key: str): feature = getattr(self, feature_key)[time_index] if feature.dtype.kind == "i": return torch.LongTensor(feature) elif feature.dtype.kind == "f": return torch.FloatTensor(feature) def _get_additional_features(self, time_index: int): additional_features = { key: self._get_additional_feature(time_index, key) for key in self.additional_feature_keys } return additional_features def __len__(self): return len(self.targets) def __getitem__(self, time_index: Union[int, slice]): if isinstance(time_index, slice): snapshot = DynamicGraphStaticSignal( self.edge_indices[time_index], self.edge_weights[time_index], self.feature, self.targets[time_index], **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys} ) else: x = self._get_feature() edge_index = self._get_edge_index(time_index) edge_weight = self._get_edge_weight(time_index) y = self._get_target(time_index) additional_features = self._get_additional_features(time_index) snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight, y=y, **additional_features) return snapshot def __next__(self): if self.t < len(self.targets): snapshot = self[self.t] self.t = self.t + 1 return snapshot else: self.t = 0 raise StopIteration def __iter__(self): self.t = 0 return self