"""test wizardcoder pytorch"""
import sys
import time
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
def evaluate(prompts, tokenizer, model, max_length=1024, **kwargs):
"""evaluate function"""
start_time_with_tokenizer = time.time()
inputs = tokenizer(prompts, return_tensors="pt", max_length=max_length, truncation=True, padding=True)
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=1,
top_p=1,
top_k=1,
do_sample=False,
num_beams=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
max_length=max_length,
**kwargs
)
start_time_no_tokenizer = time.time()
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True
)
seq = generation_output.sequences
end_time_no_tokenizer = time.time()
output = tokenizer.batch_decode(seq, skip_special_tokens=True)
end_time_with_tokenizer = time.time()
elapsed_time_with_tokenizer = end_time_with_tokenizer - start_time_with_tokenizer
elapsed_time_no_tokenizer = end_time_no_tokenizer - start_time_no_tokenizer
generate_length = sum([len(item) for item in seq]) - sum([len(ids) for ids in input_ids])
return output, generate_length, elapsed_time_with_tokenizer, elapsed_time_no_tokenizer
def generate_prompt(input_query):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Create a Python script for this problem:
{input_query}
### Response:"""
def main(args, with_prompt=True, load_8bit: bool = False):
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model.config.pad_token_id = tokenizer.pad_token_id
if not load_8bit:
model.half()
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
input_data = [["使用python编写快速排序代码"] * args.batch_size]
for _, instruction in enumerate(input_data):
print('\n开始推理.......')
if with_prompt:
prompt = instruction
else:
prompt = generate_prompt(instruction)
decode_output, generate_length, time_with_tokenizer, time_no_tokenizer = \
evaluate(prompt, tokenizer, model, max_length=args.seq_length)
print("output: \n", decode_output[0])
speed_with_tokenizer = generate_length / time_with_tokenizer
speed_no_tokenizer = generate_length / time_no_tokenizer
print("\n generate length: ", generate_length,
" elapsed_time_with_tokenizer: ", time_with_tokenizer,
" elapsed_time_no_tokenizer: ", time_no_tokenizer,
" speed_with_tokenizer: ", speed_with_tokenizer,
" speed_no_tokenizer: ", speed_no_tokenizer)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default='', type=str,
help='base model')
parser.add_argument('--seq_length', default=2048, type=int,
help='batch_size')
parser.add_argument('--batch_size', default=1, type=int,
help='batch_size')
opt = parser.parse_args()
main(opt)