Source code for cogdl.layers.disengcn_layer

import torch
import torch.nn as nn

from cogdl.utils import edge_softmax


[docs]class DisenGCNLayer(nn.Module): """ Implementation of `"Disentangled Graph Convolutional Networks" <http://proceedings.mlr.press/v97/ma19a.html>`_. """ def __init__(self, in_feats, out_feats, K, iterations, tau=1.0, activation="leaky_relu"): super(DisenGCNLayer, self).__init__() self.K = K self.tau = tau self.iterations = iterations self.factor_dim = int(out_feats / K) self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats)) self.bias = nn.Parameter(torch.Tensor(out_feats)) self.reset_parameters() if activation == "leaky_relu": self.activation = nn.LeakyReLU() elif activation == "sigmoid": self.activation = nn.Sigmoid() elif activation == "tanh": self.activation = nn.Tanh() elif activation == "prelu": self.activation = nn.PReLU() elif activation == "relu": self.activation = nn.ReLU() else: raise NotImplementedError
[docs] def reset_parameters(self): nn.init.xavier_normal_(self.weight.data, gain=1.414) nn.init.zeros_(self.bias.data)
[docs] def forward(self, graph, x): num_nodes = x.shape[0] device = x.device h = self.activation(torch.matmul(x, self.weight) + self.bias) h = h.split(self.factor_dim, dim=-1) h = torch.cat([dt.unsqueeze(0) for dt in h], dim=0) norm = h.pow(2).sum(dim=-1).sqrt().unsqueeze(-1) # multi-channel softmax: faster h_normed = h / norm # (K, N, d) h_src = h_dst = h_normed.permute(1, 0, 2) # (N, K, d) add_shape = h.shape # (K, N, d) edge_index = graph.edge_index for _ in range(self.iterations): src_edge_attr = h_dst[edge_index[0]] * h_src[edge_index[1]] src_edge_attr = src_edge_attr.sum(dim=-1) / self.tau # shape: (N, K) edge_attr_softmax = edge_softmax(graph, src_edge_attr).T # shape: (E, K) edge_attr_softmax = edge_attr_softmax.unsqueeze(-1) # shape: (K, E, 1) dst_edge_attr = h_src.index_select(0, edge_index[1]).permute(1, 0, 2) # shape: (E, K, d) -> (K, E, d) dst_edge_attr = dst_edge_attr * edge_attr_softmax edge_index_ = edge_index[0].unsqueeze(-1).unsqueeze(0).repeat(self.K, 1, h.shape[-1]) node_attr = torch.zeros(add_shape).to(device).scatter_add_(1, edge_index_, dst_edge_attr) # (K, N, d) node_attr = node_attr + h_normed node_attr_norm = node_attr.pow(2).sum(-1).sqrt().unsqueeze(-1) # shape: (K, N, 1) node_attr = (node_attr / node_attr_norm).permute(1, 0, 2) # shape: (N, K, d) h_dst = node_attr h_dst = h_dst.reshape(num_nodes, -1) # Calculate the softmax of each channel separately # h_src = h_dst = h / norm # (K, N, d) # # for _ in range(self.iterations): # for i in range(self.K): # h_attr = h_dst[i] # edge_attr = h_attr[edge_index[0]] * h_src[i][edge_index[1]] # # edge_attr = edge_attr.sum(-1)/self.tau # edge_attr = edge_softmax(edge_index, edge_attr, shape=(num_nodes, num_nodes)) # # node_attr = spmm(edge_index, edge_attr, h_src[i]) # # node_attr = node_attr + h_src[i] # h_src[i] = node_attr / node_attr.pow(2).sum(-1).sqrt().unsqueeze(-1) # # h_dst = h_dst.permute(1, 0, 2).reshape(num_nodes, -1) return h_dst