"""Telechat predict scripts."""
import os
import argparse
import json
import copy
from typing import Optional, Union, List, Dict
import mindspore as ms
from mindspore import Model, Tensor
from mindspore.common import initializer
from mindformers import MindFormerConfig
from mindformers import build_context
from mindformers.tools.utils import str2bool
from mindformers.tools.logger import logger
from mindformers.generation import GenerationConfig
from mindformers.trainer.utils import transform_and_load_checkpoint
from mindformers.core.parallel_config import build_parallel_config
from telechat_tokenizer import TelechatTokenizer
from telechat_config import TelechatConfig
from telechat_predict_utils import History
from telechat import TelechatForCausalLM
def chat(model, tokenizer, question: str = '', history: Union[List[Dict], History] = None,
generation_config: Optional[GenerationConfig] = None):
"""
Args:
tokenizer: the tokenizer of telechat
question: question which the model reply in this turn
history: history which will format the input for telechat
stream: if return the full text at last or yield the text in token
generation_config: configuration for generation
**kwargs: args which will update the generation config or pass to model forward
"""
if not generation_config:
logger.error("generation_config is None")
raise ValueError("generation_config must not be None")
if not question:
logger.error("question is empty")
raise ValueError("question must not be empty")
if history is None:
history = []
generation_config = copy.deepcopy(generation_config)
if not isinstance(history, History):
history = History(tokenizer, history)
inputs = build_inputs_for_chat(tokenizer, question, history, generation_config)
history.append({"role": "user", "content": question})
outputs = model.generate(inputs,
max_length=generation_config.max_decode_length,
do_sample=generation_config.do_sample,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
max_new_tokens=generation_config.max_new_tokens)
response = tokenizer.decode(outputs[0][len(inputs):-1])
history.append({"role": "bot", "content": response})
return response, history
def build_inputs_for_chat(tokenizer, question, history, generation_config):
"""
check history and build inputs here
"""
q_token = tokenizer(question)
qa_history = copy.deepcopy(history)
model_max_length = generation_config.seq_length
build_max_length = max(0, model_max_length - generation_config.max_new_tokens) \
if generation_config.max_new_tokens else max(0, generation_config.max_decode_length)
if build_max_length < 3:
raise ValueError("the model can not meet the requirements of input length,Please check config")
start_id = generation_config.start_token_id
usr_id = generation_config.user_token_id
bot_id = generation_config.bot_token_id
input_tokens = [usr_id] + q_token["input_ids"][-build_max_length + 1:] + [bot_id]
length = len(input_tokens)
while len(qa_history) >= 1:
message = qa_history.pop()
if message["role"] == "user":
tokens = [usr_id] + message["input_ids"]
elif message["role"] == "bot":
tokens = [bot_id] + message["input_ids"] + [generation_config.eos_token_id]
else:
tokens = []
if len(tokens) + length >= build_max_length:
break
else:
input_tokens = tokens + input_tokens
input_tokens = [start_id] + input_tokens
return input_tokens
def main():
"""main function."""
input_questions = []
input_file = open(args.input_file, 'r', encoding='utf-8')
for line in input_file.readlines():
dic = json.loads(line)
input_questions.append(dic["input"])
input_file.close()
config = MindFormerConfig(args.yaml_file)
os.environ['MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST'] = 'InferenceMatmulSplit,PagedAttention'
if args.device_id is not None:
config.context.device_id = args.device_id
if args.checkpoint_path is not None:
config.load_checkpoint = args.checkpoint_path
if args.use_parallel is not None:
config.use_parallel = args.use_parallel
if args.auto_trans_ckpt is not None:
config.auto_trans_ckpt = args.auto_trans_ckpt
if args.src_strategy_path_or_dir is not None:
config.src_strategy_path_or_dir = args.src_strategy_path_or_dir
if args.vocab_file_path is not None:
config.processor.tokenizer.vocab_file = args.vocab_file_path
build_context(config)
build_parallel_config(config)
tokenizer = TelechatTokenizer(config.processor.tokenizer.vocab_file, \
fast_tokenizer=True, trust_remote_code=True)
model_config = config.model.model_config
model_config.parallel_config = config.parallel_config
model_config.batch_size = 1
model_config.use_past = args.use_past
model_config.use_flash_attention = True
model_config.start_token_id = tokenizer.convert_tokens_to_ids(args.start_token)
model_config.user_token_id = tokenizer.convert_tokens_to_ids(args.user_token)
model_config.bot_token_id = tokenizer.convert_tokens_to_ids(args.bot_token)
model_config.max_new_tokens = None
model_config = TelechatConfig(**model_config)
model = TelechatForCausalLM(model_config)
ms_model = Model(model)
logger.info(f"[INFO_config]: {model_config}")
if config.load_checkpoint:
logger.info("----------------Transform and load checkpoint----------------")
seq_length = model_config.seq_length
input_ids = Tensor(shape=(model_config.batch_size, seq_length), dtype=ms.int32, init=initializer.One())
infer_data = model.prepare_inputs_for_predict_layout(input_ids)
transform_and_load_checkpoint(config, ms_model, model, infer_data, do_predict=True)
history = []
for question in input_questions:
logger.info(f"question : {str(question)}")
answer, history = chat(model, tokenizer, question, history, generation_config=model_config)
logger.info(f"answer:, {str(answer)}")
logger.info(f"\n截至目前的聊天记录是:, {str(history)}")
logger.info("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_file', default=None, type=str,
help='input to infer.')
parser.add_argument('--vocab_file_path', default=None, type=str,
help='which model to use.')
parser.add_argument('--checkpoint_path', default=None, type=str,
help='set checkpoint path.')
parser.add_argument('--use_parallel', default=True, type=str2bool,
help='whether use parallel.')
parser.add_argument('--auto_trans_ckpt', default=False, type=str2bool,
help='Auto transform load_checkpoint to load in distributed model.')
parser.add_argument('--src_strategy_path_or_dir', default=None, type=str,
help='set src strategy path.')
parser.add_argument('--use_past', default=True, type=str2bool,
help='whether use past.')
parser.add_argument('--yaml_file', default=None, type=str,
help='predict yaml path')
parser.add_argument('--device_id', default=0, type=int,
help='device id set when run on single card. Default: 0')
parser.add_argument('--start_token', default="<_start>", type=str,
help='start_token')
parser.add_argument('--user_token', default="<_user>", type=str,
help='user_token')
parser.add_argument('--bot_token', default="<_bot>", type=str,
help='bot_token')
args = parser.parse_args()
main()