sharding_size: 8
sub_modules_to_wrap:
- predictor.wan_dit.blocks.{*}
- predictor.wan_dit.head
- predictor.vace_dit.vace_blocks.{*}
reshard_after_forward: True
param_dtype: "bf16"
reduce_dtype: "fp32"
offload_to_cpu: False
cast_forward_inputs: True
recompute_modules:
- predictor.wan_dit.blocks.{*}
- predictor.vace_dit.vace_blocks.{*}