import torch
from torch.distributed.distributed_c10d import ReduceOp
def hccl_premul_sum_wrapper(op, output_name):
def wrapper(*args, **kwargs):
factor = None
if 'op' in kwargs and kwargs['op'] == ReduceOp.PREMUL_SUM:
factor = kwargs['op'].__getstate__()[1]
kwargs['op'] = ReduceOp.SUM
handle = op(*args, **kwargs)
if handle is not None:
handle.wait()
if factor is not None:
output = args[0] if len(args) > 0 else kwargs[output_name]
output.data.mul_(factor)
return handle
return wrapper
def apply_hccl_premul_sum_patch():
torch.distributed.all_reduce = hccl_premul_sum_wrapper(torch.distributed.all_reduce, 'tensor')
torch.distributed.reduce_scatter = hccl_premul_sum_wrapper(torch.distributed.reduce_scatter, 'output')
torch.distributed.reduce_scatter_tensor = hccl_premul_sum_wrapper(torch.distributed.reduce_scatter_tensor, 'output')