import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
from scipy.linalg import fractional_matrix_power, inv
from sklearn.preprocessing import MinMaxScaler
from .. import BaseModel, register_model
from .dgi import GCN, AvgReadout
from cogdl.utils import add_remaining_self_loops, symmetric_normalization
from cogdl.trainers.self_supervised_trainer import SelfSupervisedTrainer
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
"""Convert a scipy sparse matrix to a torch sparse tensor."""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)
def compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True):
a = nx.convert_matrix.to_numpy_array(graph)
if self_loop:
a = a + np.eye(a.shape[0]) # A^ = A + I_n
d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii
dinv = fractional_matrix_power(d, -0.5) # D^(-1/2)
at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2)
return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1
# Borrowed from https://github.com/kavehhassani/mvgrl
class Discriminator(nn.Module):
def __init__(self, n_h):
super(Discriminator, self).__init__()
self.f_k = nn.Bilinear(n_h, n_h, 1)
for m in self.modules():
self.weights_init(m)
def weights_init(self, m):
if isinstance(m, nn.Bilinear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self, c1, c2, h1, h2, h3, h4):
c_x1 = torch.unsqueeze(c1, 1)
c_x1 = c_x1.expand_as(h1).contiguous()
c_x2 = torch.unsqueeze(c2, 1)
c_x2 = c_x2.expand_as(h2).contiguous()
# positive
sc_1 = torch.squeeze(self.f_k(h2, c_x1), 2)
sc_2 = torch.squeeze(self.f_k(h1, c_x2), 2)
# negetive
sc_3 = torch.squeeze(self.f_k(h4, c_x1), 2)
sc_4 = torch.squeeze(self.f_k(h3, c_x2), 2)
logits = torch.cat((sc_1, sc_2, sc_3, sc_4), 1)
return logits
# Mainly borrowed from https://github.com/kavehhassani/mvgrl
[docs]@register_model("mvgrl")
class MVGRL(BaseModel):
[docs] @staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--max-epochs", type=int, default=1000)
parser.add_argument("--sample-size", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--sparse", action="store_true")
# fmt: on
[docs] @classmethod
def build_model_from_args(cls, args):
return cls(args.num_features, args.hidden_size, args.sample_size, args.batch_size, args.sparse, args.dataset)
def __init__(self, in_feats, hidden_size, sample_size=2000, batch_size=4, sparse=False, dataset="cora"):
super(MVGRL, self).__init__()
self.sample_size = sample_size
self.batch_size = batch_size
self.sparse = sparse
self.dataset_name = dataset
self.gcn1 = GCN(in_feats, hidden_size, "prelu")
self.gcn2 = GCN(in_feats, hidden_size, "prelu")
self.read = AvgReadout()
self.sigm = nn.Sigmoid()
self.disc = Discriminator(hidden_size)
self.loss_f = nn.BCEWithLogitsLoss()
self.cache = None
def _forward(self, seq1, seq2, adj, diff, sparse, msk):
h_1 = self.gcn1(seq1, adj, sparse)
c_1 = self.read(h_1, msk)
c_1 = self.sigm(c_1)
h_2 = self.gcn2(seq1, diff, sparse)
c_2 = self.read(h_2, msk)
c_2 = self.sigm(c_2)
h_3 = self.gcn1(seq2, adj, sparse)
h_4 = self.gcn2(seq2, diff, sparse)
ret = self.disc(c_1, c_2, h_1, h_2, h_3, h_4)
return ret, h_1, h_2
[docs] def preprocess(self, x, edge_index, edge_attr=None):
num_nodes = x.shape[0]
edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_attr)
adj = edge_index.cpu().numpy()
edge_weight = symmetric_normalization(x.shape[0], edge_index, edge_weight)
adj = sp.coo_matrix((edge_weight.cpu().numpy(), (adj[0], adj[1])), shape=(num_nodes, num_nodes)).todense()
g = nx.Graph()
g.add_nodes_from(list(range(num_nodes)))
g.add_edges_from(edge_index.cpu().numpy().transpose())
diff = compute_ppr(g, 0.2)
if self.dataset_name == "citeseer":
epsilons = [1e-5, 1e-4, 1e-3, 1e-2]
avg_degree = np.sum(adj) / adj.shape[0]
epsilon = epsilons[
np.argmin([abs(avg_degree - np.argwhere(diff >= e).shape[0] / diff.shape[0]) for e in epsilons])
]
diff[diff < epsilon] = 0.0
scaler = MinMaxScaler()
scaler.fit(diff)
diff = scaler.transform(diff)
if self.cache is None:
self.cache = dict()
self.cache["diff"] = diff
self.cache["adj"] = adj
self.device = next(self.gcn1.parameters()).device
[docs] def forward(self, x, edge_index, edge_attr=None):
if self.cache is None or "diff" not in self.cache:
self.preprocess(x, edge_index, edge_attr)
diff, adj = self.cache["diff"], self.cache["adj"]
idx = np.random.randint(0, adj.shape[-1] - self.sample_size + 1, self.batch_size)
ba, bd, bf = [], [], []
for i in idx:
ba.append(adj[i: i + self.sample_size, i: i + self.sample_size])
bd.append(diff[i: i + self.sample_size, i: i + self.sample_size])
bf.append(x[i: i + self.sample_size])
ba = np.array(ba).reshape(self.batch_size, self.sample_size, self.sample_size)
bd = np.array(bd)
bd = bd.reshape(self.batch_size, self.sample_size, self.sample_size)
bf = torch.stack(bf).reshape(self.batch_size, self.sample_size, x.shape[1])
if self.sparse:
ba = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(ba)).to(self.device)
bd = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(bd)).to(self.device)
else:
ba = torch.FloatTensor(ba)
bd = torch.FloatTensor(bd)
bf = bf.to(self.device)
idx = np.random.permutation(self.sample_size)
shuf_fts = bf[:, idx, :]
bf = bf.to(self.device)
ba = ba.to(self.device)
bd = bd.to(self.device)
shuf_fts = shuf_fts.to(self.device)
logits, _, _ = self._forward(bf, shuf_fts, ba, bd, self.sparse, None)
return logits
[docs] def loss(self, data):
if self.cache is None:
self.device = next(self.gcn1.parameters()).device
lbl_1 = torch.ones(self.batch_size, self.sample_size * 2)
lbl_2 = torch.zeros(self.batch_size, self.sample_size * 2)
lbl = torch.cat((lbl_1, lbl_2), 1)
lbl = lbl.to(self.device)
self.cache = {"labels": lbl}
lbl = self.cache["labels"]
logits = self.forward(data.x, data.edge_index, data.edge_attr)
loss = self.loss_f(logits, lbl)
return loss
[docs] def node_classification_loss(self, data):
return self.loss(data)
[docs] def embed(self, data, msk=None):
adj = torch.from_numpy(self.cache["adj"]).float().to(data.x.device)
diff = torch.from_numpy(self.cache["diff"]).float().to(data.x.device)
h_1 = self.gcn1(data.x, adj, self.sparse)
h_2 = self.gcn2(data.x, diff, self.sparse)
# c = self.read(h_1, msk)
return (h_1 + h_2).detach() # , c.detach()
[docs] @staticmethod
def get_trainer(taskType, args):
return SelfSupervisedTrainer