"""General utilities."""
import os
import json
import re
import time
import stat
import random
import warnings
from functools import wraps
from logging import getLogger
from typing import Optional, Union, List
from itertools import takewhile
from packaging.version import Version as PkgVersion
from collections import defaultdict
import acl
import torch
import torch_npu
import socket
from torch import distributed as dist
import numpy as np
import megatron
from megatron.training import get_args
from megatron.core import mpu
from megatron.core.parallel_state import is_pipeline_last_stage
from tqdm import tqdm
from mindspeed.core.context_parallel.utils import pad_data
from mindspeed.core.context_parallel.get_batch_utils import set_actual_seq_len
from mindspeed.utils import (set_position_ids,
_get_batch_on_this_cp_rank_in_megatron_cp,
_get_batch_on_this_cp_rank_in_hybrid_cp_general,
_get_batch_on_this_cp_rank_in_hybrid_cp,
broadcast_dynamic, _broadcast, get_ring_degree)
from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP
from mindspeed.model.transformer import set_attention_mask
from mindspeed.utils import _get_batch_on_this_tp_y_cp_rank_in_megatron_cp
from mindspeed_llm.tasks.dataset.shared_memory_manager import SharedMemoryManager
try:
from mindspeed.core.pipeline_parallel.dualpipev.dualpipev_schedules import get_post_process_flag
except Exception:
pass
try:
_torch_version = PkgVersion(torch.__version__)
except Exception:
_torch_version = PkgVersion("0.0.0")
logger = getLogger(__name__)
_CAN_RECORD_REGISTRY = {}
WRITE_FILE_DEFAULT_FLAGS = os.O_WRONLY | os.O_CREAT
WRITE_FILE_DEFAULT_MODES = stat.S_IWUSR | stat.S_IRUSR
_MTP_POSITION_ID = None
_MTP_BATCH_LIST = None
_ACTUAL_SEQ_LEN_LIST = None
_ACTUAL_ATTN_RATIO_LIST = []
_ACTUAL_COUNT = 0
ARCH_ALIAS_MAP = {
"bailingmoev2": "bailing_mini",
"phi3": "phi3.5",
"glm4moe": "glm45-moe"
}
def get_attn_ratio(actual_seq_len, seq_length):
first_seq_list = np.array(actual_seq_len[:-1])
last_seq_list = np.array(actual_seq_len[1:])
seq_list_without_first = (last_seq_list - first_seq_list).tolist()
seq_length_list = np.array([actual_seq_len[0], ] + seq_list_without_first)
ratio = 0.5 * sum(seq_length_list * seq_length_list) / (seq_length * seq_length)
return ratio
def clear_actual_attn_ratio():
global _ACTUAL_ATTN_RATIO_LIST, _ACTUAL_COUNT
_ACTUAL_ATTN_RATIO_LIST = []
_ACTUAL_COUNT = 0
def set_actual_seq_len_list(actual_seq_len):
global _ACTUAL_SEQ_LEN_LIST, _ACTUAL_ATTN_RATIO_LIST, _ACTUAL_COUNT
_ACTUAL_SEQ_LEN_LIST = actual_seq_len
args = get_args()
if actual_seq_len is not None and args.log_throughput and is_pipeline_last_stage():
actual_attn_ratio = get_attn_ratio(actual_seq_len, args.seq_length)
_ACTUAL_ATTN_RATIO_LIST.append(actual_attn_ratio)
_ACTUAL_COUNT += 1
def get_actual_seq_len_list():
global _ACTUAL_SEQ_LEN_LIST
return _ACTUAL_SEQ_LEN_LIST
def get_actual_attn_ratio():
global _ACTUAL_ATTN_RATIO_LIST, _ACTUAL_COUNT
return _ACTUAL_ATTN_RATIO_LIST, _ACTUAL_COUNT
def set_mtp_batch_list(mtp_batch_list):
global _MTP_BATCH_LIST
_MTP_BATCH_LIST = mtp_batch_list
def get_mtp_batch_list():
"""Get mtp_batch_list"""
global _MTP_BATCH_LIST
return _MTP_BATCH_LIST
def set_mtp_position_ids(position_ids_mtp):
"""set_postprocess_chunk for mtp position id"""
global _MTP_POSITION_ID
_MTP_POSITION_ID = position_ids_mtp
def get_torch_version():
"""Get torch version from __version__."""
global _torch_version
return _torch_version
def get_mtp_position_ids():
global _MTP_POSITION_ID
if _MTP_POSITION_ID is not None:
return _MTP_POSITION_ID
else:
raise AssertionError("_MTP_POSITION_ID is None")
def _compute_actual_seq_len(origin_seq):
seq = origin_seq.view(-1)
zero_pos = (seq == 0).nonzero()[1:].squeeze(dim=1)
res = zero_pos.tolist()
res.append(len(seq))
return res
def recompute_valid_actual_seq_len(actual_seq_len, micro_batch_size):
if len(actual_seq_len) <= 1:
return actual_seq_len
s = torch.tensor(actual_seq_len)
diffs = s[1:] - s[:-1]
indices = (diffs == 1).nonzero()
if len(indices) < micro_batch_size:
return actual_seq_len
first_continuous = indices[micro_batch_size - 1].item()
return torch.cat([s[:first_continuous + 1], s[-1:]])
def compute_actual_seq_len(origin_seq):
args = get_args()
actual_seq_len = _compute_actual_seq_len(origin_seq)
if args.mtp_num_layers:
seq_len = origin_seq.shape[1]
mtp_res = [actual_seq_len]
for i in range(1, args.mtp_num_layers + 1):
next_actual_seq_len = []
for j in actual_seq_len:
if j % seq_len == 0:
next_actual_seq_len.append(j)
else:
next_actual_seq_len.append(j - i)
mtp_res.append(next_actual_seq_len)
return mtp_res
return actual_seq_len
def regenerate_position_ids(tensor, offset):
if tensor is None:
return None
tensor = tensor.clone()
for i in range(tensor.size(0)):
row = tensor[i]
zero_mask = (row == 0)
if zero_mask.any():
first_zero_idx = torch.argmax(zero_mask.int()).item()
tensor[i, :first_zero_idx] = torch.arange(first_zero_idx)
else:
tensor = tensor - offset
return tensor
def parse_args():
return megatron.training.arguments.parse_args()
def is_rank_0():
"""Check whether it is rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0 or (
torch.distributed.get_rank() % torch.cuda.device_count() == 0
):
return True
else:
return False
else:
return True
def print_rank0_by_args(args, message):
"""Before initialization of distributed, we only print on rank 0."""
if args.rank == 0:
print(message, flush=True)
def get_tune_attention_mask(attention_mask_1d):
args = get_args()
micro_batch_size, seq_length = attention_mask_1d.size()
if args.stage in ['dpo']:
micro_batch_size = attention_mask_1d.shape[0] // 2
attention_mask_1d = attention_mask_1d[:micro_batch_size]
attention_mask = torch.ones((micro_batch_size, seq_length, seq_length),
device=attention_mask_1d.device,
dtype=torch.bool).tril_().view(micro_batch_size, 1, seq_length, seq_length)
if args.tokenizer_padding_side == "left":
attention_mask_1d = attention_mask_1d.view(seq_length, 1, -1)
attention_mask = attention_mask.masked_fill_(attention_mask_1d.bool().bitwise_not_().view(-1, 1, 1, seq_length), value=0)
attention_mask.bitwise_not_()
return attention_mask
def get_batch_on_this_cp_rank_wrapper(fn):
@wraps(fn)
def wrapper(batch):
batch = fn(batch)
args = get_args()
if 'position_ids' in batch:
if args.reset_position_ids:
set_position_ids(batch['position_ids'].transpose(0, 1).contiguous())
else:
set_position_ids(batch['position_ids'])
return batch
return wrapper
def print_args_wrapper(fn):
"""
Add switch for controlling when to print arguments.
"""
@wraps(fn)
def wrapper(title, args, after_validate=False):
if after_validate:
fn(title, args)
return wrapper
def print_args(title, args):
"""
Provide a public func for printing arguments.
"""
if args.rank == 0:
print(f'------------------------ {title} ------------------------', flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print(f'-------------------- end of {title} ---------------------',
flush=True)
def seed_all(seed=1234):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)
torch_npu.npu.manual_seed_all(seed)
torch_npu.npu.manual_seed(seed)
def emit(self, record):
try:
rank = dist.get_rank()
except Exception:
rank = -1
if rank == 0 or rank == -1:
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
def get_device_wrapper(fn):
@wraps(fn)
def wrapper(local_rank=None, *arg, **kwargs):
backend = torch.distributed.get_backend()
if backend == 'hccl':
if local_rank is None:
device = torch.device('npu')
else:
device = torch.device(f'npu:{local_rank}')
else:
device = fn(local_rank)
return device
return wrapper
def unwrap_model_wrapper(fn):
@wraps(fn)
def wrapper(model, module_instances=None):
if not module_instances:
module_instances = megatron.training.utils.ALL_MODULE_WRAPPER_CLASSNAMES
return fn(model, module_instances)
return wrapper
def get_finetune_data_on_this_tp_rank(data_iterator):
try:
ds = next(data_iterator)
except StopIteration as e:
warnings.warn(f"An exception occurred in dataloader: {e}")
data_iterator = iter(data_iterator)
ds = next(data_iterator)
tokens = ds.get('input_ids').long().cuda(non_blocking=True)
args = get_args()
tokens_shape = tokens.shape
micro_batch_size = tokens_shape[0]
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0:
via_length = torch.LongTensor([tokens_shape[1]]).cuda(non_blocking=True)
_broadcast(via_length)
_broadcast(tokens)
attention_mask_1d = ds.get('attention_mask').long().cuda(non_blocking=True)
_broadcast(attention_mask_1d)
attention_mask = get_tune_attention_mask(attention_mask_1d)
else:
via_length = torch.empty((1), dtype=torch.int64, device=torch.cuda.current_device())
_broadcast(via_length)
tokens = torch.empty((micro_batch_size, via_length), dtype=torch.int64, device=torch.cuda.current_device())
_broadcast(tokens)
attention_mask_1d = torch.empty((micro_batch_size, via_length), dtype=torch.int64,
device=torch.cuda.current_device())
_broadcast(attention_mask_1d)
attention_mask = get_tune_attention_mask(attention_mask_1d)
return tokens, attention_mask
_GLOBAL_SHM_MANAGER = None
_SHM_SKIP_FLAG = False
BASE_SHM_NAME = "g_shm"
def reset_sharedmem_mgr():
"""
Reset the shared memory manager and status flags.
"""
global _GLOBAL_SHM_MANAGER, _SHM_SKIP_FLAG
if _GLOBAL_SHM_MANAGER is not None:
try:
_GLOBAL_SHM_MANAGER.close()
except Exception as e:
print(f"[SharedMemoryManager] [WARN] Error during SharedMemoryManager shutdown: {e}")
_GLOBAL_SHM_MANAGER = None
_SHM_SKIP_FLAG = False
def get_sharedmem_mgr(base_shm_name="g_shm", buffer_length=4096):
"""
Retrieve the global shared memory manager for data transfer through shared memory.
:param base_shm_name: Base name of the shared memory
:param buffer_length: Size of the shared memory buffer, default: 4K
:return: `SharedMemoryManager` instance
"""
global _GLOBAL_SHM_MANAGER, _SHM_SKIP_FLAG
if _SHM_SKIP_FLAG:
return None
if _GLOBAL_SHM_MANAGER is not None:
return _GLOBAL_SHM_MANAGER
rank = mpu.get_tensor_model_parallel_rank()
global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1
if not torch.distributed.is_initialized():
print(
f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}]"
f"[Func: get_sharedmem_mgr] <ERROR> "
f"torch.distributed not initialized, skipping..."
)
return None
args = get_args()
reset_position_ids = args.reset_position_ids
enable_shm = args.enable_share_memory
tp_size = mpu.get_tensor_model_parallel_world_size()
device_count = torch.cuda.device_count()
if not (reset_position_ids and enable_shm and tp_size > 1 and tp_size <= device_count):
print(
f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}]"
f"[Func: get_sharedmem_mgr] <INFO> Skip creation. "
f"reset_position_ids={reset_position_ids}, enable_shm={enable_shm}, "
f"tp_size={tp_size}, device_count={device_count}"
)
_SHM_SKIP_FLAG = True
return None
if rank == 0:
pid = os.getpid()
_GLOBAL_SHM_MANAGER = SharedMemoryManager(
base_shm_name, rank0_pid=pid, buffer_length=buffer_length, tp_size=tp_size
)
print(
f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}] <INFO> Created: "
f"{_GLOBAL_SHM_MANAGER.shm_name}, TP_size: {tp_size}, TP_Group: {_GLOBAL_SHM_MANAGER.tp_group_id}"
)
try:
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
except RuntimeError as e:
print(
f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}]"
f"[Func: get_sharedmem_mgr] <ERROR> Barrier timeout: {e}"
)
if rank == 0:
pid = os.getpid()
pid_tensor = torch.tensor([pid], dtype=torch.int32, device="cuda")
torch.distributed.broadcast(pid_tensor, mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
else:
pid_tensor = torch.zeros(1, dtype=torch.int32, device="cuda")
torch.distributed.broadcast(pid_tensor, mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
pid = pid_tensor.item()
_GLOBAL_SHM_MANAGER = SharedMemoryManager(
base_shm_name, rank0_pid=pid, buffer_length=buffer_length, tp_size=tp_size, existing=True
)
print(
f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}] <INFO> Connected to: "
f"{_GLOBAL_SHM_MANAGER.shm_name}, TP_size: {tp_size}, TP_Group: {_GLOBAL_SHM_MANAGER.tp_group_id}"
)
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
return _GLOBAL_SHM_MANAGER
def get_batch_on_this_tp_rank(data_iterator):
args = get_args()
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
shm_manager = None
actual_seq_len = None
if args.enable_share_memory:
shm_manager = get_sharedmem_mgr(BASE_SHM_NAME, args.micro_batch_size * args.seq_length)
if mpu.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
if args.enable_share_memory and shm_manager is not None:
position_ids = data["position_ids"]
actual_seq_len = compute_actual_seq_len(position_ids)
shm_manager.write(actual_seq_len)
if '910B' not in acl.get_soc_name() and args.mtp_num_layers and get_post_process_flag():
from mindspeed_llm.core.transformer.multi_token_prediction import roll_tensor
position_ids_mtp = []
cur_position_id = data["position_ids"]
for _ in range(args.mtp_num_layers):
cur_position_id, _ = roll_tensor(cur_position_id, shifts=-1, dims=-1)
cur_position_id = regenerate_position_ids(cur_position_id, 1)
position_ids_mtp.append(cur_position_id)
set_mtp_position_ids((position_ids_mtp, shm_manager))
if args.return_document_ids and mpu.get_context_parallel_rank() == 0 and mpu.get_pipeline_model_parallel_rank() == 0:
document_ids = [
[x.item() for x in takewhile(lambda y: y.item() != -100, row)]
for row in data['document_ids']
]
data_idx = [
[x.item() for x in takewhile(lambda y: y.item() != -100, row)]
for row in data['idx']
]
data.pop("document_ids", None)
data.pop("idx", None)
batch = {
'tokens': data["tokens"].cuda(non_blocking=True),
'labels': data["labels"].cuda(non_blocking=True),
'loss_mask': data["loss_mask"].cuda(non_blocking=True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking=True),
'position_ids': data["position_ids"].cuda(non_blocking=True),
'document_ids': document_ids,
'idx': data_idx
}
else:
batch = {
'tokens': data["tokens"].cuda(non_blocking=True),
'labels': data["labels"].cuda(non_blocking=True),
'loss_mask': data["loss_mask"].cuda(non_blocking=True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking=True),
'position_ids': data["position_ids"].cuda(non_blocking=True)
}
if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
if args.schedules_method == 'dualpipev':
_broadcast(batch['loss_mask'])
_broadcast(batch['labels'])
elif mpu.is_pipeline_last_stage():
if args.mtp_num_layers or args.schedules_method == 'dualpipev':
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
if args.reset_attention_mask or args.mtp_num_layers or args.schedules_method == 'dualpipev':
_broadcast(batch['position_ids'])
elif args.reset_attention_mask:
_broadcast(batch['position_ids'])
else:
_broadcast(batch['attention_mask'])
if args.reset_attention_mask:
actual_seq_len = broadcast_dynamic(data['actual_seq_len'])
if args.attention_mask_type == 'causal' \
and args.context_parallel_size > 1 \
and args.context_parallel_algo == 'megatron_cp_algo':
actual_seq_len = pad_data(actual_seq_len, batch, args.context_parallel_size, args.tensor_model_parallel_size)
actual_seq_len /= get_ring_degree()
set_actual_seq_len(actual_seq_len)
else:
if args.enable_share_memory and shm_manager is not None:
actual_seq_len = shm_manager.read()
if '910B' not in acl.get_soc_name() and args.mtp_num_layers and get_post_process_flag():
set_mtp_position_ids((None, shm_manager))
tokens = torch.empty((args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device())
labels = torch.empty((args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device())
loss_mask = torch.empty((args.micro_batch_size, args.seq_length),
dtype=torch.float32,
device=torch.cuda.current_device())
if getattr(args, 'create_attention_mask_in_dataloader', False):
attention_mask = torch.empty(
(args.micro_batch_size, 1, args.seq_length, args.seq_length), dtype=torch.bool, device=torch.cuda.current_device()
)
else:
attention_mask = None
position_ids = torch.empty((args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device())
if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
if args.schedules_method == 'dualpipev':
_broadcast(loss_mask)
_broadcast(labels)
else:
labels = None
loss_mask = None
elif mpu.is_pipeline_last_stage():
if args.mtp_num_layers or args.schedules_method == 'dualpipev':
_broadcast(tokens)
else:
tokens = None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
if args.reset_attention_mask or args.mtp_num_layers or args.schedules_method == 'dualpipev':
_broadcast(position_ids)
else:
position_ids = None
else:
tokens = None
labels = None
loss_mask = None
_broadcast(attention_mask)
if args.reset_attention_mask:
_broadcast(position_ids)
else:
position_ids = None
batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
if args.reset_attention_mask:
actual_seq_len = broadcast_dynamic(None)
if args.attention_mask_type == 'causal' \
and args.context_parallel_size > 1 \
and args.context_parallel_algo == 'megatron_cp_algo':
actual_seq_len = pad_data(actual_seq_len, batch, args.context_parallel_size,
args.tensor_model_parallel_size)
actual_seq_len /= get_ring_degree()
set_actual_seq_len(actual_seq_len)
return batch
def get_batch_on_this_cp_rank(batch):
""" Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""
args = get_args()
tp_y_cp_size = TensorParallelYUnionCP().get_parallel_group_world_size() if args.tp_2d else args.context_parallel_size
if not tp_y_cp_size > 1:
return batch
if args.attention_mask_type == 'general' and batch.get("attention_mask", None) is not None:
set_attention_mask(batch['attention_mask'].squeeze())
cp_expanded_by_2d_tp = args.tp_y > 1
if args.context_parallel_algo == 'megatron_cp_algo':
if args.attention_mask_type == 'general':
batch = _get_batch_on_this_cp_rank_in_megatron_cp_general(batch)
elif cp_expanded_by_2d_tp:
batch = _get_batch_on_this_tp_y_cp_rank_in_megatron_cp(batch)
else:
batch = _get_batch_on_this_cp_rank_in_megatron_cp(batch)
elif args.context_parallel_algo == 'ulysses_cp_algo' or args.context_parallel_algo == 'mamba_cp_algo':
batch = _get_batch_on_this_cp_rank_in_ulysses_cp(batch)
elif args.context_parallel_algo == 'hybrid_cp_algo':
if args.attention_mask_type == 'general':
batch = _get_batch_on_this_cp_rank_in_hybrid_cp_general(batch)
else:
batch = _get_batch_on_this_cp_rank_in_hybrid_cp(batch)
return batch
def _get_batch_on_this_cp_rank_in_megatron_cp_general(batch):
cp_rank = mpu.get_context_parallel_rank()
cp_size = mpu.get_context_parallel_world_size()
for key, val in batch.items():
if key == 'attention_mask' and val is not None:
seq_dim = 2 if len(val.shape) == 4 else 0
mask_row = val.chunk(cp_size, dim=seq_dim)[cp_rank].contiguous()
mask_tensor = torch.stack([m.contiguous() for m in mask_row.chunk(cp_size, dim=seq_dim + 1)])
batch[key] = mask_tensor
continue
if val is not None:
seq_dim = 1
val = val.chunk(cp_size, dim=seq_dim)[cp_rank].contiguous()
batch[key] = val
return batch
def _get_batch_on_this_cp_rank_in_ulysses_cp(batch):
cp_rank = mpu.get_context_parallel_rank()
cp_size = mpu.get_context_parallel_world_size()
for key, val in batch.items():
if key == 'attention_mask':
continue
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.chunk(cp_size, dim=seq_dim)[cp_rank].contiguous()
batch[key] = val
return batch
def is_last_rank_wrapper(fn):
@wraps(fn)
def wrapper():
"""
In the context of scale-in training scenarios, use the scale-in world group to determine
if it is the last rank.
"""
from mindspeed_llm.core.high_availability import elastic_training_common
if not elastic_training_common.zit_scale_in_running_state():
return fn()
else:
return torch.distributed.get_rank() == torch.distributed.get_process_group_ranks(
group=elastic_training_common.zit_get_scale_in_world_group())[-1]
return wrapper
def print_rank_last_wrapper(fn):
@wraps(fn)
def wrapper(message):
"""
In the context of scale-in training scenarios, use the get_args().global_batch_size to
replace the batch_size.
"""
from mindspeed_llm.core.high_availability import elastic_training_common
if elastic_training_common.zit_scale_in_running_state():
args = get_args()
from megatron.core.num_microbatches_calculator import get_num_microbatches
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
src_str = f' global batch size: {batch_size:5d} |'
batch_size = get_args().global_batch_size
dest_str = f' global batch size: {batch_size:5d} |'
message = message.replace(src_str, dest_str)
return fn(message)
return wrapper
def is_shared_path(path: str, retry: int = 3, wait: float = 0.5) -> bool:
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return True
hostname = socket.gethostname()
hostnames = [None] * dist.get_world_size()
dist.all_gather_object(hostnames, hostname)
if len(set(hostnames)) == 1:
return True
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
os.makedirs(path, exist_ok=True)
marker_file = os.path.join(path, f".share_test_{hostname}")
try:
if local_rank == 0:
with open(marker_file, "w") as f:
f.write(f"marker from {hostname}")
torch.distributed.barrier()
visible_files = set()
for _ in range(retry):
visible_files = {f for f in os.listdir(path) if f.startswith(".share_test_")}
if len(visible_files) > 1 or world_size == 1:
break
time.sleep(wait)
visible_count = len(visible_files)
visible_tensor = torch.tensor(
[visible_count],
dtype=torch.int,
device="npu"
)
torch.distributed.all_reduce(visible_tensor, op=torch.distributed.ReduceOp.MAX)
total_visible = visible_tensor.item()
if rank == 0:
if total_visible > 1:
logger.info(f"[is_shared_path] Detection result: Shared storage ({path}), detected {total_visible} node marker files.")
shared = True
elif total_visible == 1:
logger.info(f"[is_shared_path] Detection result: Non-shared storage ({path}), only local node can access its own marker.")
shared = False
else:
raise RuntimeError(f"[is_shared_path] Detection failed: No visible marker files, please check mount configuration.")
else:
shared = None
shared = torch.tensor([1 if shared else 0], dtype=torch.int, device="npu")
torch.distributed.broadcast(shared, src=0)
torch.distributed.barrier()
if local_rank == 0 and os.path.exists(marker_file):
os.remove(marker_file)
torch.distributed.barrier()
return bool(shared.item())
except Exception as e:
if rank == 0:
logger.info(f"[is_shared_path] Exception during shared path check: {e}")
raise
def check_model_inputs(func):
"""
Decorator to intercept Router c layer outputs without using hooks.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__)) or {}
if capture_flags:
recordable_keys = {
f"output_{k}": True
for k in capture_flags
}
collected_outputs = defaultdict(tuple)
monkey_patched_layers = []
def make_capture_wrapper(module, orig_forward, key):
@wraps(orig_forward)
def wrapped_forward(*args, **kwargs):
output = orig_forward(*args, **kwargs)
if output[2] is not None:
if key not in collected_outputs:
collected_outputs[key] = (output[2],)
else:
collected_outputs[key] += (output[2],)
return output
return wrapped_forward
if any(recordable_keys.values()):
capture_tasks = []
for key, layer_specs in capture_flags.items():
if not recordable_keys.get(f"output_{key}", False):
continue
if not isinstance(layer_specs, list):
layer_specs = [layer_specs]
for specs in layer_specs:
capture_tasks.append((key, specs))
for name, module in self.named_modules():
for key, specs in capture_tasks:
if (specs is not None and isinstance(module, specs)):
original_forward = module.forward
module.forward = make_capture_wrapper(module, original_forward, key)
monkey_patched_layers.append((module, original_forward))
outputs = func(self, *args, **kwargs)
for module, original_forward in monkey_patched_layers:
module.forward = original_forward
for key in collected_outputs:
globals()[key] = collected_outputs[key]
return outputs
else:
outputs = func(self, *args, **kwargs)
return outputs
return wrapper
def is_distributed_ckpt_complete(
save_path: str,
iteration: int,
weight_filename: str = "model_optim_rng.pt",
) -> bool:
"""
check distributed checkpoint in path completely
"""
args = get_args()
def get_etp_valid_ckpts_list(tp: int, ep: int):
valid = []
if tp % ep == 0:
for tp_rank in range(tp):
ep_rank = tp_rank % ep
valid.append((tp_rank, ep_rank))
elif ep % tp == 0:
for ep_rank in range(ep):
tp_rank = ep_rank % tp
valid.append((tp_rank, ep_rank))
return valid
def _check_ckpt() -> bool:
tp = args.tensor_model_parallel_size
pp = args.pipeline_model_parallel_size
ep = args.expert_model_parallel_size
etp = args.expert_tensor_parallel_size
enable_etp = (etp == 1) and (tp != 1)
iter_dir = os.path.join(save_path, f"iter_{iteration:07d}")
if not os.path.isdir(iter_dir):
return False
if enable_etp and ep > 1:
tp_ep_pairs = get_etp_valid_ckpts_list(tp, ep)
else:
tp_ep_pairs = [
(tp_rank, ep_rank)
for tp_rank in range(tp)
for ep_rank in range(ep)
]
for tp_rank, ep_rank in tp_ep_pairs:
for pp_rank in range(pp):
if ep == 1 and pp == 1:
rank_dir = f"mp_rank_{tp_rank:02d}"
elif pp == 1 and ep != 1:
rank_dir = f"mp_rank_{tp_rank:02d}_{ep_rank:03d}"
elif ep == 1 and pp != 1:
rank_dir = f"mp_rank_{tp_rank:02d}_{pp_rank:03d}"
else:
rank_dir = (
f"mp_rank_{tp_rank:02d}_{pp_rank:03d}_{ep_rank:03d}"
)
weight_path = os.path.join(
iter_dir, rank_dir, weight_filename
)
if not os.path.isfile(weight_path):
return False
return True
if not torch.distributed.is_initialized():
return _check_ckpt()
torch.distributed.barrier()
result = False
if torch.distributed.get_rank() == 0:
result = _check_ckpt()
torch.distributed.barrier()
flag = torch.tensor(int(result), device="npu")
torch.distributed.broadcast(flag, src=0)
return bool(flag.item())
def _normalize_name(name: str) -> str:
name = name.lower()
name = re.sub(r'for.*$', '', name)
if name.startswith("deepseekv"):
name = re.sub(r'^deepseekv(\d+)$', r'deepseek\1', name)
name = re.sub(r'[^a-z0-9]', '', name)
return name
def infer_model_type_from_hf_config(
config_path: str,
choices: List[str]
) -> str:
"""
from architectures of Huggingface config.json to inference model_type_hf
"""
if not os.path.exists(config_path):
raise FileNotFoundError(f"config.json not found: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
architectures = config.get("architectures", [])
if not architectures:
raise RuntimeError(
"Cannot infer model type: 'architectures' field is missing in config.json. "
"Please specify --model-type-hf explicitly."
)
arch_raw = architectures[0]
arch_norm = _normalize_name(arch_raw)
if arch_norm in ARCH_ALIAS_MAP:
return ARCH_ALIAS_MAP[arch_norm]
normalized_choices = {c: _normalize_name(c) for c in choices}
for c, n in normalized_choices.items():
if n == arch_norm:
return c
raise RuntimeError(
"Cannot infer model type from architectures of Huggingface config.json '{arch_row}'. "
"Please specify --model-type-hf explicitly."
)
def auto_coverage(func):
"""
Decide whether to collect coverage based on the START_COVERAGE environment variable.
"""
@wraps(func)
def wrapper(*args, **kwargs):
if os.environ.get('START_COVERAGE', '').lower() != 'true':
return func(*args, **kwargs)
import coverage
cov = coverage.Coverage(data_suffix=f"usecase-{time.time_ns()}_{random.randint(0, 100)}")
cov.start()
try:
return func(*args, **kwargs)
finally:
cov.stop()
cov.save()
return wrapper
def check_pipeline_config(num_layers, pp, vpp_stage, noop_layers):
noop_set = set(int(x) for x in noop_layers.split(","))
all_layers = list(range(num_layers))
layers_per_pp_group = num_layers // pp
for pp_idx in range(pp):
pp_start = pp_idx * layers_per_pp_group
pp_end = pp_start + layers_per_pp_group
pp_layers = all_layers[pp_start:pp_end]
if all(layer in noop_set for layer in pp_layers):
raise ValueError(
f"Interception Error: PP Stage {pp_idx} contains layers {pp_layers} that are all noop_layers!\n"
f"Please re-adjust the PP or noop_layers indices."
)
if vpp_stage:
vpp_size = layers_per_pp_group // vpp_stage
for vpp_idx in range(vpp_size):
v_start = vpp_idx * vpp_stage
v_end = v_start + vpp_stage
vpp_layers = pp_layers[v_start:v_end]
if all(layer in noop_set for layer in vpp_layers):
raise ValueError(
f"Interception Error: VPP Stage {vpp_idx} in PP Stage {pp_idx} consists entirely of empty layers!\n"
f"Corresponding logical layer indices: {vpp_layers}\n"
f"Please modify noop_layers or vpp_stage."
)