Source code for cogdl.wrappers.model_wrapper.clustering.agc_mw

from cogdl.wrappers.tools.wrapper_utils import evaluate_clustering
from .. import EmbeddingModelWrapper


[docs]class AGCModelWrapper(EmbeddingModelWrapper):
[docs] @staticmethod def add_args(parser): # fmt: off parser.add_argument("--num-clusters", type=int, default=7) parser.add_argument("--cluster-method", type=str, default="kmeans", help="option: kmeans or spectral") parser.add_argument("--evaluation", type=str, default="full", help="option: full or NMI")
# fmt: on def __init__(self, model, optimizer_cfg, num_clusters, cluster_method="kmeans", evaluation="full", max_iter=5): super(AGCModelWrapper, self).__init__() self.model = model self.optimizer_cfg = optimizer_cfg self.num_clusters = num_clusters self.cluster_method = cluster_method self.full = evaluation == "full"
[docs] def train_step(self, graph): emb = self.model.forward(graph) return emb
[docs] def test_step(self, batch): features_matrix, graph = batch return evaluate_clustering( features_matrix, graph.y, self.cluster_method, self.num_clusters, graph.num_nodes, self.full )