import torch
from torch.distributed.distributed_c10d import ReduceOp
"""
When using EP and the set_gradient_divide_factor/set_reduce_scatter_divide_factor function for gradient
reduction in MOE layers,it is necessary to manually import the apply_hccl_premul_sum_patch function in the file.
TODO:When torch_npu supports the PreMul_Sum operator, the patch needs to be removed.
"""
def hccl_premul_sum_wrapper(op, output_name):
"""
A wrapper for distributed operations to handle ReduceOp.PREMUL_SUM which is not supported in Huawei HCCL.
This wrapper intercepts operations using ReduceOp.PREMUL_SUM and converts them into equivalent
ReduceOp.SUM operations followed by scalar multiplication.
"""
def wrapper(*args, **kwargs):
factor = None
if "op" in kwargs and kwargs["op"] == ReduceOp.PREMUL_SUM:
factor = kwargs["op"].__getstate__()[1]
kwargs["op"] = ReduceOp.SUM
handle = op(*args, **kwargs)
if handle is not None:
handle.wait()
if factor is not None:
output = args[0] if len(args) > 0 else kwargs[output_name]
output.data.mul_(factor)
return handle
return wrapper
def apply_hccl_premul_sum_patch():
torch.distributed.all_reduce = hccl_premul_sum_wrapper(torch.distributed.all_reduce, "tensor")
torch.distributed.reduce_scatter = hccl_premul_sum_wrapper(torch.distributed.reduce_scatter, "output")
torch.distributed.reduce_scatter_tensor = hccl_premul_sum_wrapper(
torch.distributed.reduce_scatter_tensor, "output"
)