Source code for cogdl.models.nn.pprgo

import torch

from .. import BaseModel
from cogdl.utils import spmm
from cogdl.layers import PPRGoLayer


[docs]class PPRGo(BaseModel):
[docs] @staticmethod def add_args(parser): parser.add_argument("--hidden-size", type=int, default=32) parser.add_argument("--num-layers", type=int, default=2) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--activation", type=str, default="relu") parser.add_argument("--nprop-inference", type=int, default=2) parser.add_argument("--alpha", type=float, default=0.5)
[docs] @classmethod def build_model_from_args(cls, args): return cls( in_feats=args.num_features, hidden_size=args.hidden_size, out_feats=args.num_classes, num_layers=args.num_layers, alpha=args.alpha, dropout=args.dropout, activation=args.activation, nprop=args.nprop_inference, norm=args.norm if hasattr(args, "norm") else "sym", )
def __init__( self, in_feats, hidden_size, out_feats, num_layers, alpha, dropout, activation="relu", nprop=2, norm="sym" ): super(PPRGo, self).__init__() self.alpha = alpha self.norm = norm self.nprop = nprop self.fc = PPRGoLayer(in_feats, hidden_size, out_feats, num_layers, dropout, activation)
[docs] def forward(self, x, targets, ppr_scores): h = self.fc(x) h = ppr_scores.unsqueeze(1) * h batch_size = targets[-1] + 1 out = torch.zeros(batch_size, h.shape[1]).to(x.device).to(x.dtype) out = out.scatter_add_(dim=0, index=targets[:, None].repeat(1, h.shape[1]), src=h) return out
[docs] def predict(self, graph, batch_size=10000): device = next(self.parameters()).device x = graph.x num_nodes = x.shape[0] pred_logits = [] with torch.no_grad(): for i in range(0, num_nodes, batch_size): batch_x = x[i : i + batch_size].to(device) batch_logits = self.fc(batch_x) pred_logits.append(batch_logits.cpu()) pred_logits = torch.cat(pred_logits, dim=0) pred_logits = pred_logits.to(device) with graph.local_graph(): if self.norm == "sym": graph.sym_norm() elif self.norm == "row": graph.row_norm() else: raise NotImplementedError edge_weight = graph.edge_weight * (1 - self.alpha) graph.edge_weight = edge_weight predictions = pred_logits for _ in range(self.nprop): predictions = spmm(graph, predictions) + self.alpha * pred_logits return predictions