import argparse
from typing import Dict
import numpy as np
import networkx as nx
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.metrics.cluster import normalized_mutual_info_score
import torch
import torch.nn.functional as F
from cogdl.datasets import build_dataset
from cogdl.models import build_model
from . import BaseTask, register_task
[docs]@register_task("attributed_graph_clustering")
class AttributedGraphClustering(BaseTask):
"""Attributed graph clustring task."""
[docs] @staticmethod
def add_args(parser: argparse.ArgumentParser):
"""Add task-specific arguments to the parser."""
# fmt: off
# parser.add_argument("--num-features", type=int)
parser.add_argument("--num-clusters", type=int, default=7)
parser.add_argument("--cluster-method", type=str, default="kmeans")
parser.add_argument("--hidden-size", type=int, default=128)
parser.add_argument("--model-type", type=str, default="content")
parser.add_argument("--evaluate", type=str, default="full")
parser.add_argument('--enhance', type=str, default=None, help='use prone or prone++ to enhance embedding')
# fmt: on
def __init__(
self,
args,
dataset=None,
_=None,
):
super(AttributedGraphClustering, self).__init__(args)
self.args = args
self.model_name = args.model
self.device = "cpu" if not torch.cuda.is_available() or args.cpu else args.device_id[0]
if dataset is None:
dataset = build_dataset(args)
self.dataset = dataset
self.data = dataset[0]
self.num_nodes = self.data.y.shape[0]
if args.model == "prone":
self.hidden_size = args.hidden_size = args.num_features = 13
else:
self.hidden_size = args.hidden_size = args.hidden_size
args.num_features = dataset.num_features
self.model = build_model(args)
self.num_clusters = args.num_clusters
if args.cluster_method not in ["kmeans", "spectral"]:
raise Exception("cluster method must be kmeans or spectral")
if args.model_type not in ["content", "spectral", "both"]:
raise Exception("model type must be content, spectral or both")
self.cluster_method = args.cluster_method
if args.evaluate not in ["full", "NMI"]:
raise Exception("evaluation must be full or NMI")
self.model_type = args.model_type
self.evaluate = args.evaluate
self.is_weighted = self.data.edge_attr is not None
self.enhance = args.enhance
[docs] def train(self) -> Dict[str, float]:
if self.model_type == "content":
features_matrix = self.data.x
elif self.model_type == "spectral":
G = nx.Graph()
if self.is_weighted:
edges, weight = (
self.data.edge_index.t().tolist(),
self.data.edge_attr.tolist(),
)
G.add_weighted_edges_from([(edges[i][0], edges[i][1], weight[i][0]) for i in range(len(edges))])
else:
G.add_edges_from(self.data.edge_index.t().tolist())
embeddings = self.model.train(G)
if self.enhance is not None:
embeddings = self.enhance_emb(G, embeddings)
# Map node2id
features_matrix = np.zeros((self.num_nodes, self.hidden_size))
for vid, node in enumerate(G.nodes()):
features_matrix[node] = embeddings[vid]
features_matrix = torch.tensor(features_matrix)
features_matrix = F.normalize(features_matrix, p=2, dim=1)
else:
trainer = self.model.get_trainer(AttributedGraphClustering, self.args)(self.args)
self.model = trainer.fit(self.model, self.data)
features_matrix = self.model.get_features(self.data)
features_matrix = features_matrix.cpu().numpy()
print("Clustering...")
if self.cluster_method == "kmeans":
kmeans = KMeans(n_clusters=self.num_clusters, random_state=0).fit(features_matrix)
clusters = kmeans.labels_
else:
clustering = SpectralClustering(
n_clusters=self.num_clusters, assign_labels="discretize", random_state=0
).fit(features_matrix)
clusters = clustering.labels_
if self.evaluate == "full":
return self.__evaluate(clusters, True)
else:
return self.__evaluate(clusters, False)
def __evaluate(self, clusters, full=True) -> Dict[str, float]:
print("Evaluating...")
truth = self.data.y.cpu().numpy()
if full:
TP = 0
FP = 0
TN = 0
FN = 0
for i in range(self.num_nodes):
for j in range(i + 1, self.num_nodes):
if clusters[i] == clusters[j] and truth[i] == truth[j]:
TP += 1
if clusters[i] != clusters[j] and truth[i] == truth[j]:
FP += 1
if clusters[i] == clusters[j] and truth[i] != truth[j]:
FN += 1
if clusters[i] != clusters[j] and truth[i] != truth[j]:
TN += 1
_ = (TP + TN) / (TP + FP + TN + FN)
precision = TP / (TP + FP)
recall = TP / (TP + FN)
print("TP", TP, "FP", FP, "TN", TN, "FN", FN)
micro_f1 = 2 * (precision * recall) / (precision + recall)
return dict(Accuracy=precision, NMI=normalized_mutual_info_score(clusters, truth), Micro_F1=micro_f1)
else:
return dict(NMI=normalized_mutual_info_score(clusters, truth))