from functools import wraps
import torch
import torch_npu
from mindspeed.args_utils import get_full_args
PROFILE_RECORD = None
def train_wrapper(train):
@wraps(train)
def wrapper(*args, **kwargs):
args_ = get_full_args()
if args_.profile:
args_.profile_npu = True
args_.profile = False
else:
args_.profile_npu = False
is_profile = hasattr(args_, 'profile_npu') and args_.profile_npu \
and ((torch.distributed.get_rank() in args_.profile_ranks) or (-1 in args_.profile_ranks))
if is_profile:
global PROFILE_RECORD
active = args_.profile_step_end - args_.profile_step_start
skip_first = args_.profile_step_start
if args_.profile_with_cpu:
activities = [torch_npu.profiler.ProfilerActivity.NPU, torch_npu.profiler.ProfilerActivity.CPU]
else:
activities = [torch_npu.profiler.ProfilerActivity.NPU]
level2level = {
'level0': torch_npu.profiler.ProfilerLevel.Level0,
'level1': torch_npu.profiler.ProfilerLevel.Level1,
'level2': torch_npu.profiler.ProfilerLevel.Level2
}
profiler_level = level2level[args_.profile_level]
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=profiler_level,
l2_cache=False
)
with torch_npu.profiler.profile(
activities=activities,
record_shapes=args_.profile_record_shapes,
profile_memory=args_.profile_with_memory,
with_stack=args_.profile_with_stack,
experimental_config=experimental_config,
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=active, repeat=1, skip_first=skip_first),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(args_.profile_save_path)
) as prof:
PROFILE_RECORD = prof
return train(*args, **kwargs)
return train(*args, **kwargs)
return wrapper
def train_step_wrapper(train_step):
@wraps(train_step)
def wrapper(*args, **kwargs):
args_ = get_full_args()
ret = train_step(*args, **kwargs)
is_profile = args_.profile_npu and (
(torch.distributed.get_rank() in args_.profile_ranks)
or (-1 in args_.profile_ranks)
)
if is_profile:
global PROFILE_RECORD
PROFILE_RECORD.step()
return ret
return wrapper