Source code for cogdl.models.nn.grand

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from .. import BaseModel, register_model
from cogdl.utils import symmetric_normalization, spmm

class MLPLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(MLPLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
            self.register_parameter("bias", None)

    def reset_parameters(self):

        stdv = 1.0 / math.sqrt(self.weight.size(1)), stdv)
        if self.bias is not None:
  , stdv)

    def forward(self, x):
        output =, self.weight)
        if self.bias is not None:
            return output + self.bias
            return output

    def __repr__(self):
        return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")"

[docs]@register_model("grand") class Grand(BaseModel): """ Implementation of GRAND in paper `"Graph Random Neural Networks for Semi-Supervised Learning on Graphs"` <> Parameters ---------- nfeat : int Size of each input features. nhid : int Size of hidden features. nclass : int Number of output classes. input_droprate : float Dropout rate of input features. hidden_droprate : float Dropout rate of hidden features. use_bn : bool Using batch normalization. dropnode_rate : float Rate of dropping elements of input features tem : float Temperature to sharpen predictions. lam : float Proportion of consistency loss of unlabelled data order : int Order of adjacency matrix sample : int Number of augmentations for consistency loss alpha : float """
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--num-features", type=int) parser.add_argument("--num-classes", type=int) parser.add_argument("--hidden-size", type=int, default=32) parser.add_argument("--hidden-dropout", type=float, default=0.5) parser.add_argument("--input-dropout", type=float, default=0.5) parser.add_argument("--bn", type=bool, default=False) parser.add_argument("--dropnode-rate", type=float, default=0.5) parser.add_argument('--order', type=int, default=5) parser.add_argument('--tem', type=float, default=0.5) parser.add_argument('--lam', type=float, default=0.5) parser.add_argument('--sample', type=int, default=2) parser.add_argument('--alpha', type=float, default=0.2)
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.num_features, args.hidden_size, args.num_classes, args.input_dropout, args.hidden_dropout,, args.dropnode_rate, args.tem, args.lam, args.order, args.sample, args.alpha, )
def __init__( self, nfeat, nhid, nclass, input_droprate, hidden_droprate, use_bn, dropnode_rate, tem, lam, order, sample, alpha, ): super(Grand, self).__init__() self.layer1 = MLPLayer(nfeat, nhid) self.layer2 = MLPLayer(nhid, nclass) self.input_droprate = input_droprate self.hidden_droprate = hidden_droprate self.bn1 = nn.BatchNorm1d(nfeat) self.bn2 = nn.BatchNorm1d(nhid) self.use_bn = use_bn self.tem = tem self.lam = lam self.order = order self.dropnode_rate = dropnode_rate self.sample = sample self.alpha = alpha
[docs] def dropNode(self, x): n = x.shape[0] drop_rates = torch.ones(n) * self.dropnode_rate if masks = torch.bernoulli(1.0 - drop_rates).unsqueeze(1) x = * x else: x = x * (1.0 - self.dropnode_rate) return x
[docs] def rand_prop(self, x, edge_index, edge_weight): x = self.dropNode(x) y = x for i in range(self.order): x = spmm(edge_index, edge_weight, x).detach_() y.add_(x) return y.div_(self.order + 1.0).detach_()
[docs] def consis_loss(self, logps, train_mask): temp = self.tem ps = [torch.exp(p)[~train_mask] for p in logps] sum_p = 0.0 for p in ps: sum_p = sum_p + p avg_p = sum_p / len(ps) sharp_p = (torch.pow(avg_p, 1.0 / temp) / torch.sum(torch.pow(avg_p, 1.0 / temp), dim=1, keepdim=True)).detach() loss = 0.0 for p in ps: loss += torch.mean((p - sharp_p).pow(2).sum(1)) loss = loss / len(ps) return self.lam * loss
[docs] def normalize_x(self, x): row_sum = x.sum(1) row_inv = row_sum.pow_(-1) row_inv.masked_fill_(row_inv == float("inf"), 0) x = x * row_inv[:, None] return x
[docs] def forward(self, x, edge_index, edge_weight=None): x = self.normalize_x(x) if edge_weight is None: edge_weight = torch.ones(edge_index.shape[1], dtype=torch.float32).to(x.device) edge_weight = symmetric_normalization(x.shape[0], edge_index, edge_weight) x = self.rand_prop(x, edge_index, edge_weight) if self.use_bn: x = self.bn1(x) x = F.dropout(x, self.input_droprate, x = F.relu(self.layer1(x)) if self.use_bn: x = self.bn2(x) x = F.dropout(x, self.hidden_droprate, x = self.layer2(x) return x
[docs] def node_classification_loss(self, data): output_list = [] edge_index = data.edge_index_train if hasattr(data, "edge_index_train") and else data.edge_index for i in range(self.sample): output_list.append(self.forward(data.x, edge_index)) loss_train = 0.0 for output in output_list: loss_train += self.loss_fn(output[data.train_mask], data.y[data.train_mask]) loss_train = loss_train / self.sample if len(data.y.shape) > 1: output_list = [torch.sigmoid(x) for x in output_list] else: output_list = [F.log_softmax(x, dim=-1) for x in output_list] loss_consis = self.consis_loss(output_list, data.train_mask) return loss_train + loss_consis
[docs] def predict(self, data): return self.forward(data.x, data.edge_index)