import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .. import BaseModel, register_model
from cogdl.utils import add_remaining_self_loops
[docs]class SpecialSpmmFunction(torch.autograd.Function):
"""Special function for only sparse region backpropataion layer."""
@staticmethod
[docs] def forward(ctx, indices, values, shape, b):
assert indices.requires_grad == False
a = torch.sparse_coo_tensor(indices, values, shape)
ctx.save_for_backward(a, b)
ctx.N = shape[0]
return torch.matmul(a, b)
@staticmethod
[docs] def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b
[docs]class SpecialSpmm(nn.Module):
[docs] def forward(self, indices, values, shape, b):
return SpecialSpmmFunction.apply(indices, values, shape, b)
[docs]class SpGraphAttentionLayer(nn.Module):
"""
Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
"""
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(SpGraphAttentionLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.xavier_normal_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features)))
nn.init.xavier_normal_(self.a.data, gain=1.414)
self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)
self.special_spmm = SpecialSpmm()
[docs] def forward(self, input, edge):
N = input.size()[0]
h = torch.mm(input, self.W)
# h: N x out
assert not torch.isnan(h).any()
# Self-attention on the nodes - Shared attention mechanism
edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
# edge: 2*D x E
edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))
assert not torch.isnan(edge_e).any()
# edge_e: E
e_rowsum = self.special_spmm(
edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1)).to(input.device)
)
# e_rowsum: N x 1
edge_e = self.dropout(edge_e)
# edge_e: E
h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h)
assert not torch.isnan(h_prime).any()
# h_prime: N x out
h_prime = h_prime.div(e_rowsum + 1e-8)
# h_prime: N x out
assert not torch.isnan(h_prime).any()
if self.concat:
# if this layer is not last layer,
return F.elu(h_prime)
else:
# if this layer is last layer,
return h_prime
[docs] def __repr__(self):
return (
self.__class__.__name__
+ " ("
+ str(self.in_features)
+ " -> "
+ str(self.out_features)
+ ")"
)
[docs]@register_model("gat")
class PetarVSpGAT(BaseModel):
r"""The GAT model from the `"Graph Attention Networks"
<https://arxiv.org/abs/1710.10903>`_ paper
Args:
num_features (int) : Number of input features.
num_classes (int) : Number of classes.
hidden_size (int) : The dimension of node representation.
dropout (float) : Dropout rate for model training.
alpha (float) : Coefficient of leaky_relu.
nheads (int) : Number of attention heads.
"""
@staticmethod
[docs] def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--num-features", type=int)
parser.add_argument("--num-classes", type=int)
parser.add_argument("--hidden-size", type=int, default=8)
parser.add_argument("--dropout", type=float, default=0.6)
parser.add_argument("--alpha", type=float, default=0.2)
parser.add_argument("--nheads", type=int, default=8)
# fmt: on
@classmethod
[docs] def build_model_from_args(cls, args):
return cls(
args.num_features,
args.hidden_size,
args.num_classes,
args.dropout,
args.alpha,
args.nheads,
)
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
"""Sparse version of GAT."""
super(PetarVSpGAT, self).__init__()
self.dropout = dropout
self.attentions = [
SpGraphAttentionLayer(
nfeat, nhid, dropout=dropout, alpha=alpha, concat=True
)
for _ in range(nheads)
]
for i, attention in enumerate(self.attentions):
self.add_module("attention_{}".format(i), attention)
self.out_att = SpGraphAttentionLayer(
nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False
)
[docs] def forward(self, x, edge_index):
edge_index, _ = add_remaining_self_loops(edge_index)
x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, edge_index) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout, training=self.training)
x = F.elu(self.out_att(x, edge_index))
return F.log_softmax(x, dim=1)
[docs] def loss(self, data):
return F.nll_loss(
self.forward(data.x, data.edge_index)[data.train_mask],
data.y[data.train_mask],
)
[docs] def predict(self, data):
return self.forward(data.x, data.edge_index)