""" export mindir model for mindspore lite."""
import argparse
import os
import mindspore as ms
from mindformers.tools import logger
from mindformers.tools.utils import str2bool
from mindformers.tools.utils import get_output_subpath
from visualglm import VisualGLMImageToTextGeneration
from visualglm_config import VisualGLMConfig
def init_context(device_id):
"""
init context
:param device_id: npu device id
"""
ms.set_context(mode=0, device_target="Ascend", device_id=device_id, max_device_memory="30GB", save_graphs=False)
def main(args):
init_context(device_id=args.device_id)
model_config = VisualGLMConfig.from_pretrained(args.config_path)
model_config.max_txt_len = args.seq_length
if args.checkpoint is not None:
print(f"checkpoint: {args.checkpoint}")
model_config.checkpoint_name_or_path = args.checkpoint
if args.batch_size > 1:
model_config.batch_size = args.batch_size
else:
model_config.batch_size = 1
model_config.text_config.batch_size = model_config.batch_size
model_config.text_config.seq_length = args.seq_length + model_config.qformer_config.query_length
model_config.text_config.do_sample = args.do_sample
model_config.text_config.top_p = args.top_p
model_config.text_config.top_k = args.top_k
model_config.text_config.use_past = args.use_past
model = VisualGLMImageToTextGeneration(model_config)
if args.mode == "export":
model.set_train(False)
rank_id = int(os.getenv("RANK_ID", "0"))
model.add_flags_recursive(is_first_iteration=True)
full_inputs = model.llm_model.prepare_inputs_for_export(full_model=True)
save_path = get_output_subpath(
f"mindir_full_checkpoint_bs_{model_config.batch_size}_{model_config.text_config.seq_length}",
rank_id)
ms.export(model.llm_model, *full_inputs, file_name=save_path, file_format='MINDIR')
if model_config.text_config.use_past:
model.add_flags_recursive(is_first_iteration=False)
inc_inputs = model.llm_model.prepare_inputs_for_export(full_model=False)
save_path = get_output_subpath(
f"mindir_inc_checkpoint_bs_{model_config.batch_size}_{model_config.text_config.seq_length}",
rank_id)
ms.export(model.llm_model, *inc_inputs, file_name=save_path, file_format='MINDIR')
logger.info(".........Export Over!.............")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', default="visualglm_6b", type=str, required=False,
help='model type')
parser.add_argument('--config_path', default="run_visualglm_6b_image_to_text_generation.yaml", type=str,
required=False, help='config path')
parser.add_argument('--device_id', type=int, default=1, required=False, help='device id')
parser.add_argument('--batch_size', type=int, default=1, required=False, help='batch_size')
parser.add_argument('--checkpoint', type=str, default=None, required=False, help='checkpoint path')
parser.add_argument('--generate_repeat_time', type=int, default=5, required=False,
help='generate repeat time')
parser.add_argument('--use_past', type=str2bool, default=False, required=False,
help='whether use past')
parser.add_argument('--do_sample', type=str2bool, default=False, required=False,
help='whether do sample')
parser.add_argument('--top_p', type=float, default=1, required=False, help='top p')
parser.add_argument('--top_k', type=int, default=0, required=False, help='top k')
parser.add_argument('--seq_length', type=int, default=32, required=False, help='seq length')
parser.add_argument('--mode', type=str, default="export", required=False, help='')
args_ = parser.parse_args()
print(args_)
main(args_)