import torch
import numpy as np
from typing import Sequence, Union
from torch_geometric.data import Data
Edge_Indices = Sequence[Union[np.ndarray, None]]
Edge_Weights = Sequence[Union[np.ndarray, None]]
Node_Feature = Union[np.ndarray, None]
Targets = Sequence[Union[np.ndarray, None]]
Additional_Features = Sequence[np.ndarray]
[docs]class DynamicGraphStaticSignal(object):
r"""A data iterator object to contain a dynamic graph with a
changing edge set and weights . The node labels
(target) are also dynamic. The iterator returns a single discrete temporal
snapshot for a time period (e.g. day or week). This single snapshot is a
Pytorch Geometric Data object. Between two temporal snapshots the edges,
edge weights, target matrices and optionally passed attributes might change.
Args:
edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors.
edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors.
feature (Numpy array): Node feature tensor.
targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.
**kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.
"""
def __init__(
self,
edge_indices: Edge_Indices,
edge_weights: Edge_Weights,
feature: Node_Feature,
targets: Targets,
**kwargs: Additional_Features
):
self.edge_indices = edge_indices
self.edge_weights = edge_weights
self.feature = feature
self.targets = targets
self.additional_feature_keys = []
for key, value in kwargs.items():
setattr(self, key, value)
self.additional_feature_keys.append(key)
self._check_temporal_consistency()
self._set_snapshot_count()
def _check_temporal_consistency(self):
assert len(self.edge_indices) == len(
self.edge_weights
), "Temporal dimension inconsistency."
assert len(self.targets) == len(
self.edge_indices
), "Temporal dimension inconsistency."
for key in self.additional_feature_keys:
assert len(self.targets) == len(
getattr(self, key)
), "Temporal dimension inconsistency."
def _set_snapshot_count(self):
self.snapshot_count = len(self.targets)
def _get_edge_index(self, time_index: int):
if self.edge_indices[time_index] is None:
return self.edge_indices[time_index]
else:
return torch.LongTensor(self.edge_indices[time_index])
def _get_edge_weight(self, time_index: int):
if self.edge_weights[time_index] is None:
return self.edge_weights[time_index]
else:
return torch.FloatTensor(self.edge_weights[time_index])
def _get_feature(self):
if self.feature is None:
return self.feature
else:
return torch.FloatTensor(self.feature)
def _get_target(self, time_index: int):
if self.targets[time_index] is None:
return self.targets[time_index]
else:
if self.targets[time_index].dtype.kind == "i":
return torch.LongTensor(self.targets[time_index])
elif self.targets[time_index].dtype.kind == "f":
return torch.FloatTensor(self.targets[time_index])
def _get_additional_feature(self, time_index: int, feature_key: str):
feature = getattr(self, feature_key)[time_index]
if feature.dtype.kind == "i":
return torch.LongTensor(feature)
elif feature.dtype.kind == "f":
return torch.FloatTensor(feature)
def _get_additional_features(self, time_index: int):
additional_features = {
key: self._get_additional_feature(time_index, key)
for key in self.additional_feature_keys
}
return additional_features
def __len__(self):
return len(self.targets)
def __getitem__(self, time_index: Union[int, slice]):
if isinstance(time_index, slice):
snapshot = DynamicGraphStaticSignal(
self.edge_indices[time_index],
self.edge_weights[time_index],
self.feature,
self.targets[time_index],
**{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
)
else:
x = self._get_feature()
edge_index = self._get_edge_index(time_index)
edge_weight = self._get_edge_weight(time_index)
y = self._get_target(time_index)
additional_features = self._get_additional_features(time_index)
snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight,
y=y, **additional_features)
return snapshot
def __next__(self):
if self.t < len(self.targets):
snapshot = self[self.t]
self.t = self.t + 1
return snapshot
else:
self.t = 0
raise StopIteration
def __iter__(self):
self.t = 0
return self