Source code for cogdl.models.nn.rgcn

import torch
import torch.nn as nn
import torch.nn.functional as F

from cogdl.utils.link_prediction_utils import GNNLinkPredict, sampling_edge_uniform
from cogdl.layers import RGCNLayer
from .. import BaseModel


class RGCN(nn.Module):
    def __init__(
        self,
        in_feats,
        out_feats,
        num_layers,
        num_rels,
        regularizer="basis",
        num_bases=None,
        self_loop=True,
        dropout=0.0,
        self_dropout=0.0,
    ):
        super(RGCN, self).__init__()
        shapes = [in_feats] + [out_feats] * num_layers
        self.num_layers = num_layers
        self.layers = nn.ModuleList(
            RGCNLayer(shapes[i], shapes[i + 1], num_rels, regularizer, num_bases, self_loop, dropout, self_dropout)
            for i in range(num_layers)
        )

    def forward(self, graph, x):
        h = x
        for i in range(len(self.layers)):
            h = self.layers[i](graph, h)
            if i < self.num_layers - 1:
                h = F.relu(h)
        return h


[docs]class LinkPredictRGCN(GNNLinkPredict, BaseModel):
[docs] @staticmethod def add_args(parser): # fmt: off parser.add_argument("--hidden-size", type=int, default=200) parser.add_argument("--num-layers", type=int, default=2) parser.add_argument("--regularizer", type=str, default="basis") parser.add_argument("--self-loop", action="store_false") parser.add_argument("--penalty", type=float, default=0.001) parser.add_argument("--dropout", type=float, default=0.2) parser.add_argument("--self-dropout", type=float, default=0.4) parser.add_argument("--num-bases", type=int, default=5) parser.add_argument("--sampling-rate", type=float, default=0.01)
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls( num_entities=args.num_entities, num_rels=args.num_rels, hidden_size=args.hidden_size, num_layers=args.num_layers, regularizer=args.regularizer, num_bases=args.num_bases, self_loop=args.self_loop, sampling_rate=args.sampling_rate, penalty=args.penalty, dropout=args.dropout, self_dropout=args.self_dropout, )
def __init__( self, num_entities, num_rels, hidden_size, num_layers, regularizer="basis", num_bases=None, self_loop=True, sampling_rate=0.01, penalty=0, dropout=0.0, self_dropout=0.0, ): BaseModel.__init__(self) GNNLinkPredict.__init__(self) self.penalty = penalty self.num_nodes = num_entities self.num_rels = num_rels self.sampling_rate = sampling_rate self.edge_set = None self.model = RGCN( in_feats=hidden_size, out_feats=hidden_size, num_layers=num_layers, num_rels=num_rels, regularizer=regularizer, num_bases=num_bases, self_loop=self_loop, dropout=dropout, self_dropout=self_dropout, ) # self.rel_weight = nn.Parameter(torch.Tensor(num_rels, hidden_size)) # nn.init.xavier_normal_(self.rel_weight, gain=nn.init.calculate_gain("relu")) # self.emb = nn.Parameter(torch.Tensor(num_entities, hidden_size)) # nn.init.xavier_normal_(self.emb, gain=nn.init.calculate_gain("relu")) self.rel_weight = nn.Embedding(num_rels, hidden_size) self.emb = nn.Embedding(num_entities, hidden_size)
[docs] def forward(self, graph): reindexed_nodes, reindexed_edges = torch.unique(torch.stack(graph.edge_index), sorted=True, return_inverse=True) x = self.emb(reindexed_nodes) self.cahce_index = reindexed_nodes graph.edge_index = reindexed_edges # graph.num_nodes = reindexed_edges.max().item() + 1 output = self.model(graph, x) # output = self.model(x, reindexed_indices, graph.edge_type) return output
[docs] def loss(self, graph, scoring): edge_index = graph.edge_index edge_types = graph.edge_attr self.get_edge_set(edge_index, edge_types) batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform( edge_index, edge_types, self.edge_set, self.sampling_rate, self.num_rels ) graph = graph.__class__(edge_index=batch_edges, edge_attr=batch_attr) # graph.edge_index = batch_edges # graph.edge_attr = batch_attr output = self.forward(graph) edge_weight = self.rel_weight(rels) sampled_nodes, reindexed_edges = torch.unique(samples, sorted=True, return_inverse=True) assert (sampled_nodes == self.cahce_index).any() sampled_types = torch.unique(rels) loss_n = self._loss( output[reindexed_edges[0]], output[reindexed_edges[1]], edge_weight, labels, scoring ) + self.penalty * self._regularization([self.emb(sampled_nodes), self.rel_weight(sampled_types)]) return loss_n
[docs] def predict(self, graph): device = next(self.parameters()).device indices = torch.arange(0, self.num_nodes).to(device) x = self.emb(indices) output = self.model(graph, x) return output, self.rel_weight.weight