"""
transform wikitext-2 dataset to mindrecord.
"""
import argparse
import os
import numpy as np
from mindspore.mindrecord import FileWriter
from internlm_tokenizer import InternLMTokenizer
def tokenize_wiki(tokenizer, file_path, seq_length, min_length):
"""tokenize wikitext-2 dataset"""
content = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
stripped_line = line.strip()
if stripped_line:
line_ids = tokenizer(stripped_line)["input_ids"]
if len(line_ids) >= min_length:
content.append(line_ids)
for ids in content:
sample = {}
if len(ids) < seq_length:
ids = np.pad(ids, (0, seq_length - len(ids)), 'constant', constant_values=(2, 2))
else:
ids = ids[:seq_length]
sample['input_ids'] = np.array(ids, dtype=np.int32)
yield sample
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--mindrecord_schema", type=str, default="internlm_wiki")
parser.add_argument("--input_glob", type=str, default="./wikitext-2/wiki.train.tokens")
parser.add_argument("--output_file", type=str, default="./wiki_processed/wiki.mindrecord")
parser.add_argument("--model_file", type=str, default="./tokenizer.model")
parser.add_argument("--file_partition", type=int, default=1)
parser.add_argument("--seq_length", type=int, default=2048)
parser.add_argument("--min_length", type=int, default=50)
args = parser.parse_args()
out_dir, out_file = os.path.split(os.path.abspath(args.output_file))
if not os.path.exists(out_dir):
os.mkdir(out_dir)
schema = {'input_ids': {"type": "int32", "shape": [-1]},}
writer = FileWriter(file_name=args.output_file,
shard_num=args.file_partition)
writer.set_page_size(256*1024*1024)
writer.add_schema(schema, args.mindrecord_schema)
if not os.path.exists(args.model_file):
raise FileNotFoundError(f"file {args.model_file} do not exists.")
transforms_count = 0
word_tokenizer = InternLMTokenizer(vocab_file=args.model_file)
for x in tokenize_wiki(word_tokenizer, args.input_glob, args.seq_length + 1, args.min_length):
transforms_count += 1
writer.write_raw_data([x])
print("Transformed {} records.".format(transforms_count))
writer.commit()
out_file = args.output_file
if args.file_partition > 1:
out_file += '0'
print("Transform finished, output files refer: {}".format(out_file))