from functools import reduce
from scipy.special import iv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import spspmm, spmm
from torch_geometric.utils import degree
# ==========
# Attention
# ==========
[docs]class NodeAttention(nn.Module):
def __init__(self, in_feat):
super(NodeAttention, self).__init__()
self.p = nn.Linear(in_feat, 1)
self.dropout = nn.Dropout(0.7)
[docs] def forward(self, x, edge_index, edge_attr):
device = x.device
N, dim = x.shape
diag_val = self.p(x)
# add sigmoid
diag_val = F.sigmoid(diag_val)
# add dropout
self.dropout(diag_val)
# ----------
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv = deg.pow(-1)
edge_attr_t = deg_inv[row] * edge_attr
diag_ind = torch.LongTensor([range(N)] * 2).to(device)
_, adj_mat_val = spspmm(edge_index, edge_attr_t, diag_ind, diag_val.view(-1), N, N, N, True)
return edge_index, adj_mat_val
[docs]class EdgeAttention(nn.Module):
def __init__(self, in_feat):
super(EdgeAttention, self).__init__()
self.p = nn.Linear(in_feat, 1)
self.q = nn.Linear(in_feat, 1)
self.dropout = nn.Dropout(0.5)
[docs] def forward(self, x, edge_index, edge_attr):
device = x.device
N, dim = x.shape
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
edge_attr_t = deg_inv_sqrt[row] * edge_attr * deg_inv_sqrt[col]
p_val = F.relu(self.p(x))
q_val = F.relu(self.q(x))
# add dropout
p_val = self.dropout(p_val)
q_val = self.dropout(q_val)
# --------------------
diag_ind = torch.LongTensor([range(N)] * 2).to(device)
_, p_adj_mat_val = spspmm(edge_index, edge_attr_t, diag_ind, p_val.view(-1), N, N, N, True)
_, q_adj_mat_val = spspmm(diag_ind, q_val.view(-1), edge_index, edge_attr, N, N, N, True)
return edge_index, p_adj_mat_val + q_adj_mat_val
[docs]class Identity(nn.Module):
def __init__(self, in_feat):
super(Identity, self).__init__()
[docs] def forward(self, x, edge_index, edge_attr):
return edge_index, edge_attr
[docs]class Gaussian(nn.Module):
def __init__(self, in_feat):
super(Gaussian, self).__init__()
self.mu = 0.2
self.theta = 1.
self.steps = 4
[docs] def forward(self, x, edge_index, edge_attr):
N = x.shape[0]
row, col = edge_index
deg = degree(row, x.size(0), dtype=x.dtype)
deg_inv = deg.pow(-1)
adj = torch.sparse_coo_tensor(edge_index, deg_inv[row] * edge_attr , size=(N, N))
identity = torch.sparse_coo_tensor([range(N)] * 2, torch.ones(N), size=(N, N)).to(x.device)
laplacian = identity - adj
t0 = identity
t1 = laplacian - self.mu * identity
t1 = t1.mm(t1.to_dense()).to_sparse()
l_x = -0.5 * (t1 - identity)
# l_x = -0.5 * ((laplacian - self.mu * identity).pow(2) - identity)
ivs = [iv(i, self.theta) for i in range(self.steps)]
ivs[1:] = [(-1) ** i * 2 * x for i, x in enumerate(ivs[1:])]
ivs = torch.tensor(ivs).to(x.device)
result = [t0, l_x]
for i in range(2, self.steps):
result.append(2*l_x.mm(result[i-1].to_dense()).to_sparse().sub(result[i-2]))
# adj_ind, adj_val = spspmm(l_x._indices(), l_x._values(), result[i-1]._indices(), result[i-1]._values(), N, N, N, True)
# obj = torch.sparse_coo_tensor(adj_ind, adj_val * 2, size=(N, N))
# obj = obj.sub(result[i-2])
# result.append(obj)
result = [result[i] * ivs[i] for i in range(self.steps)]
def fn(x, y):
return x.add(y)
res = reduce(fn, result)
return res._indices(), res._values()
[docs]class PPR(nn.Module):
def __init__(self, in_feat):
super(PPR, self).__init__()
self.alpha = 0.4
self.steps = 4
[docs] def forward(self, x, edge_index, edge_attr):
# if self.inv_indices is not None:
# return self.inv_indices, self.inv_values
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
edge_attr_t = deg_inv_sqrt[row] * edge_attr * deg_inv_sqrt[col]
N = x.size(0)
adj = torch.sparse_coo_tensor(edge_index, edge_attr_t, size=(N, N))
theta = self.alpha * (1 - self.alpha)
result = [theta * adj]
for i in range(1, self.steps-1):
theta = theta * (1 - self.alpha)
adj_ind, adj_val = spspmm(edge_index, edge_attr_t, result[i-1]._indices(), result[i-1]._values(), N, N, N, True)
result.append(torch.sparse_coo_tensor(adj_ind, adj_val, size=(N, N)))
# result.append(adj.mm((result[-1] * theta).to_dense()).to_sparse())
# spspmm
identity = torch.sparse_coo_tensor([range(N)] * 2, torch.ones(N), size=(N, N)).to(x.device)
result.append(self.alpha * identity)
def fn(x, y):
return x.add(y)
res = reduce(fn, result)
return res._indices(), res._values()
[docs]class HeatKernel(nn.Module):
def __init__(self, in_feat):
super(HeatKernel, self).__init__()
self.t = nn.Parameter(torch.zeros(1,))
[docs] def forward(self, x, edge_index, edge_attr):
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv = deg.pow(-1)
edge_attr_t = self.t * edge_attr * deg_inv[col] - self.t
return edge_index, edge_attr_t.exp()
[docs]def act_attention(attn_type):
if attn_type == "identity":
return Identity
elif attn_type == "node":
return NodeAttention
elif attn_type == "edge":
return EdgeAttention
elif attn_type == "gaussian":
return Gaussian
elif attn_type == "ppr":
return PPR
elif attn_type == "heat":
return HeatKernel
else:
raise ValueError("no such attention type")
# ===============
# Normalization
# ===============
[docs]class NormIdentity(nn.Module):
def __init__(self):
super(NormIdentity, self).__init__()
[docs] def forward(self, edge_index, edge_attr, N):
return edge_attr
[docs]class RowSoftmax(nn.Module):
def __init__(self):
super(RowSoftmax, self).__init__()
[docs] def forward(self, edge_index, edge_attr, N):
device = edge_attr.device
edge_attr_t = torch.exp(edge_attr)
ones = torch.ones(N, 1, device=device)
rownorm = 1. / spmm(edge_index, edge_attr_t, N, N, ones).view(-1)
row = rownorm[edge_index[0]]
# col = rownorm[edge_index[1]]
edge_attr_t = row * edge_attr_t
return edge_attr_t
[docs]class SymmetryNorm(nn.Module):
def __init__(self):
super(SymmetryNorm, self).__init__()
[docs] def forward(self, edge_index, edge_attr, N):
device = edge_attr.device
ones = torch.ones(N, 1, device=device)
rownorm = spmm(edge_index, edge_attr, N, N, ones).view(-1).pow(-0.5)
row = rownorm[edge_index[0]]
col = rownorm[edge_index[1]]
edge_attr_t = row * edge_attr * col
return edge_attr_t
[docs]def act_normalization(norm_type):
if norm_type == "identity":
return NormIdentity
elif norm_type == "row_uniform":
return RowUniform
elif norm_type == "row_softmax":
return RowSoftmax
elif norm_type == "col_uniform":
return ColumnUniform
elif norm_type == "symmetry":
return SymmetryNorm
else:
raise ValueError("no such normalization type")
# ============
# activation
# ============
[docs]def act_map(act):
if act == "linear":
return lambda x: x
elif act == "elu":
return torch.nn.functional.elu
elif act == "sigmoid":
return torch.sigmoid
elif act == "tanh":
return torch.tanh
elif act == "relu":
return torch.nn.functional.relu
elif act == "relu6":
return torch.nn.functional.relu6
elif act == "softplus":
return torch.nn.functional.softplus
elif act == "leaky_relu":
return torch.nn.functional.leaky_relu
else:
raise Exception("wrong activate function")