"""export wizardcoder inc"""
import argparse
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from wizardcoder_config import WizardCoderConfig
from wizardcoder import WizardCoderLMHeadModel
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=1, type=int,
help='batch_size')
parser.add_argument('--seq_length', default=2048, type=int,
help='batch_size')
parser.add_argument('--model_path', default='', type=str,
help='model path')
parser.add_argument('--device_id', default=0, type=int,
help='set device id.')
args = parser.parse_args()
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
batch_size = args.batch_size
seq_length = args.seq_length
config = WizardCoderConfig(
batch_size=batch_size,
seq_length=seq_length,
n_position=8192,
vocab_size=49153,
hidden_size=6144,
num_layers=40,
num_heads=48,
eos_token_id=0,
pad_token_id=49152,
checkpoint_name_or_path=args.model_path,
use_past=True
)
model = WizardCoderLMHeadModel(config)
model.set_train(False)
model.add_flags_recursive(is_first_iteration=True)
input_ids = ms.Tensor(np.ones((batch_size, seq_length)), mstype.int32)
input_position = ms.Tensor([127]*batch_size, mstype.int32)
init_reset = ms.Tensor([False], mstype.bool_)
batch_valid_length = ms.Tensor([[128]*batch_size], mstype.int32)
ms.export(model, input_ids, None, None, input_position, init_reset, batch_valid_length,
file_name=f"wizardcoder-15b_mslite_inc/prefill_seq{seq_length}_bs{batch_size}", file_format="MINDIR")
model.add_flags_recursive(is_first_iteration=False)
input_ids = ms.Tensor(np.ones((batch_size, 1)), mstype.int32)
input_position = ms.Tensor([128]*batch_size, mstype.int32)
init_reset = ms.Tensor([True], mstype.bool_)
batch_valid_length = ms.Tensor([[129]*batch_size], mstype.int32)
ms.export(model, input_ids, None, None, input_position, init_reset, batch_valid_length,
file_name=f"wizardcoder-15b_mslite_inc/decode_seq{seq_length}_bs{batch_size}", file_format="MINDIR")