Source code for cogdl.models.emb.gatne

import numpy as np
import networkx as nx
from collections import defaultdict
from gensim.models.keyedvectors import Vocab
import random
import math
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from .. import BaseModel, register_model


[docs]@register_model("gatne") class GATNE(BaseModel): r"""The GATNE model from the `"Representation Learning for Attributed Multiplex Heterogeneous Network" <https://dl.acm.org/doi/10.1145/3292500.3330964>`_ paper Args: walk_length (int) : The walk length. walk_num (int) : The number of walks to sample for each node. window_size (int) : The actual context size which is considered in language model. worker (int) : The number of workers for word2vec. epoch (int) : The number of training epochs. batch_size (int) : The size of each training batch. edge_dim (int) : Number of edge embedding dimensions. att_dim (int) : Number of attention dimensions. negative_samples (int) : Negative samples for optimization. neighbor_samples (int) : Neighbor samples for aggregation schema (str) : The metapath schema used in model. Metapaths are splited with ",", while each node type are connected with "-" in each metapath. For example:"0-1-0,0-1-2-1-0" """
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument('--walk-length', type=int, default=10, help='Length of walk per source. Default is 10.') parser.add_argument('--walk-num', type=int, default=10, help='Number of walks per source. Default is 10.') parser.add_argument('--window-size', type=int, default=5, help='Window size of skip-gram model. Default is 5.') parser.add_argument('--worker', type=int, default=10, help='Number of parallel workers. Default is 10.') parser.add_argument('--epoch', type=int, default=20, help='Number of epoch. Default is 20.') parser.add_argument('--batch-size', type=int, default=256, help='Number of batch_size. Default is 256.') parser.add_argument('--edge-dim', type=int, default=10, help='Number of edge embedding dimensions. Default is 10.') parser.add_argument('--att-dim', type=int, default=20, help='Number of attention dimensions. Default is 20.') parser.add_argument('--negative-samples', type=int, default=5, help='Negative samples for optimization. Default is 5.') parser.add_argument('--neighbor-samples', type=int, default=10, help='Neighbor samples for aggregation. Default is 10.') parser.add_argument('--schema', type=str, default=None, help="Input schema for metapath random walk.")
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.hidden_size, args.walk_length, args.walk_num, args.window_size, args.worker, args.epoch, args.batch_size, args.edge_dim, args.att_dim, args.negative_samples, args.neighbor_samples, args.schema, )
def __init__( self, dimension, walk_length, walk_num, window_size, worker, epoch, batch_size, edge_dim, att_dim, negative_samples, neighbor_samples, schema, ): super(GATNE, self).__init__() self.embedding_size = dimension self.walk_length = walk_length self.walk_num = walk_num self.window_size = window_size self.worker = worker self.epochs = epoch self.batch_size = batch_size self.embedding_u_size = edge_dim self.dim_att = att_dim self.num_sampled = negative_samples self.neighbor_samples = neighbor_samples self.schema = schema self.multiplicity = True
[docs] def train(self, network_data): all_walks = generate_walks(network_data, self.walk_num, self.walk_length, schema=self.schema) vocab, index2word = generate_vocab(all_walks) train_pairs = generate_pairs(all_walks, vocab) edge_types = list(network_data.keys()) num_nodes = len(index2word) edge_type_count = len(edge_types) epochs = self.epochs batch_size = self.batch_size embedding_size = self.embedding_size embedding_u_size = self.embedding_u_size num_sampled = self.num_sampled dim_att = self.dim_att neighbor_samples = self.neighbor_samples neighbors = [[[] for __ in range(edge_type_count)] for _ in range(num_nodes)] for r in range(edge_type_count): g = network_data[edge_types[r]] for (x, y) in g: ix = vocab[x].index iy = vocab[y].index neighbors[ix][r].append(iy) neighbors[iy][r].append(ix) for i in range(num_nodes): if len(neighbors[i][r]) == 0: neighbors[i][r] = [i] * neighbor_samples elif len(neighbors[i][r]) < neighbor_samples: neighbors[i][r].extend( list( np.random.choice( neighbors[i][r], size=neighbor_samples - len(neighbors[i][r]), ) ) ) elif len(neighbors[i][r]) > neighbor_samples: neighbors[i][r] = list(np.random.choice(neighbors[i][r], size=neighbor_samples)) model = GATNEModel(num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_att) nsloss = NSLoss(num_nodes, num_sampled, embedding_size) model.to(self.device) nsloss.to(self.device) optimizer = torch.optim.Adam([{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-4) for epoch in range(epochs): random.shuffle(train_pairs) batches = get_batches(train_pairs, neighbors, batch_size) data_iter = tqdm.tqdm( batches, desc="epoch %d" % (epoch), total=(len(train_pairs) + (batch_size - 1)) // batch_size, bar_format="{l_bar}{r_bar}", ) avg_loss = 0.0 for i, data in enumerate(data_iter): optimizer.zero_grad() embs = model( data[0].to(self.device), data[2].to(self.device), data[3].to(self.device), ) loss = nsloss(data[0].to(self.device), embs, data[1].to(self.device)) loss.backward() optimizer.step() avg_loss += loss.item() if i % 5000 == 0: post_fix = { "epoch": epoch, "iter": i, "avg_loss": avg_loss / (i + 1), "loss": loss.item(), } data_iter.write(str(post_fix)) final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)])) for i in range(num_nodes): train_inputs = torch.tensor([i for _ in range(edge_type_count)]).to(self.device) train_types = torch.tensor(list(range(edge_type_count))).to(self.device) node_neigh = torch.tensor([neighbors[i] for _ in range(edge_type_count)]).to(self.device) node_emb = model(train_inputs, train_types, node_neigh) for j in range(edge_type_count): final_model[edge_types[j]][index2word[i]] = node_emb[j].cpu().detach().numpy() return final_model
class GATNEModel(nn.Module): def __init__(self, num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a): super(GATNEModel, self).__init__() self.num_nodes = num_nodes self.embedding_size = embedding_size self.embedding_u_size = embedding_u_size self.edge_type_count = edge_type_count self.dim_a = dim_a self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size)) self.node_type_embeddings = Parameter(torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)) self.trans_weights = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)) self.trans_weights_s1 = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)) self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1)) self.reset_parameters() def reset_parameters(self): self.node_embeddings.data.uniform_(-1.0, 1.0) self.node_type_embeddings.data.uniform_(-1.0, 1.0) self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) def forward(self, train_inputs, train_types, node_neigh): node_embed = self.node_embeddings[train_inputs] node_embed_neighbors = self.node_type_embeddings[node_neigh] node_embed_tmp = torch.cat( [node_embed_neighbors[:, i, :, i, :].unsqueeze(1) for i in range(self.edge_type_count)], dim=1, ) node_type_embed = torch.sum(node_embed_tmp, dim=2) trans_w = self.trans_weights[train_types] trans_w_s1 = self.trans_weights_s1[train_types] trans_w_s2 = self.trans_weights_s2[train_types] attention = F.softmax( torch.matmul(F.tanh(torch.matmul(node_type_embed, trans_w_s1)), trans_w_s2).squeeze() ).unsqueeze(1) node_type_embed = torch.matmul(attention, node_type_embed) node_embed = node_embed + torch.matmul(node_type_embed, trans_w).squeeze() last_node_embed = F.normalize(node_embed, dim=1) return last_node_embed class NSLoss(nn.Module): def __init__(self, num_nodes, num_sampled, embedding_size): super(NSLoss, self).__init__() self.num_nodes = num_nodes self.num_sampled = num_sampled self.embedding_size = embedding_size self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size)) self.sample_weights = F.normalize( torch.Tensor([(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) for k in range(num_nodes)]), dim=0, ) self.reset_parameters() def reset_parameters(self): self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) def forward(self, input, embs, label): n = input.shape[0] log_target = torch.log(torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))) negs = torch.multinomial(self.sample_weights, self.num_sampled * n, replacement=True).view(n, self.num_sampled) noise = torch.neg(self.weights[negs]) sum_log_sampled = torch.sum(torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1).squeeze() loss = log_target + sum_log_sampled return -loss.sum() / n class RWGraph: def __init__(self, nx_G, node_type=None): self.G = nx_G self.node_type = node_type def walk(self, walk_length, start, schema=None): # Simulate a random walk starting from start node. G = self.G rand = random.Random() if schema: schema_items = schema.split("-") assert schema_items[0] == schema_items[-1] walk = [start] while len(walk) < walk_length: cur = walk[-1] candidates = [] for node in G[cur].keys(): if schema is None or self.node_type[node] == schema_items[len(walk) % (len(schema_items) - 1)]: candidates.append(node) if candidates: walk.append(rand.choice(candidates)) else: break return walk def simulate_walks(self, num_walks, walk_length, schema=None): G = self.G walks = [] nodes = list(G.nodes()) # print('Walk iteration:') if schema is not None: schema_list = schema.split(",") for walk_iter in range(num_walks): random.shuffle(nodes) for node in nodes: if schema is None: walks.append(self.walk(walk_length=walk_length, start=node)) else: for schema_iter in schema_list: if schema_iter.split("-")[0] == self.node_type[node]: walks.append( self.walk( walk_length=walk_length, start=node, schema=schema_iter, ) ) return walks def get_G_from_edges(edges): edge_dict = dict() for edge in edges: edge_key = str(edge[0]) + "_" + str(edge[1]) if edge_key not in edge_dict: edge_dict[edge_key] = 1 else: edge_dict[edge_key] += 1 tmp_G = nx.Graph() for edge_key in edge_dict: weight = edge_dict[edge_key] x = int(edge_key.split("_")[0]) y = int(edge_key.split("_")[1]) tmp_G.add_edge(x, y) tmp_G[x][y]["weight"] = weight return tmp_G def generate_pairs(all_walks, vocab, window_size=5): pairs = [] skip_window = window_size // 2 for layer_id, walks in enumerate(all_walks): for walk in walks: for i in range(len(walk)): for j in range(1, skip_window + 1): if i - j >= 0: pairs.append((vocab[walk[i]].index, vocab[walk[i - j]].index, layer_id)) if i + j < len(walk): pairs.append((vocab[walk[i]].index, vocab[walk[i + j]].index, layer_id)) return pairs def generate_vocab(all_walks): index2word = [] raw_vocab = defaultdict(int) for walks in all_walks: for walk in walks: for word in walk: raw_vocab[word] += 1 vocab = {} for word, v in raw_vocab.items(): vocab[word] = Vocab(count=v, index=len(index2word)) index2word.append(word) index2word.sort(key=lambda word: vocab[word].count, reverse=True) for i, word in enumerate(index2word): vocab[word].index = i return vocab, index2word def get_batches(pairs, neighbors, batch_size): n_batches = (len(pairs) + (batch_size - 1)) // batch_size # result = [] for idx in range(n_batches): x, y, t, neigh = [], [], [], [] for i in range(batch_size): index = idx * batch_size + i if index >= len(pairs): break x.append(pairs[index][0]) y.append(pairs[index][1]) t.append(pairs[index][2]) neigh.append(neighbors[pairs[index][0]]) yield torch.tensor(x), torch.tensor(y), torch.tensor(t), torch.tensor(neigh) def generate_walks(network_data, num_walks, walk_length, schema=None): # if schema is not None: # pass # else: # node_type = None all_walks = [] for layer_id in network_data: tmp_data = network_data[layer_id] # start to do the random walk on a layer layer_walker = RWGraph(get_G_from_edges(tmp_data)) layer_walks = layer_walker.simulate_walks(num_walks, walk_length, schema=schema) all_walks.append(layer_walks) return all_walks