from mindspeed_mm.fsdp.features.memory.grad_offload import GradientOffload, GradientRestore
"""
Gradient Offload Wrapper for Auxiliary Loss Management
Purpose: Manage gradient offloading for auxiliary loss computation.
Usage:
1. In forward propagation: Call `restore_wrapper` where router_logits are generated
2. In loss computation: Call `offload_wrapper` where router_logits are used for auxiliary loss
Workflow:
Generate router_logits → restore_wrapper → ... → offload_wrapper → Compute auxiliary loss
(Restore gradients) (Offload gradients to CPU)
"""
_ROUTER_LOGITS_ID = 0
def offload_wrapper(router_logits):
gate_logits = list(router_logits)
num_hidden_layers = len(gate_logits)
for i, layer_gate in enumerate(gate_logits):
if i == num_hidden_layers - 1:
continue
gate_logits[i] = GradientOffload.apply(layer_gate, i)
return tuple(gate_logits)
def restore_wrapper(router_logits, num_hidden_layers):
global _ROUTER_LOGITS_ID
if _ROUTER_LOGITS_ID == num_hidden_layers - 1:
return router_logits
prefetch_keys = (_ROUTER_LOGITS_ID - 1,) if _ROUTER_LOGITS_ID > 0 else None
router_logits = GradientRestore.apply(router_logits, _ROUTER_LOGITS_ID, prefetch_keys)
_ROUTER_LOGITS_ID += 1
_ROUTER_LOGITS_ID = _ROUTER_LOGITS_ID % num_hidden_layers
return router_logits