"""DeepSeek-V3/R1 predict script"""
import argparse
import mindspore as ms
from mindspore import Model, Tensor
from mindspore.common import initializer
from mindformers import MindFormerConfig
from mindformers import build_context
from mindformers.tools.logger import logger
from mindformers.trainer.utils import transform_and_load_checkpoint
from mindformers.core.parallel_config import build_parallel_config
from mindformers.models.llama.llama_tokenizer_fast import LlamaTokenizerFast
from deepseek3_model_infer import InferenceDeepseekV3ForCausalLM
from deepseek3_config import DeepseekV3Config
def run_predict(args):
"""Deepseek-V3/R1 predict"""
input_questions = [args.input]
yaml_file = args.config
config = MindFormerConfig(yaml_file)
build_context(config)
build_parallel_config(config)
model_config = config.model.model_config
model_config.parallel_config = config.parallel_config
model_config.moe_config = config.moe_config
model_config = DeepseekV3Config(**model_config)
tokenizer = LlamaTokenizerFast(config.processor.tokenizer.vocab_file,
config.processor.tokenizer.tokenizer_file,
unk_token=config.processor.tokenizer.unk_token,
bos_token=config.processor.tokenizer.bos_token,
eos_token=config.processor.tokenizer.eos_token,
fast_tokenizer=True, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
network = InferenceDeepseekV3ForCausalLM(model_config)
ms_model = Model(network)
if config.load_checkpoint:
logger.info("----------------Transform and load checkpoint----------------")
seq_length = model_config.seq_length
input_ids = Tensor(shape=(model_config.batch_size, seq_length), dtype=ms.int32, init=initializer.One())
infer_data = network.prepare_inputs_for_predict_layout(input_ids)
transform_and_load_checkpoint(config, ms_model, network, infer_data, do_predict=True)
inputs = tokenizer(input_questions, max_length=64, padding="max_length")["input_ids"]
outputs = network.generate(inputs,
max_length=1024,
do_sample=False,
top_k=5,
top_p=1,
max_new_tokens=128)
answer = tokenizer.decode(outputs)
print("answer: ", answer)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--config',
type=str,
required=True,
help='YAML config files, such as'
'./research/deepseek3/deepseek3_671b/predict_deepseek3_671b.yaml')
parser.add_argument(
'--input',
type=str,
default="生抽和老抽的区别是什么?")
args_ = parser.parse_args()
run_predict(args_)