Source code for cogdl.models.nn.sign

import torch
import torch.nn.functional as F

from .. import BaseModel, register_model
from cogdl.utils import (
    add_remaining_self_loops,
    remove_self_loops,
    row_normalization,
    symmetric_normalization,
    to_undirected,
    spmm,
    dropout_adj,
)


def get_adj(row, col, asymm_norm=False, set_diag=True, remove_diag=False):
    edge_index = torch.stack([row, col])
    edge_attr = torch.ones(edge_index.shape[1]).to(edge_index.device)
    if set_diag:
        edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr)
    elif remove_diag:
        edge_index, _ = remove_self_loops(edge_index)

    num_nodes = int(torch.max(edge_index)) + 1
    if not asymm_norm:
        edge_attr = row_normalization(num_nodes, edge_index, edge_attr)
    else:
        edge_attr = symmetric_normalization(num_nodes, edge_index, edge_attr)
    return edge_index, edge_attr


[docs]@register_model("sign") class MLP(BaseModel):
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument('--num-features', type=int) parser.add_argument("--num-classes", type=int) parser.add_argument('--hidden-size', type=int, default=512) parser.add_argument('--num-layers', type=int, default=3) parser.add_argument('--dropout', type=float, default=0.3) parser.add_argument('--dropedge-rate', type=float, default=0.2) parser.add_argument('--directed', action='store_true') parser.add_argument('--num-propagations', type=int, default=1) parser.add_argument('--asymm-norm', action='store_true') parser.add_argument('--set-diag', action='store_true') parser.add_argument('--remove-diag', action='store_true')
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.num_features, args.hidden_size, args.num_classes, args.num_layers, args.dropout, args.directed, args.dropedge_rate, args.num_propagations, args.asymm_norm, args.set_diag, args.remove_diag, )
def __init__( self, num_features, hidden_size, num_classes, num_layers, dropout, dropedge_rate, undirected, num_propagations, asymm_norm, set_diag, remove_diag, ): super(MLP, self).__init__() self.dropout = dropout self.dropedge_rate = dropedge_rate self.undirected = undirected self.num_propagations = num_propagations self.asymm_norm = asymm_norm self.set_diag = set_diag self.remove_diag = remove_diag self.lins = torch.nn.ModuleList() self.lins.append(torch.nn.Linear((1 + 2 * self.num_propagations) * num_features, hidden_size)) self.bns = torch.nn.ModuleList() self.bns.append(torch.nn.BatchNorm1d(hidden_size)) for _ in range(num_layers - 2): self.lins.append(torch.nn.Linear(hidden_size, hidden_size)) self.bns.append(torch.nn.BatchNorm1d(hidden_size)) self.lins.append(torch.nn.Linear(hidden_size, num_classes)) self.cache_x = None
[docs] def reset_parameters(self): for lin in self.lins: lin.reset_parameters() for bn in self.bns: bn.reset_parameters()
def _preprocessing(self, x, edge_index): num_nodes = x.shape[0] op_embedding = [] op_embedding.append(x) # Convert to numpy arrays on cpu edge_index, _ = dropout_adj(edge_index, drop_rate=self.dropedge_rate) row, col = edge_index if self.undirected: edge_index = to_undirected(edge_index, num_nodes) row, col = edge_index # adj matrix edge_index, edge_attr = get_adj( row, col, asymm_norm=self.asymm_norm, set_diag=self.set_diag, remove_diag=self.remove_diag ) nx = x for _ in range(self.num_propagations): nx = spmm(edge_index, edge_attr, nx) op_embedding.append(nx) # transpose adj matrix edge_index, edge_attr = get_adj( col, row, asymm_norm=self.asymm_norm, set_diag=self.set_diag, remove_diag=self.remove_diag ) nx = x for _ in range(self.num_propagations): nx = spmm(edge_index, edge_attr, nx) op_embedding.append(nx) return torch.cat(op_embedding, dim=1)
[docs] def forward(self, x, edge_index): if self.cache_x is None: self.cache_x = self._preprocessing(x, edge_index) x = self.cache_x for i, lin in enumerate(self.lins[:-1]): x = lin(x) x = self.bns[i](x) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.lins[-1](x) return torch.log_softmax(x, dim=-1)
[docs] def node_classification_loss(self, data, mask=None): if mask is None: mask = data.train_mask edge_index = data.edge_index_train if hasattr(data, "edge_index_train") and self.training else data.edge_index pred = self.forward(data.x, edge_index) return self.loss_fn(pred[mask], data.y[mask])
[docs] def predict(self, data): return self.forward(data.x, data.edge_index)