import torch
import numpy as np
import torch.nn as nn
from typing import List
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class MaskedSelfAttention(nn.Module):
def __init__(self, input_dim, output_dim, n_heads, attention_aggregate="mean"):
super(MaskedSelfAttention, self).__init__()
self.attention_aggregate = attention_aggregate
self.input_dim = input_dim
self.output_dim = output_dim
self.n_heads = n_heads
if attention_aggregate == "concat":
self.per_head_dim = self.dq = self.dk = self.dv = output_dim // n_heads
elif attention_aggregate == "mean":
self.per_head_dim = self.dq = self.dk = self.dv = output_dim
else:
raise ValueError(f"wrong value for aggregate {attention_aggregate}")
self.Wq = nn.Linear(input_dim, n_heads * self.dq, bias=False)
self.Wk = nn.Linear(input_dim, n_heads * self.dk, bias=False)
self.Wv = nn.Linear(input_dim, n_heads * self.dv, bias=False)
def forward(self, input_tensor):
seq_length = input_tensor.shape[1]
Q = self.Wq(input_tensor)
K = self.Wk(input_tensor)
V = self.Wv(input_tensor)
Q = Q.reshape(
input_tensor.shape[0], input_tensor.shape[1], self.n_heads, self.dq
).transpose(1, 2)
K = K.reshape(
input_tensor.shape[0], input_tensor.shape[1], self.n_heads, self.dk
).permute(0, 2, 3, 1)
V = V.reshape(
input_tensor.shape[0], input_tensor.shape[1], self.n_heads, self.dv
).transpose(1, 2)
attention_score = Q.matmul(K) / np.sqrt(self.per_head_dim)
attention_mask = (
torch.zeros(seq_length, seq_length)
.masked_fill(torch.tril(torch.ones(seq_length, seq_length)) == 0, -np.inf)
.to(input_tensor.device)
)
attention_score = attention_score + attention_mask
attention_score = torch.softmax(attention_score, dim=-1)
multi_head_result = attention_score.matmul(V)
if self.attention_aggregate == "concat":
output = multi_head_result.transpose(1, 2).reshape(
input_tensor.shape[0], seq_length, self.n_heads * self.per_head_dim
)
elif self.attention_aggregate == "mean":
output = multi_head_result.transpose(1, 2).mean(dim=2)
else:
raise ValueError(f"wrong value for aggregate {self.attention_aggregate}")
print(output.shape)
return output
class GlobalGatedUpdater(nn.Module):
def __init__(self, items_total, item_embedding):
super(GlobalGatedUpdater, self).__init__()
self.items_total = items_total
self.item_embedding = item_embedding
self.alpha = nn.Parameter(torch.rand(items_total, 1), requires_grad=True)
def forward(self, nodes_output):
batch_size = nodes_output.shape[0] // self.items_total
id = 0
num_nodes = self.items_total
items_embedding = self.item_embedding(
torch.tensor([i for i in range(self.items_total)]).to(nodes_output.device)
)
batch_embedding = []
for _ in range(batch_size):
output_node_features = nodes_output[id : id + num_nodes, :]
embed = (1 - self.alpha) * items_embedding
embed = embed + self.alpha * output_node_features
batch_embedding.append(embed)
id += num_nodes
batch_embedding = torch.stack(batch_embedding)
return batch_embedding
class AggregateTemporalNodeFeatures(nn.Module):
def __init__(self, item_embed_dim):
super(AggregateTemporalNodeFeatures, self).__init__()
self.Wq = nn.Linear(item_embed_dim, item_embed_dim, bias=False)
def forward(self, nodes_output):
aggregated_features = []
for l in range(nodes_output.shape[0]):
output_node_features = nodes_output[l, :, :]
weights = self.Wq(output_node_features)
aggregated_features.append(weights)
aggregated_features = torch.cat(aggregated_features, dim=0)
print(aggregated_features.shape)
return aggregated_features
class WeightedGCNBlock(nn.Module):
def __init__(self, in_features: int, hidden_sizes: List[int], out_features: int):
super(WeightedGCNBlock, self).__init__()
gcns, relus, bns = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
input_size = in_features
for hidden_size in hidden_sizes:
gcns.append(GCNConv(input_size, hidden_size))
relus.append(nn.ReLU())
bns.append(nn.BatchNorm1d(hidden_size))
input_size = hidden_size
gcns.append(GCNConv(hidden_sizes[-1], out_features))
relus.append(nn.ReLU())
bns.append(nn.BatchNorm1d(out_features))
self.gcns, self.relus, self.bns = gcns, relus, bns
def forward(
self,
node_features: torch.FloatTensor,
edge_index: torch.LongTensor,
edges_weight: torch.LongTensor,
):
h = node_features
for gcn, relu, bn in zip(self.gcns, self.relus, self.bns):
h = gcn(h, edge_index, edges_weight)
h = bn(h.transpose(1, -1)).transpose(1, -1)
h = relu(h)
return h
[docs]class DNNTSP(nn.Module):
r"""An implementation of the Deep Neural Network for Temporal Set Prediction.
For details see: `"Predicting Temporal Sets with Deep Neural Networks" <https://dl.acm.org/doi/abs/10.1145/3394486.3403152>`_
Args:
items_total (int): Total number of items in the sets. Cardinality of the union.
item_embedding_dim (int): Item embedding dimensions.
n_heads (int): Number of attention heads.
"""
def __init__(self, items_total: int, item_embedding_dim: int, n_heads: int):
super(DNNTSP, self).__init__()
self.item_embedding = nn.Embedding(items_total, item_embedding_dim)
self.item_embedding_dim = item_embedding_dim
self.items_total = items_total
self.stacked_gcn = WeightedGCNBlock(
item_embedding_dim, [item_embedding_dim], item_embedding_dim
)
self.masked_self_attention = MaskedSelfAttention(
input_dim=item_embedding_dim, output_dim=item_embedding_dim, n_heads=n_heads
)
self.aggregate_nodes_temporal_feature = AggregateTemporalNodeFeatures(
item_embed_dim=item_embedding_dim
)
self.global_gated_updater = GlobalGatedUpdater(
items_total=items_total, item_embedding=self.item_embedding
)
[docs] def forward(
self,
X: torch.FloatTensor,
edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
):
r"""Making a forward pass. If edge weights are not present the forward pass
defaults to an unweighted graph.
Arg types:
* **X** (PyTorch Float Tensor) - Node features.
* **edge_index** (PyTorch Long Tensor) - Graph edge indices.
* **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.
Return types:
* **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
"""
H = self.stacked_gcn(X, edge_index, edge_weight)
H = H.view(-1, self.items_total, self.item_embedding_dim)
H = self.masked_self_attention(H)
H = self.aggregate_nodes_temporal_feature(H)
H = self.global_gated_updater(H)
return H