Source code for cogdl.layers.deepergcn_layer

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from .mlp_layer import MLP
from cogdl.utils import get_activation, mul_edge_softmax, get_norm_layer, batch_max_pooling


[docs]class GENConv(nn.Module): def __init__( self, in_feats: int, out_feats: int, aggr: str = "softmax_sg", beta: float = 1.0, p: float = 1.0, learn_beta: bool = False, learn_p: bool = False, use_msg_norm: bool = False, learn_msg_scale: bool = True, norm: Optional[str] = None, residual: bool = False, activation: Optional[str] = None, num_mlp_layers: int = 2, edge_attr_size: Optional[list] = None, ): super(GENConv, self).__init__() self.use_msg_norm = use_msg_norm self.mlp = MLP(in_feats, out_feats, in_feats * 2, num_layers=num_mlp_layers, activation=activation, norm=norm) self.message_encoder = torch.nn.ReLU() self.aggr = aggr if aggr == "softmax_sg": self.beta = torch.nn.Parameter( torch.Tensor( [ beta, ] ), requires_grad=learn_beta, ) else: self.register_buffer("beta", None) if aggr == "powermean": self.p = torch.nn.Parameter( torch.Tensor( [ p, ] ), requires_grad=learn_p, ) else: self.register_buffer("p", None) self.eps = 1e-7 self.s = torch.nn.Parameter(torch.Tensor([1.0]), requires_grad=learn_msg_scale and use_msg_norm) self.residual = residual if edge_attr_size is not None and edge_attr_size[0] > 0: if len(edge_attr_size) > 1: self.edge_encoder = BondEncoder(edge_attr_size, in_feats) else: self.edge_encoder = EdgeEncoder(edge_attr_size[0], in_feats) else: self.edge_encoder = None
[docs] def message_norm(self, x, msg): x_norm = torch.norm(x, dim=1, p=2) msg_norm = F.normalize(msg, p=2, dim=1) msg_norm = msg_norm * x_norm.unsqueeze(-1) return x + self.s * msg_norm
[docs] def forward(self, graph, x): edge_index = graph.edge_index dim = x.shape[1] edge_msg = x[edge_index[1]] if self.edge_encoder is not None and graph.edge_attr is not None: edge_msg += self.edge_encoder(graph.edge_attr) edge_msg = self.message_encoder(edge_msg) + self.eps if self.aggr == "softmax_sg": h = mul_edge_softmax(graph, self.beta * edge_msg.contiguous()) h = edge_msg * h elif self.aggr == "softmax": h = mul_edge_softmax(graph, edge_msg) h = edge_msg * h elif self.aggr == "powermean": deg = graph.degrees() torch.clamp_(edge_msg, 1e-7, 1.0) h = edge_msg.pow(self.p) / deg[edge_index[0]].unsqueeze(-1) elif self.aggr == "mean": deg = graph.degrees() deg_rev = deg.pow(-1) deg_rev[torch.isinf(deg_rev)] = 0 h = edge_msg * deg_rev[edge_index[0]].unsqueeze(-1) else: h = edge_msg if self.aggr == "max": h = batch_max_pooling(h, edge_index[0]) else: h = torch.zeros_like(x).scatter_add_(dim=0, index=edge_index[0].unsqueeze(-1).repeat(1, dim), src=h) if self.aggr == "powermean": h = h.pow(1.0 / self.p) if self.use_msg_norm: h = self.message_norm(x, h) if self.residual: h = h + x h = self.mlp(h) return h
[docs]class ResGNNLayer(nn.Module): """ Implementation of DeeperGCN in paper `"DeeperGCN: All You Need to Train Deeper GCNs"` <https://arxiv.org/abs/2006.07739> Parameters ----------- conv : nn.Module An instance of GNN Layer, recieving (graph, x) as inputs n_channels : int size of input features activation : str norm: str type of normalization, ``batchnorm`` as default dropout : float checkpoint_grad : bool """ def __init__( self, conv, in_channels, activation="relu", norm="batchnorm", dropout=0.0, out_norm=None, out_channels=-1, residual=True, checkpoint_grad=False, ): super(ResGNNLayer, self).__init__() self.conv = conv self.activation = get_activation(activation) self.dropout = dropout self.norm = get_norm_layer(norm, in_channels) self.residual = residual if out_norm: self.out_norm = get_norm_layer(norm, out_channels) else: self.out_norm = None self.checkpoint_grad = checkpoint_grad
[docs] def forward(self, graph, x, dropout=None, *args, **kwargs): h = self.norm(x) h = self.activation(h) if isinstance(dropout, float) or dropout is None: h = F.dropout(h, p=self.dropout, training=self.training) else: if self.training: h = h * dropout if self.checkpoint_grad: h = checkpoint(self.conv, graph, h, *args, **kwargs) else: h = self.conv(graph, h, *args, **kwargs) if self.residual: h = h + x if self.out_norm: return self.out_norm(h) else: return h
[docs]class EdgeEncoder(nn.Module): def __init__(self, in_feats, out_feats, bias=False): super(EdgeEncoder, self).__init__() self.nn = nn.Linear(in_feats, out_feats, bias=bias)
[docs] def forward(self, edge_attr): return self.nn(edge_attr)
[docs]class BondEncoder(nn.Module): def __init__(self, bond_dim_list, emb_size): super(BondEncoder, self).__init__() self.bond_emb_list = nn.ModuleList() for i, size in enumerate(bond_dim_list): x = nn.Embedding(size, emb_size) self.bond_emb_list.append(x)
[docs] def forward(self, edge_attr): out = 0 for i in range(edge_attr.shape[1]): out += self.bond_emb_list[i](edge_attr[:, i]) return out