from functools import wraps
import torch
from torch.distributed.tensor import distribute_tensor
from mindspeed_mm.models.transformers.global_vars import get_ep_fsdp_group, get_check_moe_func
def finalize_model_grads_wrapper(fn):
@wraps(fn)
def wrapper(model, num_tokens=None):
fn(model, num_tokens)
ep_fsdp_group = get_ep_fsdp_group()
check_moe_func = get_check_moe_func()
if ep_fsdp_group is None:
return
for model_chunk in model:
for name, param in model_chunk.named_parameters():
if param.grad is None or not check_moe_func(name):
continue
local_grad = param.grad.to_local()
torch.distributed.all_reduce(
local_grad,
op=torch.distributed.ReduceOp.AVG,
group=ep_fsdp_group,
async_op=True,
)
dtensor = distribute_tensor(
local_grad,
device_mesh=param.grad.device_mesh,
placements=param.grad.placements,
)
param.grad.copy_(dtensor)
torch.npu.current_stream().synchronize()
return wrapper