from functools import wraps
import torch_npu
from megatron.training.utils import print_rank_0
from mindspeed.args_utils import get_full_args as get_args
from mindspeed.core.hccl_buffer.hccl_adaptive_func import (
hccl_buffer_auto_adaptive,
parse_hccl_buffer_string,
_HCCL_GROUP_BUFFER,
)
def get_nccl_options_wrapper(get_nccl_options):
@wraps(get_nccl_options)
def wrapper(pg_name, nccl_comm_cfgs):
args = get_args()
if args.hccl_group_buffer_adaptive:
global _HCCL_GROUP_BUFFER
if _HCCL_GROUP_BUFFER.get(pg_name) is not None:
options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
options.hccl_config = {"hccl_buffer_size": _HCCL_GROUP_BUFFER[pg_name]}
return options
return get_nccl_options(pg_name, nccl_comm_cfgs)
return wrapper
def hccl_buffer_adaptive_wrapper(initialize_model_parallel):
@wraps(initialize_model_parallel)
def wrapper(*args, **kwargs):
config = get_args()
global _HCCL_GROUP_BUFFER
if config.hccl_group_buffer_adaptive:
hccl_buffer_auto_adaptive(config)
print_rank_0(f"hccl_group_buffer_adaptive: {_HCCL_GROUP_BUFFER}")
return initialize_model_parallel(*args, **kwargs)
return wrapper
def hccl_buffer_set_wrapper(initialize_model_parallel):
@wraps(initialize_model_parallel)
def wrapper(*args, **kwargs):
config = get_args()
if config.hccl_group_buffer is not None:
parse_hccl_buffer_string(config.hccl_group_buffer)
print_rank_0(f"hccl_group_buffer_set: {_HCCL_GROUP_BUFFER}")
return initialize_model_parallel(*args, **kwargs)
return wrapper