from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import average_precision_score
from torch.utils.data import DataLoader
from dataloader import TestDataset
from apex import amp
from util import myprint
from util import time_print
class KGEModel(nn.Module):
def __init__(self, model_name, nentity, nrelation, hidden_dim, gamma,
double_entity_embedding=False, double_relation_embedding=False):
super(KGEModel, self).__init__()
self.model_name = model_name
self.nentity = nentity
self.nrelation = nrelation
self.hidden_dim = hidden_dim
self.epsilon = 2.0
self.gamma = nn.Parameter(
torch.Tensor([gamma]),
requires_grad=False
)
self.embedding_range = nn.Parameter(
torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]),
requires_grad=False
)
self.entity_dim = hidden_dim*2 if double_entity_embedding else hidden_dim
self.relation_dim = hidden_dim*2 if double_relation_embedding else hidden_dim
self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim))
nn.init.uniform_(
tensor=self.entity_embedding,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim))
nn.init.uniform_(
tensor=self.relation_embedding,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
if model_name == 'pRotatE':
self.modulus = nn.Parameter(torch.Tensor([[0.5 * self.embedding_range.item()]]))
if model_name not in ['TransE', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE']:
raise ValueError('model %s not supported' % model_name)
if model_name == 'RotatE' and (not double_entity_embedding or double_relation_embedding):
raise ValueError('RotatE should use --double_entity_embedding')
if model_name == 'ComplEx' and (not double_entity_embedding or not double_relation_embedding):
raise ValueError('ComplEx should use --double_entity_embedding and --double_relation_embedding')
def forward(self, sample, mode='single'):
'''
Forward function that calculate the score of a batch of triples.
In the 'single' mode, sample is a batch of triple.
In the 'head-batch' or 'tail-batch' mode, sample consists two part.
The first part is usually the positive sample.
And the second part is the entities in the negative samples.
Because negative samples and positive samples usually share two elements
in their triple ((head, relation) or (relation, tail)).
'''
if mode == 'single':
batch_size, negative_sample_size = sample.size(0), 1
head = torch.index_select(
self.entity_embedding,
dim=0,
index=sample[:,0]
).unsqueeze(1)
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=sample[:,1]
).unsqueeze(1)
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=sample[:,2]
).unsqueeze(1)
elif mode == 'head-batch':
tail_part, head_part = sample
batch_size, negative_sample_size = head_part.size(0), head_part.size(1)
head = torch.index_select(
self.entity_embedding,
dim=0,
index=head_part.view(-1)
).view(batch_size, negative_sample_size, -1)
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=tail_part[:, 1]
).unsqueeze(1)
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=tail_part[:, 2]
).unsqueeze(1)
elif mode == 'tail-batch':
head_part, tail_part = sample
batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
head = torch.index_select(
self.entity_embedding,
dim=0,
index=head_part[:, 0]
).unsqueeze(1)
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=head_part[:, 1]
).unsqueeze(1)
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=tail_part.view(-1)
).view(batch_size, negative_sample_size, -1)
else:
raise ValueError('mode %s not supported' % mode)
model_func = {
'TransE': self.TransE,
'DistMult': self.DistMult,
'ComplEx': self.ComplEx,
'RotatE': self.RotatE,
'pRotatE': self.pRotatE
}
if self.model_name in model_func:
score = model_func[self.model_name](head, relation, tail, mode)
else:
raise ValueError('model %s not supported' % self.model_name)
return score
def TransE(self, head, relation, tail, mode):
if mode == 'head-batch':
score = head + (relation - tail)
else:
score = (head + relation) - tail
score = self.gamma.item() - torch.norm(score, p=1, dim=2)
return score
def DistMult(self, head, relation, tail, mode):
if mode == 'head-batch':
score = head * (relation * tail)
else:
score = (head * relation) * tail
score = score.sum(dim = 2)
return score
def ComplEx(self, head, relation, tail, mode):
re_head, im_head = torch.chunk(head, 2, dim=2)
re_relation, im_relation = torch.chunk(relation, 2, dim=2)
re_tail, im_tail = torch.chunk(tail, 2, dim=2)
if mode == 'head-batch':
re_score = re_relation * re_tail + im_relation * im_tail
im_score = re_relation * im_tail - im_relation * re_tail
score = re_head * re_score + im_head * im_score
else:
re_score = re_head * re_relation - im_head * im_relation
im_score = re_head * im_relation + im_head * re_relation
score = re_score * re_tail + im_score * im_tail
score = score.sum(dim = 2)
return score
def RotatE(self, head, relation, tail, mode):
pi = 3.14159265358979323846
re_head, im_head = torch.chunk(head, 2, dim=2)
re_tail, im_tail = torch.chunk(tail, 2, dim=2)
phase_relation = relation/(self.embedding_range.item()/pi)
re_relation = torch.cos(phase_relation)
im_relation = torch.sin(phase_relation)
if mode == 'head-batch':
re_score = re_relation * re_tail + im_relation * im_tail
im_score = re_relation * im_tail - im_relation * re_tail
re_score = re_score - re_head
im_score = im_score - im_head
else:
re_score = re_head * re_relation - im_head * im_relation
im_score = re_head * im_relation + im_head * re_relation
re_score = re_score - re_tail
im_score = im_score - im_tail
score = torch.stack([re_score, im_score], dim = 0)
score = score.norm(dim = 0)
score = self.gamma.item() - score.sum(dim = 2)
return score
def pRotatE(self, head, relation, tail, mode):
pi = 3.14159262358979323846
phase_head = head/(self.embedding_range.item()/pi)
phase_relation = relation/(self.embedding_range.item()/pi)
phase_tail = tail/(self.embedding_range.item()/pi)
if mode == 'head-batch':
score = phase_head + (phase_relation - phase_tail)
else:
score = (phase_head + phase_relation) - phase_tail
score = torch.sin(score)
score = torch.abs(score)
score = self.gamma.item() - score.sum(dim = 2) * self.modulus
return score
def train_step(model, optimizer, train_iterator,time_res ,args):
'''
A single train step. Apply back-propation and return the loss
'''
start = time.time()
model.train()
if args.first and args.prof:
with torch.autograd.profiler.profile(use_cuda=True) as prof:
optimizer.zero_grad()
positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)
if args.cuda:
positive_sample = positive_sample.cuda()
negative_sample = negative_sample.cuda()
subsampling_weight = subsampling_weight.cuda()
if args.npu:
positive_sample = positive_sample.npu()
negative_sample = negative_sample.npu()
subsampling_weight = subsampling_weight.npu()
negative_score = model((positive_sample, negative_sample), mode=mode)
if args.negative_adversarial_sampling:
negative_score = (F.softmax(negative_score * args.adversarial_temperature, dim = 1).detach()
* F.logsigmoid(-negative_score)).sum(dim = 1)
else:
negative_score = F.logsigmoid(-negative_score).mean(dim = 1)
positive_score = model(positive_sample)
positive_score = F.logsigmoid(positive_score).squeeze(dim = 1)
if args.uni_weight:
positive_sample_loss = - positive_score.mean()
negative_sample_loss = - negative_score.mean()
else:
positive_sample_loss = - (subsampling_weight * positive_score).sum()/subsampling_weight.sum()
negative_sample_loss = - (subsampling_weight * negative_score).sum()/subsampling_weight.sum()
loss = (positive_sample_loss + negative_sample_loss)/2
if args.regularization != 0.0:
regularization = args.regularization * (
model.entity_embedding.norm(p = 3)**3 +
model.relation_embedding.norm(p = 3).norm(p = 3)**3
)
loss = loss + regularization
regularization_log = {'regularization': regularization.item()}
else:
regularization_log = {}
if args.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
log_file = os.path.join(args.save_path or args.init_checkpoint, "output_{}.prof".format(args.node_id))
print(log_file)
prof.export_chrome_trace(log_file)
else:
optimizer.zero_grad()
positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)
if args.cuda:
positive_sample = positive_sample.cuda()
negative_sample = negative_sample.cuda()
subsampling_weight = subsampling_weight.cuda()
if args.npu:
positive_sample = positive_sample.npu()
negative_sample = negative_sample.npu()
subsampling_weight = subsampling_weight.npu()
negative_score = model((positive_sample, negative_sample), mode=mode)
if args.negative_adversarial_sampling:
negative_score = (F.softmax(negative_score * args.adversarial_temperature, dim=1).detach()
* F.logsigmoid(-negative_score)).sum(dim=1)
else:
negative_score = F.logsigmoid(-negative_score).mean(dim=1)
positive_score = model(positive_sample)
positive_score = F.logsigmoid(positive_score).squeeze(dim=1)
if args.uni_weight:
positive_sample_loss = - positive_score.mean()
negative_sample_loss = - negative_score.mean()
else:
positive_sample_loss = - (subsampling_weight * positive_score).sum() / subsampling_weight.sum()
negative_sample_loss = - (subsampling_weight * negative_score).sum() / subsampling_weight.sum()
loss = (positive_sample_loss + negative_sample_loss) / 2
if args.regularization != 0.0:
regularization = args.regularization * (
model.entity_embedding.norm(p=3) ** 3 +
model.relation_embedding.norm(p=3).norm(p=3) ** 3
)
loss = loss + regularization
regularization_log = {'regularization': regularization.item()}
else:
regularization_log = {}
if args.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
end = time.time()
log = {
**regularization_log,
'positive_sample_loss': positive_sample_loss.item(),
'negative_sample_loss': negative_sample_loss.item(),
'loss': loss.item(),
}
run_time = end - start
fps = args.batch_size * args.world_size / (end - start)
time_print('{:.6f},{:.6f}'.format(run_time,fps), args)
time_res.update(run_time)
return log
def test_step(model, test_triples, all_true_triples, args):
'''
Evaluate the model on test or valid datasets
'''
model.eval()
if args.countries:
sample = list()
y_true = list()
for head, relation, tail in test_triples:
for candidate_region in args.regions:
y_true.append(1 if candidate_region == tail else 0)
sample.append((head, relation, candidate_region))
sample = torch.LongTensor(sample)
if args.cuda:
sample = sample.cuda()
if args.npu:
sample = sample.npu()
with torch.no_grad():
y_score = model(sample).squeeze(1).cpu().numpy()
y_true = np.array(y_true)
auc_pr = average_precision_score(y_true, y_score)
metrics = {'auc_pr': auc_pr}
else:
test_dataset_head = TestDataset(test_triples, all_true_triples, args.nentity, args.nrelation, 'head-batch')
test_dataset_tail = TestDataset(test_triples, all_true_triples, args.nentity, args.nrelation, 'tail-batch')
nw = max(1, args.cpu_num//2)
test_dataloader_head = DataLoader(
test_dataset_head,
batch_size=args.test_batch_size,
shuffle=True,
num_workers=nw,
collate_fn=TestDataset.collate_fn
)
test_dataloader_tail = DataLoader(
test_dataset_tail,
batch_size=args.test_batch_size,
shuffle=True,
num_workers=nw,
collate_fn=TestDataset.collate_fn
)
test_dataset_list = [test_dataloader_head, test_dataloader_tail]
logs = []
step = 0
total_steps = sum([len(dataset) for dataset in test_dataset_list])
with torch.no_grad():
for test_dataset in test_dataset_list:
for positive_sample, negative_sample, filter_bias, mode in test_dataset:
if args.cuda and args.test_cuda:
positive_sample = positive_sample.cuda()
negative_sample = negative_sample.cuda()
filter_bias = filter_bias.cuda()
if args.npu:
positive_sample = positive_sample.npu()
negative_sample = negative_sample.npu()
filter_bias = filter_bias.npu()
batch_size = positive_sample.size(0)
score = model((positive_sample, negative_sample), mode)
score += filter_bias
argsort = torch.argsort(score, dim=1, descending=True)
if mode == 'head-batch':
positive_arg = positive_sample[:, 0]
elif mode == 'tail-batch':
positive_arg = positive_sample[:, 2]
else:
raise ValueError('mode %s not supported' % mode)
for i in range(batch_size):
ranking = (argsort[i, :] == positive_arg[i]).nonzero()
assert ranking.size(0) == 1
ranking = 1 + ranking.item()
logs.append({
'MRR': 1.0 / ranking,
'MR': float(ranking),
'HITS@1': 1.0 if ranking <= 1 else 0.0,
'HITS@3': 1.0 if ranking <= 3 else 0.0,
'HITS@10': 1.0 if ranking <= 10 else 0.0,
})
if step % args.test_log_steps == 0:
logging.info('Evaluating the model... (%d/%d)' % (step, total_steps))
myprint('Evaluating the model... (%d/%d)' % (step, total_steps),args)
step += 1
metrics = {}
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
return metrics