import torch
import numpy as np
from typing import Sequence, Union
from torch_geometric.data import Batch
Edge_Indices = Sequence[Union[np.ndarray, None]]
Edge_Weights = Sequence[Union[np.ndarray, None]]
Node_Features = Sequence[Union[np.ndarray, None]]
Targets = Sequence[Union[np.ndarray, None]]
Batches = Sequence[Union[np.ndarray, None]]
Additional_Features = Sequence[np.ndarray]
[docs]class DynamicGraphTemporalSignalBatch(object):
r"""A data iterator object to contain a dynamic graph with a
changing edge set and weights . The feature set and node labels
(target) are also dynamic. The iterator returns a single discrete temporal
snapshot for a time period (e.g. day or week). This single snapshot is a
Pytorch Geometric Batch object. Between two temporal snapshots the edges,
edge weights, the feature matrix, target matrices and optionally passed
attributes might change.
Args:
edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors.
edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors.
features (Sequence of Numpy arrays): Sequence of node feature tensors.
targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.
batches (Sequence of Numpy arrays): Sequence of batch index tensors.
**kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.
"""
def __init__(
self,
edge_indices: Edge_Indices,
edge_weights: Edge_Weights,
features: Node_Features,
targets: Targets,
batches: Batches,
**kwargs: Additional_Features
):
self.edge_indices = edge_indices
self.edge_weights = edge_weights
self.features = features
self.targets = targets
self.batches = batches
self.additional_feature_keys = []
for key, value in kwargs.items():
setattr(self, key, value)
self.additional_feature_keys.append(key)
self._check_temporal_consistency()
self._set_snapshot_count()
def _check_temporal_consistency(self):
assert len(self.features) == len(
self.targets
), "Temporal dimension inconsistency."
assert len(self.edge_indices) == len(
self.edge_weights
), "Temporal dimension inconsistency."
assert len(self.features) == len(
self.edge_weights
), "Temporal dimension inconsistency."
assert len(self.features) == len(
self.batches
), "Temporal dimension inconsistency."
for key in self.additional_feature_keys:
assert len(self.targets) == len(
getattr(self, key)
), "Temporal dimension inconsistency."
def _set_snapshot_count(self):
self.snapshot_count = len(self.features)
def _get_edge_index(self, time_index: int):
if self.edge_indices[time_index] is None:
return self.edge_indices[time_index]
else:
return torch.LongTensor(self.edge_indices[time_index])
def _get_batch_index(self, time_index: int):
if self.batches[time_index] is None:
return self.batches[time_index]
else:
return torch.LongTensor(self.batches[time_index])
def _get_edge_weight(self, time_index: int):
if self.edge_weights[time_index] is None:
return self.edge_weights[time_index]
else:
return torch.FloatTensor(self.edge_weights[time_index])
def _get_feature(self, time_index: int):
if self.features[time_index] is None:
return self.features[time_index]
else:
return torch.FloatTensor(self.features[time_index])
def _get_target(self, time_index: int):
if self.targets[time_index] is None:
return self.targets[time_index]
else:
if self.targets[time_index].dtype.kind == "i":
return torch.LongTensor(self.targets[time_index])
elif self.targets[time_index].dtype.kind == "f":
return torch.FloatTensor(self.targets[time_index])
def _get_additional_feature(self, time_index: int, feature_key: str):
feature = getattr(self, feature_key)[time_index]
if feature.dtype.kind == "i":
return torch.LongTensor(feature)
elif feature.dtype.kind == "f":
return torch.FloatTensor(feature)
def _get_additional_features(self, time_index: int):
additional_features = {
key: self._get_additional_feature(time_index, key)
for key in self.additional_feature_keys
}
return additional_features
def __getitem__(self, time_index: Union[int, slice]):
if isinstance(time_index, slice):
snapshot = DynamicGraphTemporalSignalBatch(
self.edge_indices[time_index],
self.edge_weights[time_index],
self.features[time_index],
self.targets[time_index],
self.batches[time_index],
**{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
)
else:
x = self._get_feature(time_index)
edge_index = self._get_edge_index(time_index)
edge_weight = self._get_edge_weight(time_index)
batch = self._get_batch_index(time_index)
y = self._get_target(time_index)
additional_features = self._get_additional_features(time_index)
snapshot = Batch(x=x, edge_index=edge_index, edge_attr=edge_weight,
y=y, batch=batch, **additional_features)
return snapshot
def __next__(self):
if self.t < len(self.features):
snapshot = self[self.t]
self.t = self.t + 1
return snapshot
else:
self.t = 0
raise StopIteration
def __iter__(self):
self.t = 0
return self