from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sklearn
import argparse
import json
import os
import random
import numpy as np
import torch
import torch.nn.parallel
import torch.distributed as dist
from torch.utils.data import DataLoader
from model import KGEModel
import model
from dataloader import TrainDataset
from dataloader import BidirectionalOneShotIterator
from apex import amp
from util import myprint
from util import AverageMeter
def parse_args(args=None):
parser = argparse.ArgumentParser(
description='Training and Testing Knowledge Graph Embedding Models',
usage='train.py [<args>] [-h | --help]'
)
parser.add_argument('--cuda', action='store_true', help='use GPU')
parser.add_argument('--npu', action='store_true', help='use NPU')
parser.add_argument('--do_train', action='store_true')
parser.add_argument('--do_valid', action='store_true')
parser.add_argument('--do_test', action='store_true')
parser.add_argument('--evaluate_train', action='store_true', help='Evaluate on training data')
parser.add_argument('--countries', action='store_true', help='Use Countries S1/S2/S3 datasets')
parser.add_argument('--regions', type=int, nargs='+', default=None,
help='Region Id for Countries S1/S2/S3 datasets, DO NOT MANUALLY SET')
parser.add_argument('--data_path', type=str, default=None)
parser.add_argument('--model', default='TransE', type=str)
parser.add_argument('-de', '--double_entity_embedding', action='store_true')
parser.add_argument('-dr', '--double_relation_embedding', action='store_true')
parser.add_argument('-n', '--negative_sample_size', default=128, type=int)
parser.add_argument('-d', '--hidden_dim', default=500, type=int)
parser.add_argument('-g', '--gamma', default=12.0, type=float)
parser.add_argument('-adv', '--negative_adversarial_sampling', action='store_true')
parser.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)
parser.add_argument('-b', '--batch_size', default=1024, type=int)
parser.add_argument('-r', '--regularization', default=0.0, type=float)
parser.add_argument('--test_batch_size', default=4, type=int, help='valid/test batch size')
parser.add_argument('--uni_weight', action='store_true',
help='Otherwise use subsampling weighting like in word2vec')
parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float)
parser.add_argument('-cpu', '--cpu_num', default=10, type=int)
parser.add_argument('-init', '--init_checkpoint', default=None, type=str)
parser.add_argument('-save', '--save_path', default=None, type=str)
parser.add_argument('--max_steps', default=100000, type=int)
parser.add_argument('--warm_up_steps', default=None, type=int)
parser.add_argument('--save_checkpoint_steps', default=10000, type=int)
parser.add_argument('--valid_steps', default=10000, type=int)
parser.add_argument('--log_steps', default=100, type=int, help='train log every xx steps')
parser.add_argument('--test_log_steps', default=1000, type=int, help='valid/test log every xx steps')
parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET')
parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET')
parser.add_argument('--backend', default='nccl', type=str,
help='Name of the backend to use')
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:50000', type=str,
help='url used to set up distributed training')
parser.add_argument('--port', default='50000', type=str,
help='port used to set up distributed training')
parser.add_argument('--distributed', action='store_true', help='Use distributed', default=False)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--test_cuda', action='store_true', help='Use cuda to test', default=False)
parser.add_argument('--apex', action='store_true', help='Use apex',default=False)
parser.add_argument('--apex_level', default='O2', type=str)
parser.add_argument('--loss_scale', default='128.0', type=str)
parser.add_argument('--node_id', default=0, type=int)
parser.add_argument('--prof', action='store_true', help='Use prof', default=False)
parser.add_argument('--fused', action='store_true', help='Use fuse', default=False)
return parser.parse_args(args)
def override_config(args):
'''
Override model and data configuration
'''
with open(os.path.join(args.init_checkpoint, 'config_0.json'), 'r') as fjson:
argparse_dict = json.load(fjson)
args.countries = argparse_dict['countries']
if args.data_path is None:
args.data_path = argparse_dict['data_path']
args.model = argparse_dict['model']
args.double_entity_embedding = argparse_dict['double_entity_embedding']
args.double_relation_embedding = argparse_dict['double_relation_embedding']
args.hidden_dim = argparse_dict['hidden_dim']
args.test_batch_size = argparse_dict['test_batch_size']
def save_model(model, optimizer, save_variable_list, args):
'''
Save the parameters of the model and the optimizer,
as well as some other variables such as step and learning_rate
'''
if args.master_node:
argparse_dict = vars(args)
with open(os.path.join(args.save_path, 'config_{}.json'.format(args.node_id)), 'w') as fjson:
json.dump(argparse_dict, fjson)
torch.save({
**save_variable_list,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict() if optimizer != None else None},
os.path.join(args.save_path, 'checkpoint_{}'.format(args.node_id))
)
if not args.distributed:
entity_embedding = model.entity_embedding.detach().cpu().numpy()
np.save(
os.path.join(args.save_path, 'entity_embedding'),
entity_embedding
)
relation_embedding = model.relation_embedding.detach().cpu().numpy()
np.save(
os.path.join(args.save_path, 'relation_embedding'),
relation_embedding
)
def read_triple(file_path, entity2id, relation2id):
'''
Read triples and map them into ids.
'''
triples = []
with open(file_path) as fin:
for line in fin:
h, r, t = line.strip().split('\t')
triples.append((entity2id[h], relation2id[r], entity2id[t]))
return triples
def log_metrics(mode, step, metrics,args):
'''
Print the evaluation logs
'''
for metric in metrics:
myprint('%s %s at step %d: %f' % (mode, metric, step, metrics[metric]),args)
def main(proc, nprocs,args):
if args.fused:
from apex.optimizers import NpuFusedSGD
from apex.optimizers import NpuFusedAdam
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = args.port
optimizer = None
time_res = AverageMeter(name='time', fmt=':.6f', start_count_index=10)
if args.distributed:
rank = proc
args.local_rank = rank
dist_url = args.dist_url
if args.cuda:
torch.cuda.set_device(rank)
dist_init_method = 'env://'
print(rank)
print('port is:',args.port)
if args.npu:
loc = 'npu:{}'.format(rank)
device = torch.device(loc)
torch.npu.set_device(loc)
dist.init_process_group(backend=args.backend, world_size=args.world_size, rank=rank)
else:
dist.init_process_group(backend=args.backend, init_method=dist_init_method, world_size=args.world_size, rank=rank)
args.master_node = (not args.distributed) or (torch.distributed.is_initialized and torch.distributed.get_rank() == 0)
if args.master_node and args.save_path and not os.path.exists(args.save_path):
os.makedirs(args.save_path)
import time
time.sleep(1)
myprint('is master? '+str(args.master_node),args)
args.node_id = torch.distributed.get_rank() if args.distributed and torch.distributed.is_initialized else 0
print('node id is ',args.node_id)
myprint('local_rank:{}'.format(args.local_rank),args)
if (not args.do_train) and (not args.do_valid) and (not args.do_test):
raise ValueError('one of train/val/test mode must be choosed.')
if args.init_checkpoint:
override_config(args)
elif args.data_path is None:
raise ValueError('one of init_checkpoint/data_path must be choosed.')
if args.do_train and args.save_path is None:
raise ValueError('Where do you want to save your trained model?')
with open(os.path.join(args.data_path, 'entities.dict')) as fin:
entity2id = dict()
for line in fin:
eid, entity = line.strip().split('\t')
entity2id[entity] = int(eid)
with open(os.path.join(args.data_path, 'relations.dict')) as fin:
relation2id = dict()
for line in fin:
rid, relation = line.strip().split('\t')
relation2id[relation] = int(rid)
if args.countries:
regions = list()
with open(os.path.join(args.data_path, 'regions.list')) as fin:
for line in fin:
region = line.strip()
regions.append(entity2id[region])
args.regions = regions
nentity = len(entity2id)
nrelation = len(relation2id)
args.nentity = nentity
args.nrelation = nrelation
myprint('Model: %s' % args.model,args)
myprint('Data Path: %s' % args.data_path,args)
myprint('#entity: %d' % nentity,args)
myprint('#relation: %d' % nrelation,args)
train_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
myprint('#train: %d' % len(train_triples),args)
valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
myprint('#valid: %d' % len(valid_triples),args)
test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
myprint('#test: %d' % len(test_triples),args)
all_true_triples = train_triples + valid_triples + test_triples
kge_model = KGEModel(
model_name=args.model,
nentity=nentity,
nrelation=nrelation,
hidden_dim=args.hidden_dim,
gamma=args.gamma,
double_entity_embedding=args.double_entity_embedding,
double_relation_embedding=args.double_relation_embedding
)
myprint('Model Parameter Configuration:',args)
for name, param in kge_model.named_parameters():
myprint('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)),args)
if args.cuda:
kge_model = kge_model.cuda()
if args.npu:
kge_model = kge_model.npu()
if args.do_train:
train_dataset_head = TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch')
train_dataset_tail = TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch')
nw = max(1, args.cpu_num//2)
if args.distributed:
train_sampler_head = torch.utils.data.distributed.DistributedSampler(train_dataset_head)
train_sampler_tail = torch.utils.data.distributed.DistributedSampler(train_dataset_tail)
train_dataloader_head = DataLoader(
train_dataset_head,
batch_size=args.batch_size,
shuffle=(train_sampler_head is None),
sampler=train_sampler_head,
num_workers=nw,
collate_fn=TrainDataset.collate_fn
)
train_dataloader_tail = DataLoader(
train_dataset_head,
batch_size=args.batch_size,
shuffle=(train_sampler_tail is None),
sampler=train_sampler_tail,
num_workers=nw,
collate_fn=TrainDataset.collate_fn
)
else:
train_dataloader_head = DataLoader(
train_dataset_head,
batch_size=args.batch_size,
shuffle=True,
num_workers=nw,
collate_fn=TrainDataset.collate_fn
)
train_dataloader_tail = DataLoader(
train_dataset_tail,
batch_size=args.batch_size,
shuffle=True,
num_workers=nw,
collate_fn=TrainDataset.collate_fn
)
train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
current_learning_rate = args.learning_rate
if args.fused:
optimizer = NpuFusedAdam(
filter(lambda p: p.requires_grad, kge_model.parameters()),
lr=current_learning_rate
)
else:
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, kge_model.parameters()),
lr=current_learning_rate
)
if args.warm_up_steps:
warm_up_steps = args.warm_up_steps
else:
warm_up_steps = args.max_steps // 2
if args.apex:
myprint('apex settings: apex_level:{} loss_scale:{}'.format(args.apex_level,args.loss_scale),args)
loss_scale = float(args.loss_scale) if args.loss_scale != 'None' else 'dynamic'
print(loss_scale)
if args.fused and args.apex_level != 'O0':
kge_model, optimizer = amp.initialize(kge_model, optimizer, opt_level=args.apex_level, loss_scale=loss_scale,combine_grad=True)
else:
kge_model, optimizer = amp.initialize(kge_model, optimizer, opt_level=args.apex_level, loss_scale=loss_scale)
if args.distributed:
if args.npu:
kge_model = torch.nn.parallel.DistributedDataParallel(kge_model, device_ids=[args.local_rank],broadcast_buffers=False)
else:
kge_model = torch.nn.parallel.DistributedDataParallel(kge_model, device_ids=[args.local_rank])
if args.init_checkpoint:
myprint('Loading checkpoint %s...' % args.init_checkpoint,args)
kge_model_cpu = KGEModel(
model_name=args.model,
nentity=nentity,
nrelation=nrelation,
hidden_dim=args.hidden_dim,
gamma=args.gamma,
double_entity_embedding=args.double_entity_embedding,
double_relation_embedding=args.double_relation_embedding
)
checkpoint = torch.load(os.path.join(args.save_path, 'checkpoint_0'),map_location=torch.device('cpu'))
init_step = checkpoint['step']
state_dict = checkpoint['model_state_dict']
for key in list(state_dict.keys()):
new_key_l = key.split('.')
if new_key_l[0] == 'module':
new_key = new_key_l[1]
else:
new_key = key
state_dict[new_key] = state_dict.pop(key)
kge_model_cpu.load_state_dict(state_dict)
kge_model = kge_model_cpu
if args.cuda:
kge_model = kge_model.cuda()
if args.test_cuda:
kge_model = kge_model.cuda()
if args.npu:
kge_model = kge_model.npu()
current_learning_rate = checkpoint['current_learning_rate']
warm_up_steps = checkpoint['warm_up_steps']
if args.do_train:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
myprint('Initializing Model Finished',args)
else:
myprint('Ramdomly Initializing %s Model...' % args.model,args)
init_step = 0
step = init_step
print('train')
myprint('Start Training...',args)
myprint('init_step = %d' % init_step,args)
myprint('batch_size = %d' % args.batch_size,args)
myprint('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling,args)
myprint('hidden_dim = %d' % args.hidden_dim,args)
myprint('gamma = %f' % args.gamma,args)
myprint('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling),args)
if args.negative_adversarial_sampling:
myprint('adversarial_temperature = %f' % args.adversarial_temperature,args)
if args.do_train:
myprint('learning_rate = %d' % current_learning_rate,args)
training_logs = []
args.first = False
for step in range(init_step, args.max_steps):
if step == 5:
args.first = True
else:
args.first = False
if args.distributed:
train_sampler_head.set_epoch(step)
train_sampler_tail.set_epoch(step)
log = model.train_step(kge_model, optimizer, train_iterator,time_res, args)
training_logs.append(log)
if step >= warm_up_steps:
save_variable_list = {
'step': step,
'current_learning_rate': current_learning_rate,
'warm_up_steps': warm_up_steps
}
save_model(kge_model, optimizer, save_variable_list, args)
current_learning_rate = current_learning_rate / 10
myprint('Change learning_rate to %f at step %d' % (current_learning_rate, step),args)
if args.fused:
optimizer = NpuFusedAdam(
filter(lambda p: p.requires_grad, kge_model.parameters()),
lr=current_learning_rate
)
else:
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, kge_model.parameters()),
lr=current_learning_rate
)
kge_model = KGEModel(
model_name=args.model,
nentity=nentity,
nrelation=nrelation,
hidden_dim=args.hidden_dim,
gamma=args.gamma,
double_entity_embedding=args.double_entity_embedding,
double_relation_embedding=args.double_relation_embedding
)
checkpoint = torch.load(os.path.join(args.save_path, 'checkpoint_{}'.format(args.node_id)),map_location=torch.device('cpu'))
state_dict = checkpoint['model_state_dict']
for key in list(state_dict.keys()):
new_key_l = key.split('.')
if new_key_l[0] == 'module':
new_key = new_key_l[1]
else:
new_key = key
state_dict[new_key] = state_dict.pop(key)
kge_model.load_state_dict(state_dict)
if args.cuda:
kge_model = kge_model.cuda()
if args.npu:
kge_model = kge_model.npu()
if args.apex:
myprint('apex settings: apex_level:{} loss_scale:{}'.format(args.apex_level, args.loss_scale), args)
loss_scale = float(args.loss_scale) if args.loss_scale != 'None' else 'dynamic'
print(loss_scale)
if args.fused:
kge_model, optimizer = amp.initialize(kge_model, optimizer, opt_level=args.apex_level, loss_scale=loss_scale,combine_grad=True)
else:
kge_model, optimizer = amp.initialize(kge_model, optimizer, opt_level=args.apex_level, loss_scale=loss_scale)
if args.distributed:
if args.npu:
kge_model = torch.nn.parallel.DistributedDataParallel(kge_model, device_ids=[args.local_rank],broadcast_buffers=False)
else:
kge_model = torch.nn.parallel.DistributedDataParallel(kge_model, device_ids=[args.local_rank])
warm_up_steps = warm_up_steps * 3
if step % args.save_checkpoint_steps == 0:
save_variable_list = {
'step': step,
'current_learning_rate': current_learning_rate,
'warm_up_steps': warm_up_steps
}
save_model(kge_model, optimizer, save_variable_list, args)
if step % args.log_steps == 0:
metrics = {}
for metric in training_logs[0].keys():
metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
log_metrics('Training average', step, metrics,args)
training_logs = []
if args.do_valid and step % args.valid_steps == 0:
if not args.test_cuda and not args.npu:
save_variable_list = {
'step': step,
'current_learning_rate': current_learning_rate,
'warm_up_steps': warm_up_steps
}
save_model(kge_model, optimizer, save_variable_list, args)
kge_model_cpu = KGEModel(
model_name=args.model,
nentity=nentity,
nrelation=nrelation,
hidden_dim=args.hidden_dim,
gamma=args.gamma,
double_entity_embedding=args.double_entity_embedding,
double_relation_embedding=args.double_relation_embedding
)
checkpoint = torch.load(os.path.join(args.save_path, 'checkpoint_{}'.format(args.node_id)),map_location=torch.device('cpu'))
state_dict = checkpoint['model_state_dict']
for key in list(state_dict.keys()):
new_key_l = key.split('.')
if new_key_l[0] == 'module':
new_key = new_key_l[1]
else:
new_key = key
state_dict[new_key] = state_dict.pop(key)
kge_model_cpu.load_state_dict(state_dict)
test_model = kge_model_cpu
else:
test_model = kge_model
myprint('Evaluating on Valid Dataset...',args)
metrics = model.test_step(test_model, valid_triples, all_true_triples, args)
log_metrics('Valid', step, metrics,args)
save_variable_list = {
'step': step,
'current_learning_rate': current_learning_rate,
'warm_up_steps': warm_up_steps
}
save_model(kge_model, optimizer, save_variable_list, args)
if args.do_valid:
myprint('Evaluating on Valid Dataset...',args)
if not args.test_cuda and not args.npu:
save_variable_list = {
'step': step,
'current_learning_rate': current_learning_rate,
'warm_up_steps': warm_up_steps
}
save_model(kge_model, optimizer, save_variable_list, args)
kge_model_cpu = KGEModel(
model_name=args.model,
nentity=nentity,
nrelation=nrelation,
hidden_dim=args.hidden_dim,
gamma=args.gamma,
double_entity_embedding=args.double_entity_embedding,
double_relation_embedding=args.double_relation_embedding
)
checkpoint = torch.load(os.path.join(args.save_path, 'checkpoint_{}'.format(args.node_id)),map_location=torch.device('cpu'))
state_dict = checkpoint['model_state_dict']
for key in list(state_dict.keys()):
new_key_l = key.split('.')
if new_key_l[0] == 'module':
new_key = new_key_l[1]
else:
new_key = key
state_dict[new_key] = state_dict.pop(key)
kge_model_cpu.load_state_dict(state_dict)
test_model = kge_model_cpu
else:
test_model = kge_model
metrics = model.test_step(test_model, valid_triples, all_true_triples, args)
log_metrics('Valid', step, metrics,args)
if args.do_test:
myprint('Evaluating on Test Dataset...',args)
if not args.test_cuda and not args.npu:
save_variable_list = {
'step': step,
'current_learning_rate': current_learning_rate,
'warm_up_steps': warm_up_steps
}
save_model(kge_model, optimizer, save_variable_list, args)
kge_model_cpu = KGEModel(
model_name=args.model,
nentity=nentity,
nrelation=nrelation,
hidden_dim=args.hidden_dim,
gamma=args.gamma,
double_entity_embedding=args.double_entity_embedding,
double_relation_embedding=args.double_relation_embedding
)
checkpoint = torch.load(os.path.join(args.save_path, 'checkpoint_{}'.format(args.node_id)),map_location=torch.device('cpu'))
state_dict = checkpoint['model_state_dict']
for key in list(state_dict.keys()):
new_key_l = key.split('.')
if new_key_l[0] == 'module':
new_key = new_key_l[1]
else:
new_key = key
state_dict[new_key] = state_dict.pop(key)
kge_model_cpu.load_state_dict(state_dict)
test_model = kge_model_cpu
else:
test_model = kge_model
metrics = model.test_step(test_model, test_triples, all_true_triples, args)
log_metrics('Test', step, metrics,args)
if args.evaluate_train:
save_variable_list = {
'step': step,
'current_learning_rate': current_learning_rate,
'warm_up_steps': warm_up_steps
}
save_model(kge_model, optimizer, save_variable_list, args)
kge_model_cpu = KGEModel(
model_name=args.model,
nentity=nentity,
nrelation=nrelation,
hidden_dim=args.hidden_dim,
gamma=args.gamma,
double_entity_embedding=args.double_entity_embedding,
double_relation_embedding=args.double_relation_embedding
)
checkpoint = torch.load(os.path.join(args.save_path, 'checkpoint_{}'.format(args.node_id)),map_location=torch.device('cpu'))
state_dict = checkpoint['model_state_dict']
for key in list(state_dict.keys()):
new_key_l = key.split('.')
if new_key_l[0] == 'module':
new_key = new_key_l[1]
else:
new_key = key
state_dict[new_key] = state_dict.pop(key)
kge_model_cpu.load_state_dict(state_dict)
myprint('Evaluating on Training Dataset...',args)
metrics = model.test_step(kge_model_cpu, train_triples, all_true_triples, args)
log_metrics('Test', step, metrics,args)
if args.do_train and time_res.avg>0:
myprint('time: {:.6f}'.format(time_res.avg),args)
myprint('fps: {:.6f}'.format(args.batch_size*args.world_size/time_res.avg),args)
if not args.master_node:
os.remove(os.path.join(args.save_path, 'checkpoint_{}'.format(args.node_id)))
else:
print('fine')
if __name__ == '__main__':
import torch.multiprocessing as mp
args = parse_args()
npu = int(os.environ['RANK_ID'])
main(npu, 8, args)