from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np
import os
import argparse
import sys
sys.path.append(r'KnowledgeGraphEmbedding/codes/')
from run import read_triple
import dataloader
def parse_args(args=None):
parser = argparse.ArgumentParser(
description='Data preprocessing for Knowledge Graph Embedding Models',
usage='RotatE_preprocess.py [<args>] [-h | --help]'
)
parser.add_argument('--data_path', type=str,
default='./KnowledgeGraphEmbedding/data/FB15k-237')
parser.add_argument('--test_batch_size', default=6,
type=int, help='valid/test batch size')
parser.add_argument('-cpu', '--cpu_num', default=10, type=int)
parser.add_argument('--output_path', default='bin', type=str)
parser.add_argument('--output_head_post', default='head/post', type=str)
parser.add_argument('--output_tail_post', default='tail/post', type=str)
parser.add_argument('--output_head_pos', default='head/pos', type=str)
parser.add_argument('--output_head_neg', default='head/neg', type=str)
parser.add_argument('--output_head_mode', default='head/mode', type=str)
parser.add_argument('--output_head_pp', default='head/possamp', type=str)
parser.add_argument('--output_head_np', default='head/negsamp', type=str)
parser.add_argument('--output_tail_pos', default='tail/pos', type=str)
parser.add_argument('--output_tail_neg', default='tail/neg', type=str)
parser.add_argument('--output_tail_mode', default='tail/mode', type=str)
parser.add_argument('--output_tail_pp', default='tail/possamp', type=str)
parser.add_argument('--output_tail_np', default='tail/negsamp', type=str)
parser.add_argument('--nentity', type=int, default=0,
help='DO NOT MANUALLY SET')
parser.add_argument('--nrelation', type=int, default=0,
help='DO NOT MANUALLY SET')
arg = parser.parse_args(args)
arg.output_head_post = os.path.join(arg.output_path, arg.output_head_post)
arg.output_tail_post = os.path.join(arg.output_path, arg.output_tail_post)
arg.output_head_pos = os.path.join(arg.output_path, arg.output_head_pos)
arg.output_head_neg = os.path.join(arg.output_path, arg.output_head_neg)
arg.output_head_mode = os.path.join(arg.output_path, arg.output_head_mode)
arg.output_head_pp = os.path.join(arg.output_path, arg.output_head_pp)
arg.output_head_np = os.path.join(arg.output_path, arg.output_head_np)
arg.output_tail_pos = os.path.join(arg.output_path, arg.output_tail_pos)
arg.output_tail_neg = os.path.join(arg.output_path, arg.output_tail_neg)
arg.output_tail_mode = os.path.join(arg.output_path, arg.output_tail_mode)
arg.output_tail_pp = os.path.join(arg.output_path, arg.output_tail_pp)
arg.output_tail_np = os.path.join(arg.output_path, arg.output_tail_np)
return arg
def main(args):
with open(os.path.join(args.data_path, 'entities.dict')) as fin:
entity2id = dict()
for line in fin:
eid, entity = line.strip().split('\t')
entity2id[entity] = int(eid)
with open(os.path.join(args.data_path, 'relations.dict')) as fin:
relation2id = dict()
for line in fin:
rid, relation = line.strip().split('\t')
relation2id[relation] = int(rid)
args.nentity = len(entity2id)
args.nrelation = len(relation2id)
train_triples = read_triple(os.path.join(
args.data_path, 'train.txt'), entity2id, relation2id)
valid_triples = read_triple(os.path.join(
args.data_path, 'valid.txt'), entity2id, relation2id)
test_triples = read_triple(os.path.join(
args.data_path, 'test.txt'), entity2id, relation2id)
all_true_triples = train_triples + valid_triples + test_triples
test_dataloader_head = DataLoader(
dataloader.TestDataset(
test_triples,
all_true_triples,
args.nentity,
args.nrelation,
'head-batch'
),
batch_size=args.test_batch_size,
num_workers=max(1, args.cpu_num // 2),
collate_fn=dataloader.TestDataset.collate_fn
)
test_dataloader_tail = DataLoader(
dataloader.TestDataset(
test_triples,
all_true_triples,
args.nentity,
args.nrelation,
'tail-batch'
),
batch_size=args.test_batch_size,
num_workers=max(1, args.cpu_num // 2),
collate_fn=dataloader.TestDataset.collate_fn
)
for dirs in [args.output_head_pos, args.output_head_neg, args.output_head_mode, args.output_head_pp,
args.output_tail_pos, args.output_tail_neg, args.output_tail_mode, args.output_tail_pp]:
if not os.path.exists(dirs):
os.makedirs(dirs)
for index, (positive_sample, negative_sample, filter_bias, mode) in enumerate(tqdm(test_dataloader_head, desc="Preprocessing head data...")):
filename = f'bin{args.test_batch_size * index}-{args.test_batch_size * (index + 1) - 1}'
save_path_pos = os.path.join(args.output_head_pos, f'{filename}.bin')
save_path_pos_txt = os.path.join(args.output_head_pp, f'{filename}.txt')
positive_sample.long().numpy().tofile(save_path_pos)
np.savetxt(save_path_pos_txt, positive_sample.long().numpy())
save_path_neg = os.path.join(args.output_head_neg, f'{filename}.bin')
negative_sample.int().numpy().tofile(save_path_neg)
save_post_dir = str(args.output_head_post)
if not os.path.exists(save_post_dir):
os.makedirs(save_post_dir)
save_path_post = os.path.join(save_post_dir, f'{filename}.txt')
np.savetxt(save_path_post, filter_bias.numpy())
for index, (positive_sample, negative_sample, filter_bias, mode) in enumerate(tqdm(test_dataloader_tail, desc="Preprocessing tail data...")):
filename = f'bin{args.test_batch_size * index}-{args.test_batch_size * (index + 1) - 1}'
save_path_pos = os.path.join(args.output_tail_pos, f'{filename}.bin')
save_path_pos_txt = os.path.join(args.output_tail_pp, f'{filename}.txt')
positive_sample.long().numpy().tofile(save_path_pos)
np.savetxt(save_path_pos_txt, positive_sample.long().numpy())
save_path_neg = os.path.join(args.output_tail_neg, f'{filename}.bin')
negative_sample.int().numpy().tofile(save_path_neg)
save_post_dir = str(args.output_tail_post)
if not os.path.exists(save_post_dir):
os.makedirs(save_post_dir)
save_path_post = os.path.join(save_post_dir, f'{filename}.txt')
np.savetxt(save_path_post, filter_bias.numpy())
if __name__ == '__main__':
main(parse_args())