Source code for cogdl.models.agc.agc

import torch
import torch.nn
import torch.nn.functional as F

from .. import BaseModel, register_model
from cogdl.trainers.agc_trainer import AGCTrainer

[docs]@register_model("agc") class AGC(BaseModel): r"""The AGC model from the `"Attributed Graph Clustering via Adaptive Graph Convolution" <>`_ paper Args: num_clusters (int) : Number of clusters. max_iter (int) : Max iteration to increase k """
[docs] @staticmethod def add_args(parser): parser.add_argument("--max-iter", type=int, default=60)
[docs] @classmethod def build_model_from_args(cls, args): return cls(args.num_clusters, args.max_iter)
def __init__(self, num_clusters, max_iter): super(AGC, self).__init__() self.num_clusters = num_clusters self.max_iter = max_iter self.k = 0 self.features_matrix = None
[docs] def get_trainer(self, task, args): return AGCTrainer
[docs] def get_features(self, data): return self.features_matrix.detach().cpu()