Source code for cogdl.layers.han_layer

import torch
import torch.nn as nn

from .gat_layer import GATLayer


[docs]class AttentionLayer(nn.Module): def __init__(self, num_features): super(AttentionLayer, self).__init__() self.linear = nn.Linear(num_features, 1)
[docs] def forward(self, x): att = self.linear(x).view(-1, 1, x.shape[1]) return torch.matmul(att, x).squeeze(1)
[docs]class HANLayer(nn.Module): def __init__(self, num_edge, w_in, w_out): super(HANLayer, self).__init__() self.gat_layer = nn.ModuleList() for _ in range(num_edge): self.gat_layer.append(GATLayer(w_in, w_out // 8, 8)) self.att_layer = AttentionLayer(w_out)
[docs] def forward(self, graph, x): adj = graph.adj output = [] with graph.local_graph(): for i, edge in enumerate(adj): graph.edge_index = edge[0] output.append(self.gat_layer[i](graph, x)) output = torch.stack(output, dim=1) return self.att_layer(output)