from functools import wraps
import torch_npu
def get_nccl_options_add_group_info_wrapper(get_nccl_options):
@wraps(get_nccl_options)
def wrapper(pg_name, nccl_comm_cfgs):
options = get_nccl_options(pg_name, nccl_comm_cfgs)
if hasattr(torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options, 'hccl_config'):
options = options if options is not None else torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
try:
hccl_config = options.hccl_config
hccl_config.update({'group_name': pg_name})
options.hccl_config = hccl_config
except TypeError as e:
pass
return options
return wrapper