from functools import wraps
import os
import torch
import torch_npu
from megatron.training.global_vars import set_args, get_args
from megatron.training.tokenizer.tokenizer import build_tokenizer
from megatron.training.utils import print_rank_0
from megatron.training.training import print_datetime
from mindspeed.auto_settings.auto_settings import AutoSettings
from mindspeed.auto_settings.module.parse.profiling_parse import get_settings, get_model_params
from mindspeed.auto_settings.module.parse.profiling_parse.profiling_node_parse import GatherNodeProfiling
POLICY = None
OPTIMIZED_MBS_LIST = None
PP_SCHEDULE_LIST = None
OPTIMAL_LAYERS = None
ORIGIN_MBS = None
DATA_PARALLEL_SIZE = 1
ENABLE_SCHEDULER = False
FLOPS_COUNTER = None
RECORDED_COUNT = 0
TRAVERSED_COUNT = 0
def auto_settings_fun(argument):
set_args(argument)
print("pretrain_decorator set_args ========================================")
argument = get_args()
working_dir_root = os.path.realpath(argument.auto_settings_work_dir)
if not os.path.exists(working_dir_root) and argument.rank % torch.cuda.device_count() == 0:
os.makedirs(working_dir_root)
if argument.rank % torch.cuda.device_count() == 0:
print("only rank 0 run auto tuning ========================================")
settings = AutoSettings()
settings.auto_setting_fun(argument)
return
def auto_settings_parse_args():
args = get_args()
if not args.vocab_size:
tokenizer = build_tokenizer(args)
args.vocab_size = tokenizer.vocab_size
get_settings(args, args.profile_save_path)
print_rank_0("================OOTB_OPTIMIZER_PARSE_ARGS END EXIT!====================")
return
def auto_settings_parse_model(model, mpu, args):
get_model_params(model, mpu.get_pipeline_model_parallel_rank(), args.profile_save_path, args.context_parallel_size * args.tensor_model_parallel_size * args.data_parallel_size)
print_rank_0("================OOTB_OPTIMIZER_PARSE_MODEL END EXIT!====================")
return
def auto_settings_profile(args):
res_dir = args.profile_save_path
cur_rank = torch.distributed.get_rank()
if res_dir and cur_rank % torch.cuda.device_count() == 0:
GatherNodeProfiling(res_dir).parse_node_pkl(args)
print_datetime('after training is done')
return
def train_decorator(step_fn):
@wraps(step_fn)
def wrapper(*args, **kwargs):
args_ = get_args()
if args_.profile:
args_.profile_npu = True
args_.profile = False
else:
args_.profile_npu = False
if judge_if_profile(args_):
active = args_.profile_step_end - args_.profile_step_start
skip_first = args_.profile_step_start
if args_.profile_with_cpu:
activities = [torch_npu.profiler.ProfilerActivity.NPU, torch_npu.profiler.ProfilerActivity.CPU]
else:
activities = [torch_npu.profiler.ProfilerActivity.NPU]
if args_.profile_level == 'level0':
profiler_level = torch_npu.profiler.ProfilerLevel.Level0
elif args_.profile_level == 'level1':
profiler_level = torch_npu.profiler.ProfilerLevel.Level1
elif args_.profile_level == 'level2':
profiler_level = torch_npu.profiler.ProfilerLevel.Level2
else:
raise ValueError(f"profiler_level only support level0, level1, level2, but gets {args_.profile_level}")
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=profiler_level,
l2_cache=False
)
with torch_npu.profiler.profile(
activities=activities,
record_shapes=args_.profile_record_shapes,
profile_memory=args_.profile_with_memory,
with_stack=args_.profile_with_stack,
experimental_config=experimental_config,
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=active, repeat=1, skip_first=skip_first),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(args_.profile_save_path)
) as prof:
args_.prof = prof
return step_fn(*args, **kwargs)
else:
return step_fn(*args, **kwargs)
return wrapper
def train_step_decorator(step_fn):
@wraps(step_fn)
def wrapper(*args, **kwargs):
args_ = get_args()
flop_count = None
if args_.op_cal_tflops:
flop_count = get_flops_counter()
flop_count.start()
ret = step_fn(*args, **kwargs)
if args_.profile_npu and (torch.distributed.get_rank() in args_.profile_ranks):
args_.prof.step()
if args_.op_cal_tflops:
flop_count = get_flops_counter()
counts = flop_count.get_flops()
set_count(counts)
flop_count.stop()
return ret
return wrapper
def generated_flops_counter():
from torch_npu.utils.flops_count import FlopsCounter
global FLOPS_COUNTER
FLOPS_COUNTER = FlopsCounter()
def get_flops_counter():
global FLOPS_COUNTER
if FLOPS_COUNTER is None:
generated_flops_counter()
return FLOPS_COUNTER
def set_count(count):
global RECORDED_COUNT
global TRAVERSED_COUNT
RECORDED_COUNT = count[0]
TRAVERSED_COUNT = count[1]
def get_count():
global RECORDED_COUNT
global TRAVERSED_COUNT
if RECORDED_COUNT == 0 and TRAVERSED_COUNT == 0:
flops_counter = get_flops_counter()
count = flops_counter.get_flops()
set_count(count)
return RECORDED_COUNT, TRAVERSED_COUNT
def judge_if_profile(args):
if not hasattr(args, 'profile_npu') or not args.profile_npu:
return False
if (torch.distributed.get_rank() in args.profile_ranks) or (-1 in args.profile_ranks):
return True
return False