Source code for cogdl.models.nn.compgcn

import torch

# import torch.fft
import torch.nn as nn
import torch.nn.functional as F

from cogdl.data import Graph
from cogdl.utils import row_normalization
from cogdl.utils.link_prediction_utils import GNNLinkPredict, sampling_edge_uniform, cal_mrr
from .. import BaseModel


# def ccorr(a, b):
#     return torch.fft.irfft(
#         torch.multiply(torch.fft.rfft(a, dim=1).conj(), torch.fft.rfft(b, dim=1)), n=a.shape[-1], dim=1
#     )


class BasesRelEmbLayer(nn.Module):
    def __init__(self, num_bases, num_rels, in_feats):
        super(BasesRelEmbLayer, self).__init__()
        self.num_bases = num_bases
        self.num_resl = num_rels
        self.in_feats = in_feats
        self.weight = nn.Parameter(torch.Tensor(num_bases, in_feats))
        self.alpha = nn.Parameter(torch.Tensor(2 * num_rels, num_bases))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.xavier_uniform_(self.alpha)

    def forward(self):
        weight = torch.matmul(self.alpha, self.weight)
        return weight


class CompGCNLayer(nn.Module):
    """
    Implementation of CompGCN in paper `"Composition-based Multi-Relational Graph Convolutional Networks"`
    <https://arxiv.org/abs/1911.03082>

    Parameters
    ----------
    in_feats : int
        Size of each input embedding
    out_feats : int
        Size of each output embedding
    num_rel : int
        The number of relations
    opn : str
        Operation to mix two head and tail embeddings
    num_bases : Optional[Int]
        The number of basis, default : ``None``.
    activation : Function
    dropout : float
    bias : bool
    """

    def __init__(
        self, in_feats, out_feats, num_rels, opn="mult", num_bases=None, activation=lambda x: x, dropout=0.0, bias=True
    ):
        super(CompGCNLayer, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.num_rels = num_rels
        self.opn = opn
        self.use_bases = num_bases is not None and num_bases > 0

        self.weight_in = self.get_param(in_feats, out_feats)
        self.weight_out = self.get_param(in_feats, out_feats)
        self.weight_rel = self.get_param(in_feats, out_feats)
        self.weight_loop = self.get_param(in_feats, out_feats)
        self.loop_rel = self.get_param(1, in_feats)

        if self.use_bases:
            self.basis_weight = BasesRelEmbLayer(num_bases, num_rels, in_feats)
        else:
            self.register_buffer("basis_weight", None)

        self.dropout = dropout
        self.activation = activation
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_feats))
        else:
            self.register_buffer("bias", None)
        self.bn = nn.BatchNorm1d(out_feats)

    def get_param(self, num_in, num_out):
        weight = nn.Parameter(torch.Tensor(num_in, num_out))
        nn.init.xavier_normal_(weight.data)
        return weight

    def forward(self, graph, x, rel_embed):
        device = x.device
        if self.use_bases:
            rel_embed = self.basis_weight()
        edge_index = graph.edge_index
        edge_type = graph.edge_attr
        rel_embed = torch.cat((rel_embed, self.loop_rel), dim=0)
        num_edges = edge_index[0].shape[0] // 2
        num_entities = x.shape[0]

        row, col = edge_index
        i_row, i_col = row[:num_edges], col[:num_edges]
        rev_row, rev_col = row[num_edges:], col[num_edges:]
        loop_index = torch.stack((torch.arange(num_entities), torch.arange(num_entities))).to(device)
        types, rev_types = edge_type[:num_edges], edge_type[num_edges:]
        loop_types = torch.full((num_entities,), rel_embed.shape[0] - 1, dtype=torch.long).to(device)

        in_norm = row_normalization(num_entities, i_row, i_col)
        rev_norm = row_normalization(num_entities, rev_row, rev_col)

        emb = self.message_passing(x, rel_embed, (i_row, i_col), types, "in", in_norm)
        rev_emb = self.message_passing(x, rel_embed, (rev_row, rev_col), rev_types, "out", rev_norm)
        loop_emb = self.message_passing(x, rel_embed, loop_index, loop_types, "loop")

        out = 1 / 3 * (emb + rev_emb + loop_emb)
        if self.bias is not None:
            out += self.bias
        out = self.bn(out)
        return self.activation(out), torch.matmul(rel_embed, self.weight_rel)[:-1]

    def message_passing(self, x, rel_embed, edge_index, edge_types, mode, edge_weight=None):
        device = x.device
        tail_emb = x[edge_index[1]]
        rel_emb = rel_embed[edge_types]
        weight = getattr(self, f"weight_{mode}")

        trans_embed = self.rel_transform(tail_emb, rel_emb)
        trans_embed = torch.matmul(trans_embed, weight)
        dim = trans_embed.shape[1]
        if edge_weight is not None:
            trans_embed = trans_embed * edge_weight.unsqueeze(-1)
        embed = (
            torch.zeros(x.shape[0], dim)
            .to(device)
            .scatter_add_(0, edge_index[0].unsqueeze(-1).repeat(1, dim), trans_embed)
        )
        return F.dropout(embed, p=self.dropout, training=self.training)

    def rel_transform(self, ent_embed, rel_embed):
        # if self.opn == "corr":
        #     trans_embed = ccorr(ent_embed, rel_embed)
        if self.opn == "sub":
            trans_embed = ent_embed - rel_embed
        elif self.opn == "mult":
            trans_embed = ent_embed * rel_embed
        else:
            raise NotImplementedError(f"{self.opn}")
        return trans_embed


class CompGCN(nn.Module):
    def __init__(
        self, num_entities, num_rels, num_bases, in_feats, hidden_size, out_feats, layers, dropout, activation, opn
    ):
        super(CompGCN, self).__init__()
        self.opn = opn
        self.num_rels = num_rels
        self.num_entities = num_entities
        if num_bases is not None and num_bases > 0:
            self.init_rel = nn.Embedding(num_bases, in_feats)
        else:
            self.init_rel = nn.Embedding(2 * num_rels, in_feats)

        self.convs = nn.ModuleList()
        if num_bases > 0:
            self.convs.append(
                CompGCNLayer(
                    in_feats=in_feats,
                    out_feats=hidden_size,
                    num_rels=num_rels,
                    opn=self.opn,
                    num_bases=num_bases,
                    activation=activation,
                    dropout=dropout,
                )
            )
        else:
            self.convs.append(
                CompGCNLayer(
                    in_feats=in_feats,
                    out_feats=hidden_size,
                    num_rels=num_rels,
                    opn=self.opn,
                    activation=activation,
                    dropout=dropout,
                )
            )
        if layers == 2:
            self.convs.append(
                CompGCNLayer(
                    in_feats=hidden_size,
                    out_feats=out_feats,
                    num_rels=num_rels,
                    opn=self.opn,
                    activation=activation,
                    dropout=dropout,
                )
            )

    def forward(self, graph, x):
        rel_embed = self.init_rel.weight
        node_embed = x
        for layer in self.convs:
            node_embed, rel_embed = layer(graph, node_embed, rel_embed)
        return node_embed, rel_embed


[docs]class LinkPredictCompGCN(GNNLinkPredict, BaseModel):
[docs] @staticmethod def add_args(parser): # fmt: off parser.add_argument("--hidden-size", type=int, default=200) parser.add_argument("--penalty", type=float, default=0.001) parser.add_argument("--dropout", type=float, default=0.3) parser.add_argument("--num-bases", type=int, default=10) parser.add_argument("--num-layers", type=int, default=1) parser.add_argument("--sampling-rate", type=float, default=0.01) parser.add_argument("--lbl_smooth", type=float, default=0.1) parser.add_argument("--opn", type=str, default="sub")
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls( num_entities=args.num_entities, num_rels=args.num_rels, hidden_size=args.hidden_size, num_bases=args.num_bases, sampling_rate=args.sampling_rate, penalty=args.penalty, layers=args.num_layers, dropout=args.dropout, lbl_smooth=args.lbl_smooth, opn=args.opn, )
def __init__( self, num_entities, num_rels, hidden_size, num_bases=0, layers=1, sampling_rate=0.01, penalty=0.001, dropout=0.0, lbl_smooth=0.1, opn="sub", ): BaseModel.__init__(self) GNNLinkPredict.__init__(self) activation = F.tanh self.model = CompGCN( num_entities, num_rels, num_bases, hidden_size // 2, hidden_size, hidden_size, layers, dropout, activation, opn, ) self.emb = nn.Embedding(num_entities, hidden_size // 2) self.sampling_rate = sampling_rate self.penalty = penalty self.num_rels = num_rels self.num_entities = num_entities self.cache_index = None self.lbl_smooth = lbl_smooth
[docs] def add_reverse_edges(self, edge_index, edge_types): row, col = edge_index _row = torch.cat([row, col]) _col = torch.cat([col, row]) edge_index = torch.cat((_row, _col), dim=1) edge_types_rev = edge_types + self.num_rels edge_types = torch.cat([edge_types, edge_types_rev]) return edge_index, edge_types
[docs] def forward(self, graph): edge_index = torch.stack(graph.edge_index) # edge_index, edge_types = self.add_reverse_edges(edge_index, edge_types) reindexed_node, reindexed_edge_index = torch.unique(edge_index, return_inverse=True, sorted=True) self.cache_index = reindexed_node node_embed = self.emb(reindexed_node) with graph.local_graph(): graph.edge_index = reindexed_edge_index node_embed, rel_embed = self.model(graph, node_embed) return node_embed, rel_embed
[docs] def loss(self, data: Graph, scoring): row, col = data.edge_index edge_types = data.edge_attr edge_index = torch.stack([row, col]) self.get_edge_set(edge_index, edge_types) batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform( (row, col), edge_types, self.edge_set, self.sampling_rate, self.num_rels, label_smoothing=self.lbl_smooth, num_entities=self.num_entities, ) with data.local_graph(): data.edge_index = batch_edges data.edge_attr = batch_attr node_embed, rel_embed = self.forward(data) sampled_nodes, reindexed_edges = torch.unique(samples, sorted=True, return_inverse=True) assert (self.cache_index == sampled_nodes).any() loss_n = self._loss( node_embed[reindexed_edges[0]], node_embed[reindexed_edges[1]], rel_embed[rels], labels, scoring ) loss_r = self.penalty * self._regularization([self.emb(sampled_nodes), rel_embed]) return loss_n + loss_r
[docs] def predict(self, graph): device = next(self.parameters()).device indices = torch.arange(0, self.num_entities).to(device) x = self.emb(indices) node_embed, rel_embed = self.model(graph, x) return node_embed, rel_embed