import torch
import numpy as np
from typing import Sequence, Union
from torch_geometric.data import Batch
Edge_Index = Union[np.ndarray, None]
Edge_Weight = Union[np.ndarray, None]
Node_Features = Sequence[Union[np.ndarray, None]]
Targets = Sequence[Union[np.ndarray, None]]
Batches = Union[np.ndarray, None]
Additional_Features = Sequence[np.ndarray]
[docs]class StaticGraphTemporalSignalBatch(object):
r"""A data iterator object to contain a static graph with a dynamically
changing constant time difference temporal feature set (multiple signals).
The node labels (target) are also temporal. The iterator returns a single
constant time difference temporal snapshot for a time period (e.g. day or week).
This single temporal snapshot is a Pytorch Geometric Batch object. Between two
temporal snapshots the feature matrix, target matrices and optionally passed
attributes might change. However, the underlying graph is the same.
Args:
edge_index (Numpy array): Index tensor of edges.
edge_weight (Numpy array): Edge weight tensor.
features (Sequence of Numpy arrays): Sequence of node feature tensors.
targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.
batches (Numpy array): Batch index tensor.
**kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.
"""
def __init__(
self,
edge_index: Edge_Index,
edge_weight: Edge_Weight,
features: Node_Features,
targets: Targets,
batches: Batches,
**kwargs: Additional_Features
):
self.edge_index = edge_index
self.edge_weight = edge_weight
self.features = features
self.targets = targets
self.batches = batches
self.additional_feature_keys = []
for key, value in kwargs.items():
setattr(self, key, value)
self.additional_feature_keys.append(key)
self._check_temporal_consistency()
self._set_snapshot_count()
def _check_temporal_consistency(self):
assert len(self.features) == len(
self.targets
), "Temporal dimension inconsistency."
for key in self.additional_feature_keys:
assert len(self.targets) == len(
getattr(self, key)
), "Temporal dimension inconsistency."
def _set_snapshot_count(self):
self.snapshot_count = len(self.features)
def _get_edge_index(self):
if self.edge_index is None:
return self.edge_index
else:
return torch.LongTensor(self.edge_index)
def _get_batch_index(self):
if self.batches is None:
return self.batches
else:
return torch.LongTensor(self.batches)
def _get_edge_weight(self):
if self.edge_weight is None:
return self.edge_weight
else:
return torch.FloatTensor(self.edge_weight)
def _get_feature(self, time_index: int):
if self.features[time_index] is None:
return self.features[time_index]
else:
return torch.FloatTensor(self.features[time_index])
def _get_target(self, time_index: int):
if self.targets[time_index] is None:
return self.targets[time_index]
else:
if self.targets[time_index].dtype.kind == "i":
return torch.LongTensor(self.targets[time_index])
elif self.targets[time_index].dtype.kind == "f":
return torch.FloatTensor(self.targets[time_index])
def _get_additional_feature(self, time_index: int, feature_key: str):
feature = getattr(self, feature_key)[time_index]
if feature.dtype.kind == "i":
return torch.LongTensor(feature)
elif feature.dtype.kind == "f":
return torch.FloatTensor(feature)
def _get_additional_features(self, time_index: int):
additional_features = {
key: self._get_additional_feature(time_index, key)
for key in self.additional_feature_keys
}
return additional_features
def __getitem__(self, time_index: Union[int, slice]):
if isinstance(time_index, slice):
snapshot = StaticGraphTemporalSignalBatch(
self.edge_index,
self.edge_weight,
self.features[time_index],
self.targets[time_index],
self.batches,
**{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
)
else:
x = self._get_feature(time_index)
edge_index = self._get_edge_index()
edge_weight = self._get_edge_weight()
batch = self._get_batch_index()
y = self._get_target(time_index)
additional_features = self._get_additional_features(time_index)
snapshot = Batch(x=x, edge_index=edge_index, edge_attr=edge_weight,
y=y, batch=batch, **additional_features)
return snapshot
def __next__(self):
if self.t < len(self.features):
snapshot = self[self.t]
self.t = self.t + 1
return snapshot
else:
self.t = 0
raise StopIteration
def __iter__(self):
self.t = 0
return self