f77d0c52创建于 2025年5月22日历史提交
import time
from functools import wraps
import torch
import torch_npu
from megatron.training import get_args
from megatron.training.utils import print_rank_0
from megatron.training.initialize import _warmup_jit_function
from mindspeed.core.tensor_parallel.ascend_turbo.initialize import initialize_cfg_from_args
from .utils import extend_seed_all


def _compile_dependencies():
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from megatron.core.datasets.utils import compile_helpers
        compile_helpers()
        print('>>> done with dataset index builder. Compilation time: {:.3f} '
              'seconds'.format(time.time() - start_time), flush=True)


def set_jit_fusion_options_wrapper(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        def _jit_set_nvfuser_enabled(option):
            pass
        torch._C._jit_set_nvfuser_enabled = _jit_set_nvfuser_enabled
        fn(*args, **kwargs)
        args = get_args()
        if args.jit_compile:
            torch_npu.npu.set_compile_mode(jit_compile=True)

    return wrapper


def coc_registration_wrapper(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        res = fn(*args, **kwargs)
        from mindspeed.core.tensor_parallel.lcal_coc.user_config import initialize_coc_from_cfg
        args = get_args()
        initialize_coc_from_cfg(args)
        return res

    return wrapper


def mc2_wrapper(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        res = fn(*args, **kwargs)
        args = get_args()
        initialize_cfg_from_args(args)
        return res

    return wrapper


def deter_comp_wrapper(fn):
    @wraps(fn)
    def wrapper(seed_, data_parallel_random_init=False, te_rng_tracker=False, inference_rng_tracker=False, use_cudagraphable_rng=False):
        fn(seed_, data_parallel_random_init=False)
        extend_seed_all(seed_)
        print_rank_0("deterministic computing is applied for npu.")
    return wrapper