from functools import wraps
import torch
import megatron
from megatron.training import get_args
def initialize_distributed_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
device_count = torch.cuda.device_count()
device = get_args().rank % device_count
torch.cuda.set_device(device)
from mindspeed_llm.core.high_availability import tft_init_controller_processor, ttp_initialize_replica_dp_group
tft_init_controller_processor()
fn(*args, **kwargs)
world_size: int = torch.distributed.get_world_size()
args = megatron.training.get_args()
order = 'tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-cp-ep-pp-dp'
ttp_initialize_replica_dp_group(
args.pipeline_model_parallel_size,
args.tensor_model_parallel_size,
args.context_parallel_size,
args.expert_model_parallel_size,
args.expert_tensor_parallel_size,
world_size,
order
)
return wrapper
def build_train_valid_test_data_iterators_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
res = fn(*args, **kwargs)
from mindspeed_llm.core.high_availability import tft_is_arf_reboot_node
if tft_is_arf_reboot_node():
get_args().do_train = True
return res
return wrapper