sharding_size: auto
sub_modules_to_wrap:
- text_decoder.output_layer
- text_decoder.embedding
- text_decoder.rotary_pos_emb
- text_decoder.decoder.layers.{*}
param_dtype: bf16
reduce_dtype: fp32
cast_forward_inputs: True
ignored_modules:
- image_encoder
recompute_modules:
- text_decoder.decoder.layers.{*}
num_to_forward_prefetch: 0
num_to_backward_prefetch: 0
offload_to_cpu: False