"""generate mindrecord script"""
import os
import argparse
import collections
from random import shuffle
import datasets
import numpy as np
from tqdm import tqdm
from mindspore.mindrecord import FileWriter
from telechat_tokenizer import TelechatTokenizer
class TelechatDataset:
"""TelechatDataset"""
def __init__(self, output_path, seed, dataset_name):
self.output_path = output_path
self.seed = seed
self.raw_datasets = datasets.load_dataset(path="json", data_files=dataset_name)
def get_train_data(self):
dataset = self.raw_datasets["train"]
return dataset
def get_prompt(self, sample):
return "<_user>" + sample['input'] + "<_bot>"
def get_prompt_and_answer(self, sample):
return "<_user>" + sample['input'] + "<_bot>" + sample['output'] + "<_end>"
def write_instance_to_file(writer, instance):
"""write the instance to file"""
input_ids = instance["input_ids"]
labels = instance["labels"]
features = collections.OrderedDict()
features["input_ids"] = np.asarray(input_ids).astype(np.int32)
features["labels"] = np.asarray(labels).astype(np.int32)
writer.write_raw_data([features])
return features
def make_input_mask(labels, tokenizer):
"""generate input mask"""
user_token_id = tokenizer.convert_tokens_to_ids(args.user_token)
bot_token_id = tokenizer.convert_tokens_to_ids(args.bot_token)
end_token_id = tokenizer.convert_tokens_to_ids(args.end_token)
target_labels = np.zeros((1, args.max_length))
indices_user = np.where(np.array(labels) == user_token_id)[0]
indices_bot = np.where(np.array(labels) == bot_token_id)[0]
indices_end = np.where(np.array(labels) == end_token_id)[0]
assert len(indices_user) == len(indices_bot) == len(indices_end)
for i in range(len(indices_bot)):
user_idx = indices_user[i]
bot_idx = indices_bot[i]
end_idx = indices_end[i]
target_labels[0][bot_idx:end_idx + 1] = 1
target_labels[0][user_idx] = 1
return target_labels
def process_dataset(current_dataset, tokenizer, max_seq_len):
"""process dataset."""
dataset = []
all_lines = []
for _, tmp_data in enumerate(current_dataset):
input_data = tmp_data['input']
if not input_data.startswith("<_user>"):
input_data = "<_user>" + input_data
output = tmp_data['output']
if "<_bot>" in input_data:
concat_line = ""
input_turns = input_data.split("<_user>")[1:]
for item in input_turns:
if "<_bot>" in item:
concat_line += "<_user>" + item + "<_end>"
else:
concat_line += "<_user>" + item + "<_bot>"
concat_line += output + "<_end>"
else:
concat_line = str(input_data) + "<_bot>" + str(output) + "<_end>"
assert concat_line.count("<_user>") == concat_line.count("<_bot>") == concat_line.count("<_end>")
all_lines.append(concat_line)
shuffle(all_lines)
previous_corpus_token_cnt = 0
shard = []
padding_out = []
for corpus in tqdm(all_lines):
corpus_ids = tokenizer(corpus)
if previous_corpus_token_cnt + len(corpus_ids["input_ids"]) < max_seq_len:
shard.append(corpus)
previous_corpus_token_cnt += len(corpus_ids["input_ids"])
else:
shard_output = "".join(shard)
shard_output = (args.max_length - previous_corpus_token_cnt) * tokenizer.pad_token + shard_output
assert len(tokenizer(shard_output)["input_ids"]) == max_seq_len
if shard_output.count("<_user>") >= 1:
padding_out.append(shard_output)
if len(corpus_ids["input_ids"]) < max_seq_len:
shard = [corpus]
previous_corpus_token_cnt = len(corpus_ids["input_ids"])
else:
shard = []
previous_corpus_token_cnt = 0
print("prompt length: ", len(padding_out))
for dt in padding_out:
tokens = tokenizer(dt)
tokens['labels'] = make_input_mask(tokens["input_ids"], tokenizer)
dataset.append(tokens)
return dataset
def make_dataset():
"""make dataset."""
raw_dataset = TelechatDataset(args.output_path, args.seed, args.input_dataset_file)
train_dataset = raw_dataset.get_train_data()
tokenizer = TelechatTokenizer(args.vocab_file_path, fast_tokenizer=True,
trust_remote_code=True, padding_side="left")
train_dataset = process_dataset(train_dataset, tokenizer, args.max_length)
print("***** Writing to output files *****")
print("Output File: %s", args.output_dataset_file)
writer = FileWriter(args.output_dataset_file, 1)
data_schema = {"input_ids": {"type": "int32", "shape": [-1]},
"labels": {"type": "int32", "shape": [-1]}}
writer.add_schema(data_schema, "lm-schema")
for dataset in tqdm(train_dataset):
instance = {"input_ids": dataset["input_ids"], "labels": dataset["labels"]}
write_instance_to_file(writer, instance=instance)
writer.commit()
print(">>>> Transform dataset finished <<<<")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dataset_file", type=str, default="")
parser.add_argument("--output_path", type=str, default="")
parser.add_argument('--vocab_file_path', default='', type=str, help='which model to use.')
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--seed", type=int, default=1233)
parser.add_argument("--user_token", type=str, default="<_user>", help="user token")
parser.add_argument("--bot_token", type=str, default="<_bot>", help="bot token")
parser.add_argument("--end_token", type=str, default="<_end>", help="end token")
args = parser.parse_args()
if args.output_path and not os.path.exists(args.output_path):
os.makedirs(args.output_path, exist_ok=True)
args.output_dataset_file = os.path.join(args.output_path, "new_dataset.mindrecord")
make_dataset()