from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from .. import BaseModel
from cogdl.datasets.kg_data import TestDataset
[docs]class KGEModel(BaseModel):
@staticmethod
[docs] def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument("--embedding_size", type=int, default=256, help="Dimensionality of embedded vectors")
parser.add_argument("--nentity", type=int, help="Number of entities")
parser.add_argument("--nrelation", type=int, help="Number of relations")
parser.add_argument("--gamma", type=int, help="Hyperparameter for embedding")
parser.add_argument("--double_entity_embedding", action="store_true")
parser.add_argument("--double_relation_embedding", action="store_true")
@classmethod
[docs] def build_model_from_args(cls, args):
return cls(args.nentity, args.nrelation, args.embedding_size, args.gamma, args.double_entity_embedding, args.double_relation_embedding)
def __init__(self, nentity, nrelation, hidden_dim, gamma,
double_entity_embedding=False, double_relation_embedding=False):
super(KGEModel, self).__init__()
self.nentity = nentity
self.nrelation = nrelation
self.hidden_dim = hidden_dim
self.epsilon = 2.0
self.gamma = nn.Parameter(
torch.Tensor([gamma]),
requires_grad=False
)
self.embedding_range = nn.Parameter(
torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]),
requires_grad=False
)
self.entity_dim = hidden_dim*2 if double_entity_embedding else hidden_dim
self.relation_dim = hidden_dim*2 if double_relation_embedding else hidden_dim
self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim))
nn.init.uniform_(
tensor=self.entity_embedding,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim))
nn.init.uniform_(
tensor=self.relation_embedding,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
[docs] def forward(self, sample, mode='single'):
'''
Forward function that calculate the score of a batch of triples.
In the 'single' mode, sample is a batch of triple.
In the 'head-batch' or 'tail-batch' mode, sample consists two part.
The first part is usually the positive sample.
And the second part is the entities in the negative samples.
Because negative samples and positive samples usually share two elements
in their triple ((head, relation) or (relation, tail)).
'''
if mode == 'single':
batch_size, negative_sample_size = sample.size(0), 1
head = torch.index_select(
self.entity_embedding,
dim=0,
index=sample[:,0]
).unsqueeze(1)
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=sample[:,1]
).unsqueeze(1)
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=sample[:,2]
).unsqueeze(1)
elif mode == 'head-batch':
tail_part, head_part = sample
batch_size, negative_sample_size = head_part.size(0), head_part.size(1)
head = torch.index_select(
self.entity_embedding,
dim=0,
index=head_part.view(-1)
).view(batch_size, negative_sample_size, -1)
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=tail_part[:, 1]
).unsqueeze(1)
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=tail_part[:, 2]
).unsqueeze(1)
elif mode == 'tail-batch':
head_part, tail_part = sample
batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
head = torch.index_select(
self.entity_embedding,
dim=0,
index=head_part[:, 0]
).unsqueeze(1)
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=head_part[:, 1]
).unsqueeze(1)
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=tail_part.view(-1)
).view(batch_size, negative_sample_size, -1)
else:
raise ValueError('mode %s not supported' % mode)
score = self.score(head, relation, tail, mode)
return score
[docs] def score(self, head, relation, tail, mode):
raise NotImplementedError
@staticmethod
[docs] def train_step(model, optimizer, train_iterator, args):
'''
A single train step. Apply back-propation and return the loss
'''
model.train()
optimizer.zero_grad()
positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)
if args.cuda:
positive_sample = positive_sample.cuda()
negative_sample = negative_sample.cuda()
subsampling_weight = subsampling_weight.cuda()
negative_score = model((positive_sample, negative_sample), mode=mode)
if args.negative_adversarial_sampling:
#In self-adversarial sampling, we do not apply back-propagation on the sampling weight
negative_score = (F.softmax(negative_score * args.adversarial_temperature, dim = 1).detach()
* F.logsigmoid(-negative_score)).sum(dim = 1)
else:
negative_score = F.logsigmoid(-negative_score).mean(dim = 1)
positive_score = model(positive_sample)
positive_score = F.logsigmoid(positive_score).squeeze(dim = 1)
if args.uni_weight:
positive_sample_loss = - positive_score.mean()
negative_sample_loss = - negative_score.mean()
else:
positive_sample_loss = - (subsampling_weight * positive_score).sum()/subsampling_weight.sum()
negative_sample_loss = - (subsampling_weight * negative_score).sum()/subsampling_weight.sum()
loss = (positive_sample_loss + negative_sample_loss)/2
if args.regularization != 0.0:
#Use L3 regularization for ComplEx and DistMult
regularization = args.regularization * (
model.entity_embedding.norm(p = 3)**3 +
model.relation_embedding.norm(p = 3).norm(p = 3)**3
)
loss = loss + regularization
regularization_log = {'regularization': regularization.item()}
else:
regularization_log = {}
loss.backward()
optimizer.step()
log = {
**regularization_log,
'positive_sample_loss': positive_sample_loss.item(),
'negative_sample_loss': negative_sample_loss.item(),
'loss': loss.item()
}
return log
@staticmethod
[docs] def test_step(model, test_triples, all_true_triples, args):
'''
Evaluate the model on test or valid datasets
'''
model.eval()
if True:
#standard (filtered) MRR, MR, HITS@1, HITS@3, and HITS@10 metrics
#Prepare dataloader for evaluation
test_dataloader_head = DataLoader(
TestDataset(
test_triples,
all_true_triples,
args.nentity,
args.nrelation,
'head-batch'
),
batch_size=args.test_batch_size,
collate_fn=TestDataset.collate_fn
)
test_dataloader_tail = DataLoader(
TestDataset(
test_triples,
all_true_triples,
args.nentity,
args.nrelation,
'tail-batch'
),
batch_size=args.test_batch_size,
collate_fn=TestDataset.collate_fn
)
test_dataset_list = [test_dataloader_head, test_dataloader_tail]
logs = []
step = 0
total_steps = sum([len(dataset) for dataset in test_dataset_list])
with torch.no_grad():
for test_dataset in test_dataset_list:
for positive_sample, negative_sample, filter_bias, mode in test_dataset:
if args.cuda:
positive_sample = positive_sample.cuda()
negative_sample = negative_sample.cuda()
filter_bias = filter_bias.cuda()
batch_size = positive_sample.size(0)
score = model((positive_sample, negative_sample), mode)
score += filter_bias
#Explicitly sort all the entities to ensure that there is no test exposure bias
argsort = torch.argsort(score, dim = 1, descending=True)
if mode == 'head-batch':
positive_arg = positive_sample[:, 0]
elif mode == 'tail-batch':
positive_arg = positive_sample[:, 2]
else:
raise ValueError('mode %s not supported' % mode)
for i in range(batch_size):
#Notice that argsort is not ranking
ranking = (argsort[i, :] == positive_arg[i]).nonzero()
assert ranking.size(0) == 1
#ranking + 1 is the true ranking used in evaluation metrics
ranking = 1 + ranking.item()
logs.append({
'MRR': 1.0/ranking,
'MR': float(ranking),
'HITS@1': 1.0 if ranking <= 1 else 0.0,
'HITS@3': 1.0 if ranking <= 3 else 0.0,
'HITS@10': 1.0 if ranking <= 10 else 0.0,
})
if step % args.test_log_steps == 0:
logging.info('Evaluating the model... (%d/%d)' % (step, total_steps))
step += 1
metrics = {}
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs])/len(logs)
return metrics