reshard_after_forward: True param_dtype: "bf16" reduce_dtype: "fp32" output_dtype: "bf16" cast_forward_inputs: True num_to_forward_prefetch: 2