sharding_size: 16
sub_modules_to_wrap:
- predictor.blocks.{*}
- predictor.head
reshard_after_forward: True
param_dtype: "bf16"
reduce_dtype: "fp32"
ignored_modules:
- ae
- text_encoder
cast_forward_inputs: True
recompute_modules:
- predictor.blocks.{*}
num_to_forward_prefetch: 0
num_to_backward_prefetch: 0
offload_to_cpu: False