import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]def cal_mrr(embedding, rel_embedding, edge_index, edge_type, scoring, protocol="raw", batch_size=1000, hits=[]):
with torch.no_grad():
if protocol == "raw":
heads = edge_index[0]
tails = edge_index[1]
ranks_h = get_raw_rank(heads, tails, edge_type, embedding, rel_embedding, batch_size, scoring)
ranks_t = get_raw_rank(tails, heads, edge_type, embedding, rel_embedding, batch_size, scoring)
# ranks = torch.cat((ranks_h, ranks_t)) + 1
ranks = np.concatenate((ranks_h, ranks_t)) + 1
elif protocol == "filtered":
raise NotImplementedError
else:
raise ValueError
mrr = (1. / ranks).mean()
hits_count = []
# for hit in hits:
# hits_count.append(torch.mean((ranks <= hit).float()).item())
for hit in hits:
hits_count.append(np.mean((ranks <= hit).astype(np.float)))
# return mrr.item(), hits_count
return mrr, hits_count
[docs]class DistMultLayer(nn.Module):
def __init__(self):
super(DistMultLayer, self).__init__()
[docs] def forward(self, sub_emb, obj_emb, rel_emb):
return torch.sum(sub_emb * obj_emb * rel_emb, dim=-1)
[docs] def predict(self, sub_emb, obj_emb, rel_emb):
return torch.matmul(sub_emb * rel_emb, obj_emb.t())
[docs]class ConvELayer(nn.Module):
def __init__(self, dim, num_filter=20, kernel_size=7, k_w=10, dropout=0.3):
super(ConvELayer, self).__init__()
assert dim % k_w == 0
self.k_w = k_w
self.k_h = dim // k_w
self.dim = dim
self.bn0 = torch.nn.BatchNorm2d(1)
self.bn1 = torch.nn.BatchNorm2d(num_filter)
self.bn2 = torch.nn.BatchNorm1d(dim)
self.hidden_drop = nn.Dropout(dropout)
self.hidden_drop2 = nn.Dropout(dropout)
self.feature_drop = nn.Dropout(dropout)
self.conv = nn.Conv2d(1, out_channels=num_filter, kernel_size=(kernel_size, kernel_size), stride=1, padding=0, bias=True)
flat_size_h = int(2*self.k_w) - kernel_size + 1
flat_size_w = self.k_h - kernel_size + 1
self.flat_size = flat_size_h * flat_size_w * num_filter
self.fc = nn.Linear(self.flat_size, dim)
self.bias = nn.Parameter(torch.zeros(dim))
[docs] def concat(self, ent, rel):
ent = ent.view(-1, 1, self.dim)
rel = rel.view(-1, 1, self.dim)
ent_rel = torch.cat([ent, rel], dim=1)
ent_rel = ent_rel.transpose(2, 1).reshape(-1, 1, 2 * self.k_w, self.k_h)
return ent_rel
[docs] def forward(self, sub_emb, obj_emb, rel_emb):
h = self.concat(sub_emb, rel_emb)
h = self.bn0(h)
h = self.conv(h)
h = F.relu(self.bn1(h))
h = self.feature_drop(h)
h = h.view(-1, self.flat_size)
h = self.hidden_drop(self.fc(h))
h = F.relu(self.bn2(self.hidden_drop2(h)))
x = torch.sum(h * obj_emb + self.bias, dim=-1)
return x
[docs] def predict(self, sub_emb, obj_emb, rel_emb):
h = self.concat(sub_emb, rel_emb)
h = self.bn0(h)
h = self.conv(h)
h = F.relu(self.bn1(h))
h = h.view(-1, self.flat_size)
h = self.fc(h)
h = F.relu(self.bn2(h))
x = torch.matmul(h, obj_emb.t())
return x
[docs]class GNNLinkPredict(nn.Module):
def __init__(self, score_func, dim):
super(GNNLinkPredict, self).__init__()
self.edge_set = None
self.score_func = score_func
if score_func == "distmult":
self.scoring = DistMultLayer()
elif score_func == "conve":
self.scoring = ConvELayer(dim)
else:
raise NotImplementedError
[docs] def forward(self, edge_index, edge_type):
raise NotImplemented
[docs] def get_score(self, heads, tails, rels):
return self.scoring(heads, tails, rels)
[docs] def get_edge_set(self, edge_index, edge_types):
if self.edge_set is None:
edge_list = torch.cat((edge_index, edge_types.unsqueeze(0)), dim=0).T
edge_list = edge_list.cpu().T.numpy().tolist()
torch.cuda.empty_cache()
self.edge_set = set([tuple(x) for x in edge_list]) # tuple(h, t, r)
[docs] def _loss(self, head_embed, tail_embed, rel_embed, labels):
score = self.get_score(head_embed, tail_embed, rel_embed)
prediction_loss = F.binary_cross_entropy_with_logits(score, labels.float())
return prediction_loss
[docs] def _regularization(self, embs):
loss = 0
for emb in embs:
loss += torch.mean(emb.pow(2))
return loss
[docs]def get_rank(scores, target):
_, indices = torch.sort(scores, dim=1, descending=True)
rank = (indices == target.view(-1, 1)).nonzero()[:, 1]
return rank.view(-1)
[docs]def get_raw_rank(heads, tails, rels, embedding, rel_embedding, batch_size, scoring):
test_size = heads.shape[0]
num_batch = (test_size + batch_size - 1) // batch_size
ranks = []
for i in range(num_batch):
start = batch_size * i
end = start + batch_size
scores = torch.sigmoid(scoring.predict(embedding[heads[start:end]], embedding, rel_embedding[rels[start:end]]))
target = tails[start:end]
rank = get_rank(scores, target).cpu().numpy()
torch.cuda.empty_cache()
ranks.append(rank)
return np.concatenate(ranks).astype(np.float)
[docs]def get_filtered_rank(heads, tails, rels, embedding, rel_embedding, batch_size, seen_data):
pass