reshard_after_forward: True

param_dtype: "bf16"

reduce_dtype: "fp32"

output_dtype: "bf16"

cast_forward_inputs: True

num_to_forward_prefetch: 2