Source code for cogdl.layers.maggregator

import torch
import torch.nn as nn
from cogdl.utils import spmm


[docs]class MeanAggregator(torch.nn.Module): def __init__(self, in_channels, out_channels, bias=True): super(MeanAggregator, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.cached_result = None self.linear = nn.Linear(in_channels, out_channels, bias)
[docs] @staticmethod def norm(graph, x): graph.row_norm() x = spmm(graph, x) return x
[docs] def forward(self, graph, x): x = self.linear(x) x = self.norm(graph, x) return x
def __repr__(self): return "{}({}, {})".format(self.__class__.__name__, self.in_channels, self.out_channels)
[docs]class SumAggregator(torch.nn.Module): def __init__(self, in_channels, out_channels, bias=True): super(SumAggregator, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.cached_result = None self.linear = nn.Linear(in_channels, out_channels, bias)
[docs] @staticmethod def aggr(graph, x): x = spmm(graph, x) return x
[docs] def forward(self, graph, x): x = self.linear(x) x = self.aggr(graph, x) return x
def __repr__(self): return "{}({}, {})".format(self.__class__.__name__, self.in_channels, self.out_channels)