Source code for cogdl.models.nn.deepergcn

from typing import Any
from torch.utils.checkpoint import checkpoint

import torch
import torch.nn as nn
import torch.nn.functional as F

from .. import register_model, BaseModel
from cogdl.utils import mul_edge_softmax, spmm, get_activation
from cogdl.trainers.sampled_trainer import DeeperGCNTrainer


class GENConv(nn.Module):
    def __init__(
        self,
        in_feat,
        out_feat,
        aggr="softmax_sg",
        beta=1.0,
        p=1.0,
        learn_beta=False,
        learn_p=False,
        use_msg_norm=False,
        learn_msg_scale=True,
    ):
        super(GENConv, self).__init__()
        self.use_msg_norm = use_msg_norm
        self.mlp = nn.Linear(in_feat, out_feat)

        self.message_encoder = torch.nn.ReLU()

        self.aggr = aggr
        if aggr == "softmax_sg":
            self.beta = torch.nn.Parameter(
                torch.Tensor(
                    [
                        beta,
                    ]
                ),
                requires_grad=learn_beta,
            )
        else:
            self.register_buffer("beta", None)
        if aggr == "powermean":
            self.p = torch.nn.Parameter(
                torch.Tensor(
                    [
                        p,
                    ]
                ),
                requires_grad=learn_p,
            )
        else:
            self.register_buffer("p", None)
        self.eps = 1e-7

        self.s = torch.nn.Parameter(torch.Tensor([1.0]), requires_grad=learn_msg_scale)
        self.act = nn.ReLU()

    def message_norm(self, x, msg):
        x_norm = torch.norm(x, dim=1, p=2)
        msg_norm = F.normalize(msg, p=2, dim=1)
        msg_norm = msg_norm * x_norm.unsqueeze(-1)
        return x + self.s * msg_norm

    def forward(self, x, edge_index, edge_attr=None):
        device = x.device
        dim = x.shape[1]
        num_nodes = x.shape[0]
        edge_msg = x[edge_index[1]]  # if edge_attr is None else x[edge_index[1]] + edge_attr
        edge_msg = self.act(edge_msg) + self.eps

        if self.aggr == "softmax_sg":
            h = mul_edge_softmax(edge_index, self.beta * edge_msg, shape=(num_nodes, num_nodes))
            h = edge_msg * h
        elif self.aggr == "softmax":
            h = mul_edge_softmax(edge_index, edge_msg, shape=(num_nodes, num_nodes))
            h = edge_msg * h
        elif self.aggr == "powermean":
            deg = spmm(
                indices=edge_index,
                values=torch.ones(edge_index.shape[1]),
                b=torch.ones(num_nodes).unsqueeze(-1).to(device),
            ).view(-1)
            h = edge_msg.pow(self.t) / deg[edge_index[0]].unsqueeze(-1)
        else:
            raise NotImplementedError

        h = torch.zeros_like(x).scatter_add_(dim=0, index=edge_index[0].unsqueeze(-1).repeat(1, dim), src=h)
        if self.aggr == "powermean":
            h = h.pow(1.0 / self.p)
        if self.use_msg_norm:
            h = self.message_norm(x, h)
        h = self.mlp(h)
        return h


class DeepGCNLayer(nn.Module):
    """
    Implementation of DeeperGCN in paper `"DeeperGCN: All You Need to Train Deeper GCNs"` <https://arxiv.org/abs/2006.07739>

    Parameters
    -----------
    in_feat : int
        Size of each input sample
    out_feat : int
        Size of each output sample
    conv : class
        Base convolution layer.
    connection : str
        Residual connection type, `res` or `res+`.
    activation : str
    dropout : float
    checkpoint_grad : bool
    """

    def __init__(
        self,
        in_feat,
        out_feat,
        conv,
        connection="res",
        activation="relu",
        dropout=0.0,
        checkpoint_grad=False,
    ):
        super(DeepGCNLayer, self).__init__()
        self.conv = conv
        self.activation = get_activation(activation)
        self.dropout = dropout
        self.connection = connection
        self.norm = nn.BatchNorm1d(out_feat, affine=True)
        self.checkpoint_grad = checkpoint_grad

    def forward(self, x, edge_index):
        if self.connection == "res+":
            h = self.norm(x)
            h = self.activation(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
            if self.checkpoint_grad:
                h = checkpoint(self.conv, h, edge_index)
            else:
                h = self.conv(h, edge_index)
        elif self.connection == "res":
            h = self.conv(x, edge_index)
            h = self.norm(h)
            h = self.activation(h)
        else:
            raise NotImplementedError
        return x + h


[docs]@register_model("deepergcn") class DeeperGCN(BaseModel):
[docs] @staticmethod def add_args(parser): # fmt: off parser.add_argument("--num-features", type=int) parser.add_argument("--num-classes", type=int) parser.add_argument("--num-layers", type=int, default=14) parser.add_argument("--hidden-size", type=int, default=128) parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--connection", type=str, default="res+") parser.add_argument("--activation", type=str, default="relu") parser.add_argument("--aggr", type=str, default="softmax_sg") parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--p", type=float, default=1.0) parser.add_argument("--learn-beta", action="store_true") parser.add_argument("--learn-p", action="store_true") parser.add_argument("--learn-msg-scale", action="store_true") parser.add_argument("--use-msg-norm", action="store_true") # fmt: on """ ogbn-products: num_layers: 14 self_loop: aggr: softmax_sg beta: 0.1 """
[docs] @classmethod def build_model_from_args(cls, args): return cls( in_feat=args.num_features, hidden_size=args.hidden_size, out_feat=args.num_classes, num_layers=args.num_layers, connection=args.connection, activation=args.connection, dropout=args.dropout, aggr=args.aggr, beta=args.beta, p=args.p, learn_beta=args.learn_beta, learn_p=args.learn_p, learn_msg_scale=args.learn_msg_scale, use_msg_norm=args.use_msg_norm, )
def __init__( self, in_feat, hidden_size, out_feat, num_layers, connection="res+", activation="relu", dropout=0.0, aggr="max", beta=1.0, p=1.0, learn_beta=False, learn_p=False, learn_msg_scale=True, use_msg_norm=False, ): super(DeeperGCN, self).__init__() self.dropout = dropout self.feat_encoder = nn.Linear(in_feat, hidden_size) self.layers = nn.ModuleList() self.layers.append(GENConv(hidden_size, hidden_size)) for i in range(num_layers - 1): self.layers.append( DeepGCNLayer( in_feat=hidden_size, out_feat=hidden_size, conv=GENConv( in_feat=hidden_size, out_feat=hidden_size, aggr=aggr, beta=beta, p=p, learn_beta=learn_beta, learn_p=learn_p, use_msg_norm=use_msg_norm, learn_msg_scale=learn_msg_scale, ), connection=connection, activation=activation, dropout=dropout, checkpoint_grad=(num_layers > 3) and ((i + 1) == num_layers // 2), ) ) self.norm = nn.BatchNorm1d(hidden_size, affine=True) self.activation = get_activation(activation) self.fc = nn.Linear(hidden_size, out_feat)
[docs] def forward(self, x, edge_index, edge_attr=None): h = self.feat_encoder(x) for layer in self.layers: h = layer(h, edge_index) h = self.activation(self.norm(h)) h = F.dropout(h, p=self.dropout, training=self.training) h = self.fc(h) return h
[docs] def predict(self, data): return self.forward(data.x, data.edge_index)
[docs] @staticmethod def get_trainer(taskType: Any, args): return DeeperGCNTrainer