静态图使用动态shape微调

🚨 弃用说明

本文档已过时,不再进行维护,并将在 1.6.0 版本下架,其中可能包含过时的信息或已被更新的功能替代。建议参考最新的 官方文档 ,以获取准确的信息。

如果您仍需使用本文档中的内容,请仔细核对其适用性,并结合最新版本的相关资源进行验证。

如有任何问题或建议,请通过 社区Issue 提交反馈。感谢您的理解与支持!

功能

Mindspore在2.3.0版本之后为finetune提供了动态shape能力

当前版本的动态shape仅包含动态SeqLength。基本原理是尽量减小计算的shape。动态SeqLength不再将数据统一padding到SeqLength长度,而是统一padding到该batch内最长数据的长度,既不损失精度又减小了计算量。

finetune开启动态shape,最大的优点在于精度不变的情况下性能可以提高数倍。以通用的alpaca-en-52k数据集为例,静态shape下llama2 7b单机8卡跑完两个Epoch的finetune,最佳性能大约是4小时;启用动态shape后,这个时间可以缩短到37分钟。

使用方式

在yaml中开启动态shape配置并选择动态shape支持的data_loader。

参考配置

train_dataset: &train_dataset
  data_loader:
    type: SFTDataLoader
    dataset_dir: '/path/to/alpaca_data_en_52k.json'
    tokenizer:
      unk_token: '<unk>'
      bos_token: '<s>'
      eos_token: '</s>'
      pad_token: '<unk>'
      type: LlamaTokenizer
      vocab_file: '/path/to/tokenizer.model'
    max_length: 4097
    file_format: json
    dataset_name: multi-insruct-dyn
    shuffle: False
  divisor: 2
  remainder: 1
  input_columns: ['input_ids', 'labels']
  num_parallel_workers: 8
  python_multiprocessing: False
  drop_remainder: True
  batch_size: 2
  repeat: 1
  numa_enabel: False
  prefetch_size: 1
  dynamic_batch: True

data_loader相关参数:

  • type: data_loader类型,如果是SFT数据任务,需要使用SFTDataLoader,否则无法使用SFT_MAP_FUNCTIONS解析数据
  • tokenizer: 同下方processor中的tokenizer。因为动态shape是直接读json进来转成input_ids,所以需要tokenizer能力
  • dataset_name:data_loader类型,根据要解析的数据选择。例如读原始alpaca的json数据,推荐使用multi-instruct-dyn类型;读mindformers中alpaca_converter.py格式化过的json数据时(即转mindrecord前的json数据),推荐使用multi-round-chat-dyn类型。
  • divisor: 动态shape网络张量seqlen的约数(详细文档)
  • remainder: 动态shape网络张量seqlen的余数(详细文档)
  • dynamic_batch: 动态shape开关

model_config相关参数:

  • is_dynamic: 静态图的动态shape开关。需要一并开启,否则会导致CausalMask的shape不匹配

关于自定义prompt

prompt对齐,常用在需要和其它baseline对齐的场景。是否对齐以最终输入模型的input_ids、labels以及其它相关的mask是否和目标对齐了为准。 一般把没对齐的部分decode出来,并修改对应的prompt内容来完成。如果是padding或者mask字符不一致导致的误差,可以通过修改data_loader中的pad_token_id或者map_function_kwargs中ignore_token_id参数解决。

prompt示例

这里分别给出llama2和llama3和原mindrecord数据对齐时候的prompt供参考

llama2:

train_dataset: &train_dataset
  data_loader:
    map_function_kwargs: {"user_prompt":"A chat between a curious urer and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. ", "user_prompt_role":"USER: ", "assistant_prompt_role":"ASSISTANT:"}

llama3:

train_dataset: &train_dataset
  data_loader:
    map_function_kwargs: {"user_prompt":"A chat between a curious urer and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. ", "user_prompt_role":"USER: ", "assistant_prompt_role":" ASSISTANT:", "bos_token":"", "sep_token":" "}