a02f348e创建于 2024年12月7日历史提交
general:
  backend: mindspore
  device_category: NPU
  logger:
    level: info
  task:
    local_base_path: ./
    task_id: "ms_quant_sample_pt"
  device_evaluate_before_train: False

register:
  pkg_path: [ "./" ]
  modules:
    - module: "sample_net_mindspore" # 模块导入
      script_network: [ "get_model" ]

pipeline: [ nas ]

nas:
  pipe_step:
    type: SearchPipeStep
  model:
    model_desc:
      type: ScriptModelGen
      common:
        network: get_model
    input_shape: [ 1, 3, 224, 224]
  dataset:
    type: RandomDataset
    common:
      image_size: 224
      batch_size: 1
      channel_size: 3
      img_len: 1000
  search_algorithm:
    type: MsQuantRL
    codec: QuantRLCodec
    policy:
      max_episode: 30   # Max eposide, recommended value>100, bigger is better, but it takes longer to learn.
      num_warmup: 10    # time without training but only filling the replay memory, recommended:10-20
    objective_keys: [ 'accuracy','compress_ratio', 'flops' ]  # accuracy must be one of objective keys
    reward_type: 'compress_first'  # choice: acc_first | compress_first
    custom_reward: False
    metric_to_reward: flops  # metric to calculate reward
    metric_ratio: 0.5
    stop_early: False
    acc_threshold: 0.5
    latency_threshold: 5
    compress_threshold: 40

  search_space:
    type: SearchSpace
    hyperparameters:
      - key: network.bit_candidates
        type: CATEGORY
        range: [ 8, 32 ]

  trainer:
    type: QuantTrainer
    epochs: 1
    seed: 234
    calib_portion: 0.1
    callbacks: [ OptExportCallback ]
    custom_calib:
      pkg_path: /automl/resnet/
      path: /automl/resnet/train.py
      func: train_func
    custom_eval:
      pkg_path: /automl/resnet/
      path: /automl/resnet/train.py
      func: eval_func
      metric_name: "accuracy_top1"