"""mbpp evaluate using online method"""
import json
import argparse
from tqdm import tqdm
import mindspore as ms
from mindformers.generation import GenerationConfig
from wizardcoder_config import WizardCoderConfig
from wizardcoder_tokenizer import WizardCoderTokenizer
from wizardcoder import WizardCoderLMHeadModel
def read_mbpp(path):
"""read mbpp file"""
mbpp_problems = {}
with open(path, "r", encoding="utf-8") as in_file:
for line in in_file:
item = json.loads(line.strip())
mbpp_problems[item["task_id"]] = item
return mbpp_problems
def generate_prompt(input_problem):
"""construct the prompt"""
query_prompt = \
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Create a Python script for this problem:
{input_problem}
### Response:"""
return query_prompt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=1, type=int,
help='batch_size')
parser.add_argument('--seq_length', default=2048, type=int,
help='batch_size')
parser.add_argument('--tokenizer_path', default='/path/mindspore_models/', type=str,
help='tokenizer_path')
parser.add_argument('--model_path', default='/path/mindspore_models/wizardcoder.ckpt', type=str,
help='wizardcoder_model_path')
parser.add_argument('--device_id', default=0, type=int,
help='set device id.')
parser.add_argument('--start_index', default=0, type=int,
help='start_index')
parser.add_argument('--end_index', default=0, type=int,
help='end_index')
parser.add_argument('--output_path', default="", type=str,
help='output_path')
parser.add_argument('--mbpp_path', default="", type=str,
help='mbpp path')
args = parser.parse_args()
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
print(args.start_index, args.end_index)
tokenizer = WizardCoderTokenizer(
vocab_file=args.tokenizer_path + "vocab.json",
merge_file=args.tokenizer_path + "merges.txt"
)
gen_config = GenerationConfig(
do_sample=False,
use_past=True,
max_length=args.seq_length,
eos_token_id=0,
pad_token_id=49152
)
wizardcoder_config = WizardCoderConfig(
batch_size=args.batch_size,
seq_length=args.seq_length,
n_position=8192,
vocab_size=49153,
hidden_size=6144,
num_layers=40,
num_heads=48,
eos_token_id=0,
pad_token_id=49152,
checkpoint_name_or_path=args.model_path,
use_past=True
)
model = WizardCoderLMHeadModel(config=wizardcoder_config)
print("model create success!")
problems = read_mbpp(args.mbpp_path)
task_ids = sorted(problems.keys())[args.start_index: args.end_index]
prompts = []
for task_id in task_ids:
prompt = f"\n{problems[task_id]['text']}\nTest examples:"
if task_id == 493:
test_example = problems[task_id]['test_list'][0]
prompt += f"\ncalculate_polygons(startx, starty, endx, endy, radius)"
else:
for test_example in problems[task_id]['test_list']:
prompt += f"\n{test_example}"
prompts.append(prompt)
num_samples = len(prompts)
print(num_samples)
for i in tqdm(range(num_samples), ncols=0, total=num_samples):
output_file = args.output_path + '/{}.json'.format(args.start_index + i)
prompt = prompts[i].replace(' ', '\t')
prompt_batch = [tokenizer(generate_prompt(prompt))['input_ids']]
print(prompt_batch)
output = model.generate(input_ids=prompt_batch, generation_config=gen_config, max_length=args.seq_length)
decode_output = tokenizer.decode(output)
print(decode_output)
with open(output_file, "w", encoding='utf-8') as f:
json.dump({"output": decode_output}, f)