Source code for cogdl.models.nn.infograph

import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from .mlp import MLP
from cogdl.layers import GINLayer
from cogdl.utils import batch_mean_pooling, batch_sum_pooling, split_dataset_general
from .. import BaseModel


class Encoder(nn.Module):
    r"""Encoder stacked with GIN layers

    Parameters
    ----------
    in_feats : int
        Size of each input sample.
    hidden_feats : int
        Size of output embedding.
    num_layers : int, optional
        Number of GIN layers, default: ``3``.
    num_mlp_layers : int, optional
        Number of MLP layers for each GIN layer, default: ``2``.
    pooling : str, optional
        Aggragation type, default : ``sum``.

    """

    def __init__(self, in_feats, hidden_dim, num_layers=3, num_mlp_layers=2, pooling="sum"):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.gnn_layers = nn.ModuleList()
        self.bn_layers = nn.ModuleList()
        for i in range(num_layers):
            if i == 0:
                mlp = MLP(in_feats, hidden_dim, hidden_dim, num_mlp_layers, norm="batchnorm")
            else:
                mlp = MLP(hidden_dim, hidden_dim, hidden_dim, num_mlp_layers, norm="batchnorm")
            self.gnn_layers.append(GINLayer(mlp, eps=0, train_eps=True))
            self.bn_layers.append(nn.BatchNorm1d(hidden_dim))

        if pooling == "sum":
            self.pooling = batch_sum_pooling
        elif pooling == "mean":
            self.pooling = batch_mean_pooling
        else:
            raise NotImplementedError

    def forward(self, graph, x=None, *args):
        batch = graph.batch
        if x is None:
            x = torch.ones((batch.shape[0], 1)).to(batch.device)
        layer_rep = []
        for i in range(self.num_layers):
            x = F.relu(self.bn_layers[i](self.gnn_layers[i](graph, x)))
            layer_rep.append(x)

        pooled_rep = [self.pooling(h, batch) for h in layer_rep]
        node_rep = torch.cat(layer_rep, dim=1)
        graph_rep = torch.cat(pooled_rep, dim=1)
        return graph_rep, node_rep


class FF(nn.Module):
    r"""Residual MLP layers.

    ..math::
        out = \mathbf{MLP}(x) + \mathbf{Linear}(x)

    Paramaters
    ----------
    in_feats : int
        Size of each input sample
    out_feats : int
        Size of each output sample
    """

    def __init__(self, in_feats, out_feats):
        super(FF, self).__init__()
        self.block = MLP(in_feats, out_feats, out_feats, num_layers=3)
        self.shortcut = nn.Linear(in_feats, out_feats)

    def forward(self, x):
        return F.relu(self.block(x)) + self.shortcut(x)


[docs]class InfoGraph(BaseModel): r"""Implimentation of Infograph in paper `"InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization" <https://openreview.net/forum?id=r1lfF2NYvH>__. ` Parameters ---------- in_feats : int Size of each input sample. out_feats : int Size of each output sample. num_layers : int, optional Number of MLP layers in encoder, default: ``3``. unsup : bool, optional Use unsupervised model if True, default: ``True``. """
[docs] @staticmethod def add_args(parser): parser.add_argument("--hidden-size", type=int, default=512) parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--target", dest="target", type=int, default=0, help="") parser.add_argument("--train-num", dest="train_num", type=int, default=5000) parser.add_argument("--num-layers", type=int, default=1) parser.add_argument("--sup", dest="sup", action="store_true") parser.add_argument("--epochs", type=int, default=15) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--test-ratio", type=float, default=0.1)
[docs] @classmethod def build_model_from_args(cls, args): return cls(args.num_features, args.hidden_size, args.num_classes, args.num_layers, args.sup)
[docs] @classmethod def split_dataset(cls, dataset, args): return split_dataset_general(dataset, args)
def __init__(self, in_feats, hidden_dim, out_feats, num_layers=3, sup=False): super(InfoGraph, self).__init__() self.sup = sup self.emb_dim = hidden_dim self.out_feats = out_feats self.num_layers = num_layers self.sem_fc1 = nn.Linear(num_layers * hidden_dim, hidden_dim) self.sem_fc2 = nn.Linear(hidden_dim, out_feats) if not sup: self.unsup_encoder = Encoder(in_feats, hidden_dim, num_layers) self.register_parameter("sem_encoder", None) else: self.unsup_encoder = Encoder(in_feats, hidden_dim, num_layers) self.sem_encoder = Encoder(in_feats, hidden_dim, num_layers) self._fc1 = FF(num_layers * hidden_dim, hidden_dim) self._fc2 = FF(num_layers * hidden_dim, hidden_dim) self.local_dis = FF(num_layers * hidden_dim, hidden_dim) self.global_dis = FF(num_layers * hidden_dim, hidden_dim) self.criterion = nn.MSELoss()
[docs] def reset_parameters(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight.data)
[docs] def forward(self, batch): if self.sup: return self.sup_forward(batch, batch.x) else: return self.unsup_forward(batch, batch.x)
[docs] def sup_forward(self, batch, x): node_feat, graph_feat = self.sem_encoder(batch, x) node_feat = F.relu(self.sem_fc1(node_feat)) node_feat = self.sem_fc2(node_feat) return node_feat
[docs] def unsup_forward(self, batch, x): # return self.unsup_loss(x, edge_index, batch) graph_feat, node_feat = self.unsup_encoder(batch, x) if self.training: return graph_feat, node_feat else: return graph_feat