import functools
import os
import time
import uuid
from collections.abc import Callable, Iterable
from typing import Optional
import torch
from torch.distributed.distributed_c10d import _world as distributed_world
from torch_npu._C._profiler import (
_finalize_profiler,
_get_freq,
_get_monotonic,
_get_syscnt,
_get_syscnt_enable,
_init_profiler,
_start_profiler,
_stop_profiler,
_supported_npu_activities,
_warmup_profiler,
NpuProfilerConfig,
ProfilerActivity,
)
from torch_npu.npu import _lazy_init, Event
from torch_npu.utils.collect_env import get_cann_version, get_torch_npu_version
from ._profiler_gc_detect import ProfGCDetector
from ._profiler_path_creator import ProfPathCreator
from .analysis._npu_profiler import NpuProfiler
from .analysis.prof_common_func._constant import Constant, print_warn_msg
from .analysis.prof_common_func._file_manager import FileManager
from .analysis.prof_common_func._log import ProfilerLogger
from .analysis.prof_common_func._utils import (
check_msprof_env,
collect_env_vars,
no_exception_func,
)
from .experimental_config import _ExperimentalConfig
from ._flops_hook import FlopsHookManager
from .scheduler import ProfilerAction
__all__ = ["supported_activities"]
def _enable_event_record():
def record_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with torch.profiler.record_function(f"Event::{func.__name__}"):
out = func(*args, **kwargs)
return out
wrapper.origin_func = func
return wrapper
Event.record = record_wrapper(Event.record)
Event.wait = record_wrapper(Event.wait)
Event.query = record_wrapper(Event.query)
Event.elapsed_time = record_wrapper(Event.elapsed_time)
Event.synchronize = record_wrapper(Event.synchronize)
def _disable_event_record():
Event.record = getattr(Event.record, "origin_func", Event.record)
Event.wait = getattr(Event.wait, "origin_func", Event.wait)
Event.query = getattr(Event.query, "origin_func", Event.query)
Event.elapsed_time = getattr(Event.elapsed_time, "origin_func", Event.elapsed_time)
Event.synchronize = getattr(Event.synchronize, "origin_func", Event.synchronize)
class _ProfInterface:
PARALLEL_GROUP_KEY = "parallel_group_info"
TRACE_ID_KEY = "trace_id"
MAX_TRACE_ID_LEN = 1024
def __init__(
self,
activities: Optional[Iterable[ProfilerActivity]] = None,
record_shapes: bool = False,
profile_memory: bool = False,
with_stack: bool = False,
with_flops: bool = False,
with_modules: bool = False,
schedule: Optional[Callable[[int], ProfilerAction]] = None,
metadata: Optional[dict] = None,
experimental_config: Optional[_ExperimentalConfig] = None,
custom_trace_id_callback: Optional[Callable[[], str]] = None,
) -> None:
self._is_env_valid = check_msprof_env()
self.prof_path = ""
self.syscnt_enable = False
self.freq = 100
self.start_cnt = 0
self.start_monotonic = 0
self.activities = set(activities) if activities else supported_activities()
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.with_flops = with_flops
self.with_modules = with_modules
if experimental_config is None:
experimental_config = _ExperimentalConfig()
self.experimental_config = experimental_config
self.schedule = schedule
self.metadata = metadata
self.custom_trace_id_callback = custom_trace_id_callback
self.trace_id = ""
self.gc_detector = None
self._check_params()
if self._is_env_valid:
_lazy_init()
def init_trace(self):
if not self._is_env_valid:
return
ProfPathCreator().create_prof_dir()
self.prof_path = ProfPathCreator().get_prof_dir()
_init_profiler(self.prof_path, self.activities)
self.trace_id = self.create_trace_id()
def warmup_trace(self):
if not self._is_env_valid:
return
prof_config = [
self.prof_path,
self.record_shapes,
self.profile_memory,
self.with_stack,
self.with_flops,
self.with_modules,
self.experimental_config(),
]
npu_prof_config = NpuProfilerConfig(*tuple(prof_config))
_warmup_profiler(npu_prof_config, self.activities)
def start_trace(self):
if not self._is_env_valid:
return
prof_config = [
self.prof_path,
self.record_shapes,
self.profile_memory,
self.with_stack,
self.with_flops,
self.with_modules,
self.experimental_config(),
]
npu_prof_config = NpuProfilerConfig(*tuple(prof_config))
self.syscnt_enable = _get_syscnt_enable()
if self.syscnt_enable:
self.freq = _get_freq()
self.start_cnt = _get_syscnt()
self.start_monotonic = _get_monotonic()
_enable_event_record()
_start_profiler(npu_prof_config, self.activities)
if self.with_flops and self.experimental_config._msprof_tx:
FlopsHookManager.install()
self.start_gc_detect()
def stop_trace(self):
if not self._is_env_valid:
return
if ProfilerActivity.NPU in self.activities:
torch.npu.synchronize()
if self.with_flops and self.experimental_config._msprof_tx:
FlopsHookManager.uninstall()
_stop_profiler()
self.stop_gc_detect()
_disable_event_record()
def finalize_trace(self):
if not self._is_env_valid:
return
_finalize_profiler()
self._dump_profiler_info()
self._dump_metadata()
ProfPathCreator().is_prof_inited = False
ProfilerLogger.destroy()
def delete_prof_dir(self):
if not self._is_env_valid:
return
ProfPathCreator().delete_prof_dir()
def analyse(
self,
analysis_type: str = Constant.TENSORBOARD_TRACE_HANDLER,
output_path: Optional[str] = None,
**kwargs,
):
if not self._is_env_valid:
return
try:
NpuProfiler.analyse(self.prof_path, analysis_type, output_path, **kwargs)
except Exception as e:
print_warn_msg(f"Profiling data parsing failed, error: {e}")
def check_gc_detect_enable(self):
return (
ProfilerActivity.CPU in self.activities and self.experimental_config.with_gc
)
def start_gc_detect(self):
if self.check_gc_detect_enable():
self.gc_detector = ProfGCDetector(
self.experimental_config.gc_detect_threshold
)
self.gc_detector.start()
def stop_gc_detect(self):
if self.check_gc_detect_enable() and self.gc_detector is not None:
self.gc_detector.stop()
self.gc_detector = None
def default_trace_id(self):
uuid_raw = uuid.uuid4()
return f"{uuid_raw.int:032X}"
def create_trace_id(self):
if not self.custom_trace_id_callback:
return self.default_trace_id()
if not isinstance(self.custom_trace_id_callback, Callable):
print_warn_msg(
"Parameter custom_trace_id_callback is not callable, reset it to default."
)
return self.default_trace_id()
try:
trace_id = self.custom_trace_id_callback()
if isinstance(trace_id, str) and len(trace_id) <= self.MAX_TRACE_ID_LEN:
return trace_id
print_warn_msg(
f"Parameter custom_trace_id_callback should return str(max length: {self.MAX_TRACE_ID_LEN}), reset it to default."
)
except Exception as e:
print_warn_msg(
f"Parameter custom_trace_id_callback raised an exception: {e}, reset it to default."
)
return self.default_trace_id()
def _check_params(self):
for activity in self.activities:
if activity in supported_activities():
continue
print_warn_msg(
"Invalid activities, only CPU and NPU are supported, reset it to default."
)
self.activities = supported_activities()
break
if not isinstance(self.record_shapes, bool):
print_warn_msg(
"Parameter record_shapes is of boolean type, reset it to False."
)
self.record_shapes = False
if not isinstance(self.profile_memory, bool):
print_warn_msg(
"Parameter profile_memory is of boolean type, reset it to False."
)
self.profile_memory = False
if not isinstance(self.with_stack, bool):
print_warn_msg(
"Parameter with_stack is of boolean type, reset it to False."
)
self.with_stack = False
if not isinstance(self.with_flops, bool):
print_warn_msg(
"Parameter with_flops is of boolean type, reset it to False."
)
self.with_flops = False
if not isinstance(self.with_modules, bool):
print_warn_msg(
"Parameter with_modules is of boolean type, reset it to False."
)
self.with_modules = False
if not isinstance(self.experimental_config, _ExperimentalConfig):
print_warn_msg(
"Parameter experimental_config is an instance of _ExperimentalConfig, "
"reset it to default."
)
self.experimental_config = _ExperimentalConfig()
if (
ProfilerActivity.NPU not in self.activities
and self.experimental_config is not None
):
print_warn_msg(
"Experimental config will not be used while ProfilerActivity.NPU is not set."
)
if (
ProfilerActivity.CPU not in self.activities
and self.experimental_config.with_gc
):
print_warn_msg(
"GC detect will not take effect while ProfilerActivity.CPU is not set."
)
def _dump_profiler_info(self):
def _trans_obj2cfg(obj):
if not obj:
return None
obj_attr = getattr(obj, "__dict__", {})
return obj_attr
common_config = {
"activities": list(map(str, list(self.activities))),
"schedule": _trans_obj2cfg(self.schedule),
"record_shapes": self.record_shapes,
"profile_memory": self.profile_memory,
"with_stack": self.with_stack,
"with_flops": self.with_flops,
"with_modules": self.with_modules,
}
experimental_config = _trans_obj2cfg(self.experimental_config)
config = {
Constant.COMMON_CONFIG: common_config,
Constant.EXPERIMENTAL_CONFIG: experimental_config,
}
start_info = {
Constant.SyscntEable: self.syscnt_enable,
Constant.SysCntFreq: self.freq,
Constant.StartCnt: self.start_cnt,
Constant.StartMonotonic: self.start_monotonic,
}
end_info = {
Constant.FWK_END_TIME: time.time_ns(),
Constant.FWK_END_MONOTONIC: time.monotonic_ns(),
}
total_info = {
Constant.CONFIG: config,
Constant.START_INFO: start_info,
Constant.END_INFO: end_info,
Constant.TORCH_NPU_VERSION: get_torch_npu_version()
.replace("'", "")
.replace(" ", ""),
Constant.CANN_VERSION: get_cann_version(),
}
rank_id = os.environ.get("RANK")
if (
rank_id is None
and torch.distributed.is_available()
and torch.distributed.is_initialized()
):
rank_id = torch.distributed.get_rank()
if rank_id is None:
path = os.path.join(os.path.realpath(self.prof_path), "profiler_info.json")
else:
path = os.path.join(
os.path.realpath(self.prof_path), f"profiler_info_{rank_id}.json"
)
total_info["rank_id"] = rank_id
FileManager.create_json_file_by_path(path, total_info, indent=4)
def _dump_metadata(self):
if Constant.Text in self.experimental_config.export_type:
self.metadata.update(collect_env_vars())
self._add_group_info_to_metadata()
self._add_trace_id_to_metadata()
if not self.metadata:
return
if not ProfPathCreator().is_prof_inited:
print_warn_msg("Profiler is not initialized. Skip this metadata.")
return
metadata_path = os.path.join(self.prof_path, Constant.PROFILER_META_DATA)
FileManager.create_json_file_by_path(metadata_path, self.metadata)
self.metadata.clear()
def _add_group_info_to_metadata(self):
try:
if torch.distributed.is_available() and torch.distributed.is_initialized():
group_info = {}
global_rank = torch.distributed.get_rank()
for group, group_config in distributed_world.pg_map.items():
backend = str(group_config[0]).lower()
if backend != "hccl":
continue
hccl_group = group._get_backend(torch.device("npu"))
comm_name = hccl_group.get_hccl_comm_name(
global_rank, init_comm=False
)
if comm_name:
group_info[comm_name] = {
"group_name": hccl_group.options.hccl_config.get(
"group_name", ""
),
"group_rank": torch.distributed.get_group_rank(
group, global_rank
),
"global_ranks": torch.distributed.get_process_group_ranks(
group
),
}
default_group = torch.distributed.distributed_c10d._get_default_group()
comm_name = default_group._get_backend(
torch.device("npu")
).get_hccl_comm_name(global_rank, init_comm=False)
if comm_name:
group_info[comm_name] = {
"group_name": "default_group",
"group_rank": torch.distributed.get_group_rank(
default_group, global_rank
),
"global_ranks": torch.distributed.get_process_group_ranks(
default_group
),
}
if group_info:
self.metadata.update({self.PARALLEL_GROUP_KEY: group_info})
except Exception as err:
print_warn_msg(f"Failed to get parallel group info, Exception: {str(err)}.")
def _add_trace_id_to_metadata(self):
self.metadata.update({self.TRACE_ID_KEY: self.trace_id})
@no_exception_func(set())
def supported_activities():
return _supported_npu_activities()