import os
import time
from functools import wraps
from logging import getLogger
import torch
import torch_npu
import megatron
from megatron.core import mpu
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import validate_args
from megatron.training.yaml_arguments import validate_yaml
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.async_utils import init_persistent_async_worker
from megatron.training.global_vars import set_global_variables
from megatron.training.initialize import (
_initialize_distributed, _set_random_seed,
_init_autoresume, _initialize_tp_communicators,
_warmup_jit_function
)
from mindspeed.core.tensor_parallel.ascend_turbo.initialize import initialize_cfg_from_args
from mindspeed_llm.training.arguments import parse_args_decorator
from mindspeed_llm.tasks.utils.error_utils import ensure_valid
from mindspeed_llm.training.utils import seed_all, is_shared_path
from mindspeed_llm.training.checkpointing import _convert_weights_if_needed
from mindspeed_llm.core.datasets.dataset_preprocess import convert_datasets, _is_raw_data_path
logger = getLogger(__name__)
def _compile_dependencies():
"""
Compile dataset index builder dependencies on the first rank of each node.
This function compiles the C++/CUDA helpers for dataset indexing on one GPU
per node to avoid redundant compilation across all ranks.
Note:
Only the first rank on each node (rank % device_count == 0) performs
the compilation to save time in multi-GPU setups.
"""
device_count = torch.cuda.device_count()
if device_count == 0:
raise ZeroDivisionError
if torch.distributed.get_rank() % device_count == 0:
start_time = time.time()
print('> compiling dataset index builder ...')
from megatron.core.datasets.utils import compile_helpers
compile_helpers()
print('>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'.format(time.time() - start_time), flush=True)
def initialize_megatron(
extra_args_provider=None,
args_defaults={},
ignore_unknown_args=False,
allow_no_cuda=False,
skip_mpu_initialization=False,
get_embedding_ranks=None,
get_position_embedding_ranks=None,
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if not allow_no_cuda:
ensure_valid(torch.cuda.is_available(), "Megatron requires CUDA.")
parse_args = parse_args_decorator(megatron.training.arguments.parse_args)
args = parse_args(extra_args_provider, ignore_unknown_args)
if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
ensure_valid(args.load is not None,
"--use-checkpoints-args requires --load argument")
load_args_from_checkpoint(args)
if args.async_save and args.use_persistent_ckpt_worker:
init_persistent_async_worker()
if args.yaml_cfg is not None:
args = validate_yaml(args, args_defaults)
else:
validate_args(args, args_defaults)
set_global_variables(args)
if args.use_deter_comp:
seed_all(args.seed)
print_rank_0("deterministic computing is applied for npu.")
def finish_mpu_init():
args = get_args()
_initialize_distributed(get_embedding_ranks, get_position_embedding_ranks)
if args.rank == 0:
print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init)
if args.use_ascend_mc2:
initialize_cfg_from_args(args)
if skip_mpu_initialization:
return None
args = get_args()
if args.lazy_mpu_init:
args.use_cpu_initialization = True
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
mpu.set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
finish_mpu_init()
_init_autoresume()
_compile_dependencies()
if args.tp_comm_overlap:
_initialize_tp_communicators()
return None
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._debug_set_autodiff_subgraph_inlining(False)
_warmup_jit_function()
args = get_args()
if args.jit_compile:
torch_npu.npu.set_compile_mode(jit_compile=True)
def coc_registration_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
res = fn(*args, **kwargs)
from mindspeed.core.tensor_parallel.lcal_coc.user_config import initialize_coc_from_cfg
args = get_args()
initialize_coc_from_cfg(args)
return res
return wrapper
def initialize_megatron_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
result = fn(*args, **kwargs)
args = get_args()
data_path = getattr(args, "data_path", None)
if data_path:
if isinstance(data_path, (list, tuple)):
raw_path = str(data_path[0])
else:
raw_path = str(data_path).split(",")[0]
raw_path = raw_path.strip().strip('"').strip("'")
if _is_raw_data_path(raw_path):
logger.info("[InitHook] Megatron initialization completed. Starting data preprocessing...")
if os.path.isfile(raw_path):
base_dir = os.path.dirname(raw_path)
elif os.path.isdir(raw_path):
base_dir = raw_path
else:
base_dir = os.path.dirname(raw_path) or "."
shared = is_shared_path(base_dir)
convert_datasets(args, shared)
logger.info("[InitHook] Data preprocessing finished.")
else:
logger.info(f"[InitHook] data_path={raw_path} is not raw. Skip preprocessing.")
else:
logger.info("[InitHook] args.data_path is empty. Skip preprocessing.")
if getattr(args, 'enable_hf2mg_convert', False):
logger.info("[InitHook] Starting weight conversion check...")
if not os.path.exists(args.load):
raise ValueError(f"Specified weight path does not exist: {args.load}")
files = os.listdir(args.load)
if not (
any(f == 'config.json' for f in files) and
any(f.endswith(('.bin', '.safetensors')) and 'model' in f.lower() for f in files)
):
raise ValueError(
f"When enable_hf2mg_convert is enabled, path {args.load} is not a valid HuggingFace path."
)
supported_models = [
'llama2', 'qwen3', 'qwen3-moe', 'deepseek3', 'glm45-air', 'glm45', 'bailing_mini',
'qwen3-next', 'seed-oss', 'deepseek32', 'magistral', 'deepseek2-lite', 'mamba2'
]
if args.model_type_hf not in supported_models:
raise ValueError(
f"Current --enable-hf2mg-convert does not support model type '{args.model_type_hf}'. "
f"Supported models: {', '.join(supported_models)}"
)
if not getattr(args, "mg_save_dir", None):
def _safe_int(attr_name):
val = getattr(args, attr_name, None)
try:
return int(val) if val not in (None, "", 0) else 1
except Exception:
return 1
tp = _safe_int("tensor_model_parallel_size")
pp = _safe_int("pipeline_model_parallel_size")
ep = _safe_int("expert_model_parallel_size")
args.mg_save_dir = os.path.join(args.load, f"megatron_cache_tp{tp}pp{pp}ep{ep}")
os.makedirs(args.mg_save_dir, exist_ok=True)
logger.info(f"[InitHook] Conversion cache path: {args.mg_save_dir}")
shared = is_shared_path(args.mg_save_dir)
logger.info(f"[InitHook] save_dir={args.mg_save_dir}, shared_storage={shared}")
_convert_weights_if_needed(args, shared)
args.load = args.mg_save_dir
logger.info("[InitHook] Weight conversion phase completed.")
if getattr(args, 'enable_mg2hf_convert', False):
shared = is_shared_path(args.save)
if not shared:
raise ValueError(
f"When enable_mg2hf_convert is enabled, path {args.save} should be shared path."
)
return result
return wrapper