"""do infer using wizardcoder"""
import os
import argparse
from mindspore import log as logger
from mindformers import MindFormerConfig
from mindformers.tools.utils import str2bool
from mindformers.core.context import build_context
from wizardcoder_config import WizardCoderConfig
from wizardcoder import WizardCoderLMHeadModel
from wizardcoder_tokenizer import WizardCoderTokenizer
def load_model_and_tokenizer(args):
"""load model and tokenizer using args."""
config = MindFormerConfig(os.path.realpath(args.config_path))
config.context.device_id = args.device_id
build_context(config)
wizard_config = WizardCoderConfig.from_pretrained(os.path.realpath(args.config_path))
wizard_config.use_past = args.use_past
wizard_config.batch_size = args.batch_size
tokenizer = WizardCoderTokenizer(config.processor.tokenizer.vocab_file,
config.processor.tokenizer.merge_file)
model = WizardCoderLMHeadModel(wizard_config)
return model, tokenizer
def main(args):
"""do infer"""
model, tokenizer = load_model_and_tokenizer(args)
prompts = [
[
'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a Python script for this problem:\n\nWrite a python function to find the volume of a triangular prism.\nTest examples:\nassert find_Volume(10,8,6) == 240\nassert find_Volume(3,2,2) == 6\nassert find_Volume(1,2,1) == 1\n\n### Response:'],
[
'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a Python script for this problem:\n\nWrite a function to find sequences of lowercase letters joined with an underscore.\nTest examples:\nassert text_lowercase_underscore("aab_cbbbc")==(\'Found a match!\')\nassert text_lowercase_underscore("aab_Abbbc")==(\'Not matched!\')\nassert text_lowercase_underscore("Aaab_abbbc")==(\'Not matched!\')\n\n### Response:'],
[
'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a Python script for this problem:\n\nWrite a function to check if the given tuple list has all k elements.\nTest examples:\nassert check_k_elements([(4, 4), (4, 4, 4), (4, 4), (4, 4, 4, 4), (4, )], 4) == True\nassert check_k_elements([(7, 7, 7), (7, 7)], 7) == True\nassert check_k_elements([(9, 9), (9, 9, 9, 9)], 7) == False\n\n### Response:'],
[
'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a Python script for this problem:\n\nWrite a function to find the n-th rectangular number.\nTest examples:\nassert find_rect_num(4) == 20\nassert find_rect_num(5) == 30\nassert find_rect_num(6) == 42\n\n### Response:']
]
for idx, prompt in enumerate(prompts):
logger.info(f'==============================Start Case{idx} infer============================')
prompt = prompt * args.batch_size
logger.info(f"prompt: {[prompt]}")
output = model.generate(input_ids=tokenizer(prompt)["input_ids"], use_past=args.use_past,
max_length=args.max_length)
output_decode = tokenizer.decode(output[0])
logger.info(f"output: {[output_decode]}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', default='run_wizardcoder_15b.yaml', type=str,
help='config')
parser.add_argument('--max_length', default=2048, type=int,
help='max length')
parser.add_argument('--batch_size', default=1, type=int,
help='batch_size')
parser.add_argument('--device_id', default=0, type=int,
help='device_id')
parser.add_argument('--use_past', default=True, type=str2bool,
help="use past")
args_ = parser.parse_args()
main(args_)