93ae0ae6创建于 2025年1月3日历史提交
seed: 0
output_dir: './output' # path to save checkpoint/strategy
load_checkpoint: "path/checkpoint.ckpt"
auto_trans_ckpt: False  # If true, auto transform load_checkpoint to load in distributed model
only_save_strategy: False
resume_training: False
use_parallel: False
run_mode: 'finetune'

# trainer config
trainer:
  type: ContrastiveLanguageImagePretrainTrainer
  model_name: 'mllama'

# runner config
runner_config:
  epochs: 1
  batch_size: 1
  sink_mode: True
  sink_size: 1

# optimizer
optimizer:
  type: AdamW
  betas: [0.9,0.999]
  eps: 1.e-8
  weight_decay: 0.0

# lr_schedule
lr_schedule:
  type: CosineWithWarmUpLR
  learning_rate: 1.e-6
  lr_end: 1.e-6
  warmup_ratio: 0.0
  total_steps: -1 # -1 means it will load the total steps of the datase

# default parallel of device num = 8 for Atlas 800T A2
parallel_config:
  data_parallel: 2
  model_parallel: 4
  pipeline_stage: 1
  use_seq_parallel: True
  micro_batch_num: 2
  vocab_emb_dp: True
  gradient_aggregation_group: 4
# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process.
micro_batch_interleave_num: 1

# recompute config
recompute_config:
  recompute: True
  select_recompute: False
  parallel_optimizer_comm_recompute: False
  mp_comm_recompute: True
  recompute_slice_activation: True

callbacks:
  - type: MFLossMonitor
  - type: CheckpointMonitor
    prefix: "mllama_11b"
    save_checkpoint_steps: 20000
    integrated_save: False
    async_save: False
  - type: ObsMonitor

model:
  model_config:
    type: MllamaConfig
    stage: 2 #1--Pretrain Stage; 2--Finetune Stage
    freeze_vision: False
    batch_size: 1
    seq_length: &seq_length 2048
    compute_dtype: "bfloat16"
    layernorm_compute_type: "float32"
    softmax_compute_type: "float32"
    rotary_dtype: "bfloat16"
    param_init_type: "bfloat16"
    ignore_token_id: &ignore_token_id -100
    pad_token_id: 128004
    repetition_penalty: 1
    use_past: False
    block_size: 16
    num_blocks: 512
    is_dynamic: &is_dynamic False
    use_flash_attention: True
    top_k: 0
    top_p: 0.8
    max_decode_length: 50 # 1024
    do_sample: False
    vision_model:
      arch:
        type: MllamaVisionModel
      model_config:
        type: MllamaVisionConfig
        hidden_size: 1280
        intermediate_size: 5120
        num_attention_heads: 16
        image_size: &image_size 560
        patch_size: 14
        hidden_act: "gelu"
        intermediate_layers_indices: [3, 7, 15, 23, 30]
        length_penalty: 1.0
        max_length: 20
        max_num_tiles: 4
        min_length: 0
        model_type: mllama_vision_model
        norm_eps: 1.0e-05
        num_channels: 3
        num_global_layers: 8
        num_hidden_layers: 32
        supported_aspect_ratios: [[1, 1], [1, 2], [1, 3], [1, 4], [2, 1], [2, 2], [3, 1], [4, 1]]
        vision_output_dim: 7680
    text_model:
      arch:
        type: MllamaForCausalLM
      model_config:
        type: MllamaTextConfig
        image_token_index: 128256
        bos_token_id: 128000
        cross_attention_layers: [3, 8, 13, 18, 23, 28, 33, 38]
        eos_token_id: 128001
        hidden_size: 4096
        intermediate_size: 14336
        num_heads: 32
        num_layers: 40
        n_kv_heads: 8
        output_attentions: false
        output_hidden_states: false
        max_position_embedding: 131072
        extend_method: "LLAMA3" # support "None", "PI", "NTK", "LLAMA3"
        scaling_factor:
          factor: 8.0
          low_freq_factor: 1.0
          high_freq_factor: 4.0
          original_max_position_embeddings: 8192
        pad_token_id: 128004
        rms_norm_eps: 1.0e-05
        theta: 500000.0
        vocab_size: 128256
        use_past_shard: False
        checkpoint_name_or_path: ""
        do_sample: False
  arch:
    type: MllamaForConditionalGeneration


# dataset
train_dataset: &train_dataset
  data_loader:
    type: BaseMultiModalDataLoader
    annotation_file: "path/train_data.json"
    shuffle: False
  construct_args_key: ["input_ids", "labels", "pixel_values", "aspect_ratio_mask", "aspect_ratio_ids", "cross_attention_mask"]
  num_parallel_workers: 1
  python_multiprocessing: False
  drop_remainder: True
  batch_size: 1
  repeat: 1
  numa_enable: False
  prefetch_size: 1
  seed: 42
  modal_to_text_transform:
    type: BaseXModalToTextTransform
    max_length: *seq_length
    model_transform_template:
      type: MllamaProcessor
      add_special_tokens: True
      output_columns: ["input_ids", "labels", "pixel_values", "aspect_ratio_mask", "aspect_ratio_ids", "cross_attention_mask"]
      mode: 'train'
      ignore_token_id: *ignore_token_id
      image_mean: [0.48145466, 0.4578275, 0.40821073]
      image_std: [0.26862954,0.26130258, 0.27577711]
      max_num_images: 1
      image_size: *image_size
  tokenizer:
    pad_token: "<|finetune_right_pad_id|>"
    vocab_file: "path/tokenizer.model"
    add_bos_token: True
    type: MllamaTokenizer
train_dataset_task:
  type: ModalToTextSFTDataset
  dataset_config: *train_dataset

runner_wrapper:
  type: MFTrainOneStepCell
  scale_sense: 1.0
  use_clip_grad: True

# mindspore context init config
context:
  mode: 0 #0--Graph Mode; 1--Pynative Mode
  device_target: "Ascend"
  ascend_config:
    precision_mode:
      "must_keep_origin_dtype"
  max_call_depth: 10000
  max_device_memory: "59GB"
  device_id: 0


# parallel context config
parallel:
  parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel
  gradients_mean: False
  enable_alltoall: False
  full_batch: True
  search_mode: "sharding_propagation"
  enable_parallel_optimizer: True