import warnings
from typing import Optional, Union
import torch_npu._C
from .analysis.prof_common_func._constant import (
Constant,
print_info_msg,
print_warn_msg,
)
__all__ = [
"_ExperimentalConfig",
"supported_profiler_level",
"supported_ai_core_metrics",
"supported_export_type",
"ProfilerLevel",
"AiCMetrics",
"ExportType",
"HostSystem",
]
def supported_profiler_level():
return {
ProfilerLevel.Level0,
ProfilerLevel.Level1,
ProfilerLevel.Level2,
ProfilerLevel.Level_none,
}
def supported_ai_core_metrics():
return {
AiCMetrics.AiCoreNone,
AiCMetrics.PipeUtilization,
AiCMetrics.ArithmeticUtilization,
AiCMetrics.Memory,
AiCMetrics.MemoryL0,
AiCMetrics.MemoryUB,
AiCMetrics.ResourceConflictRatio,
AiCMetrics.L2Cache,
AiCMetrics.MemoryAccess,
}
def supported_export_type():
return {ExportType.Db, ExportType.Text}
class ProfilerLevel:
Level0 = Constant.LEVEL0
Level1 = Constant.LEVEL1
Level2 = Constant.LEVEL2
Level_none = Constant.LEVEL_NONE
class AiCMetrics:
PipeUtilization = Constant.AicPipeUtilization
ArithmeticUtilization = Constant.AicArithmeticUtilization
Memory = Constant.AicMemory
MemoryL0 = Constant.AicMemoryL0
MemoryUB = Constant.AicMemoryUB
ResourceConflictRatio = Constant.AicResourceConflictRatio
L2Cache = Constant.AicL2Cache
MemoryAccess = Constant.AicMemoryAccess
AiCoreNone = Constant.AicMetricsNone
class ExportType:
Db = Constant.Db
Text = Constant.Text
class HostSystem:
CPU = Constant.CPU
MEM = Constant.MEM
DISK = Constant.DISK
NETWORK = Constant.NETWORK
OSRT = Constant.OSRT
NUMA = Constant.NUMA
class _ExperimentalConfig:
def __init__(
self,
profiler_level: int = Constant.LEVEL0,
aic_metrics: int = Constant.AicPipeUtilization,
l2_cache: bool = False,
msprof_tx: bool = False,
mstx: bool = False,
data_simplification: bool = True,
record_op_args: bool = False,
op_attr: bool = False,
gc_detect_threshold: Optional[float] = None,
export_type: Union[str, list] = ExportType.Text,
host_sys: Optional[list] = None,
mstx_domain_include: Optional[list] = None,
mstx_domain_exclude: Optional[list] = None,
sys_io: bool = False,
sys_interconnection: bool = False,
):
self._profiler_level = profiler_level
self._aic_metrics = aic_metrics
self._l2_cache = l2_cache
self._msprof_tx = msprof_tx
self._mstx = mstx
self._data_simplification = data_simplification
self.record_op_args = record_op_args
self._export_type = self._conver_export_type_to_list(export_type)
self._host_sys = host_sys if host_sys else []
self._op_attr = op_attr
self._gc_detect_threshold = gc_detect_threshold
self._mstx_domain_include = mstx_domain_include if mstx_domain_include else []
self._mstx_domain_exclude = mstx_domain_exclude if mstx_domain_exclude else []
self._sys_io = sys_io
self._sys_interconnection = sys_interconnection
self._check_params()
self._check_mstx_domain_params()
self._check_host_sys_params()
def __call__(self) -> torch_npu._C._profiler._ExperimentalConfig:
return torch_npu._C._profiler._ExperimentalConfig(
trace_level=self._profiler_level,
metrics=self._aic_metrics,
l2_cache=self._l2_cache,
record_op_args=self.record_op_args,
msprof_tx=self._msprof_tx or self._mstx,
op_attr=self._op_attr,
host_sys=self._host_sys,
mstx_domain_include=self._mstx_domain_include,
mstx_domain_exclude=self._mstx_domain_exclude,
sys_io=self._sys_io,
sys_interconnection=self._sys_interconnection,
)
@property
def export_type(self):
return self._export_type
@property
def with_gc(self):
return self._gc_detect_threshold is not None
@property
def gc_detect_threshold(self):
return self._gc_detect_threshold
def _conver_export_type_to_list(self, export_type: Union[str, list]) -> list:
if not export_type:
print_warn_msg(
f"Invalid parameter export_type: {export_type}, reset it to text."
)
return [ExportType.Text]
if isinstance(export_type, str):
return [export_type]
elif isinstance(export_type, list):
try:
return list(set(export_type))
except Exception as error:
print_warn_msg(
f"Invalid parameter export_type: {export_type}, reset it to text. Error is {error}"
)
return [ExportType.Text]
else:
print_warn_msg(
f"Invalid parameter export_type: {export_type}, reset it to text."
)
return [ExportType.Text]
def _check_params(self):
if (
self._profiler_level == Constant.LEVEL0
or self._profiler_level == Constant.LEVEL_NONE
) and self._aic_metrics != Constant.AicMetricsNone:
print_warn_msg(
"Please use level1 or level2 if you want to collect aic metrics, reset aic metrics to None!"
)
self._aic_metrics = Constant.AicMetricsNone
if not isinstance(self._l2_cache, bool):
print_warn_msg(
"Invalid parameter l2_cache, which must be of bool type, reset it to False."
)
self._l2_cache = False
if not isinstance(self._msprof_tx, bool):
print_warn_msg(
"Invalid parameter msprof_tx, which must be of bool type, reset it to False."
)
self._msprof_tx = False
if self._msprof_tx:
warnings.warn(
"The parameter msprof_tx will be deprecated. Please use the new parameter mstx instead."
)
if not isinstance(self._mstx, bool):
print_warn_msg(
"Invalid parameter mstx, which must be of bool type, reset it to False."
)
self._mstx = False
if (
self._profiler_level == ProfilerLevel.Level_none
and not self._mstx
and not self._msprof_tx
):
self._mstx = True
print_warn_msg(
"Parameter mstx or msprof_tx must be True if profiler_level is set to Level_none."
)
if self._data_simplification is not None and not isinstance(
self._data_simplification, bool
):
print_warn_msg(
"Invalid parameter data_simplification, which must be of bool type, reset it to default."
)
self._data_simplification = True
if not isinstance(self.record_op_args, bool):
print_warn_msg(
"Invalid parameter record_op_args, which must be of bool type, reset it to False."
)
self.record_op_args = False
if self._profiler_level not in (
ProfilerLevel.Level0,
ProfilerLevel.Level1,
ProfilerLevel.Level2,
ProfilerLevel.Level_none,
):
print_warn_msg(
"Invalid parameter profiler_level, reset it to ProfilerLevel.Level0."
)
self._profiler_level = ProfilerLevel.Level0
if self._aic_metrics not in (
AiCMetrics.L2Cache,
AiCMetrics.MemoryL0,
AiCMetrics.Memory,
AiCMetrics.MemoryUB,
AiCMetrics.PipeUtilization,
AiCMetrics.ArithmeticUtilization,
AiCMetrics.ResourceConflictRatio,
AiCMetrics.MemoryAccess,
AiCMetrics.AiCoreNone,
):
print_warn_msg("Invalid parameter aic_metrics, reset it to default.")
if self._profiler_level == ProfilerLevel.Level0:
self._aic_metrics = AiCMetrics.AiCoreNone
else:
self._aic_metrics = AiCMetrics.PipeUtilization
if not isinstance(self._op_attr, bool):
print_warn_msg(
"Invalid parameter op_attr, which must be of bool type, reset it to False."
)
self._op_attr = False
if not all(
export_type in [ExportType.Text, ExportType.Db]
for export_type in self._export_type
):
print_warn_msg("Invalid parameter export_type, reset it to text.")
self._export_type = [ExportType.Text]
if self._op_attr and ExportType.Db not in self._export_type:
print_warn_msg("op_attr switch is invalid with export type set as text.")
self._op_attr = False
if self._gc_detect_threshold is not None:
if not isinstance(self._gc_detect_threshold, (int, float)):
print_warn_msg(
"Parameter gc_detect_threshold is not int or float type, reset it to default."
)
self._gc_detect_threshold = None
elif self._gc_detect_threshold < 0.0:
print_warn_msg(
"Parameter gc_detect_threshold can not be negative, reset it to default."
)
self._gc_detect_threshold = None
elif self._gc_detect_threshold == 0.0:
print_info_msg(
"Parameter gc_detect_threshold is set to 0, it will collect all gc events."
)
if not isinstance(self._sys_io, bool):
print_warn_msg(
"Invalid parameter sys_io, which must be of bool type, reset it to False."
)
self._sys_io = False
if not isinstance(self._sys_interconnection, bool):
print_warn_msg(
"Invalid parameter sys_interconnection, which must be of bool type, reset it to False."
)
self._sys_interconnection = False
def _check_mstx_domain_params(self):
if not self._msprof_tx and not self._mstx:
if self._mstx_domain_include or self._mstx_domain_exclude:
print_warn_msg(
"mstx_domain_include and mstx_domain_exclude are valid when msprof_tx or mstx is True."
)
self._mstx_domain_include = []
self._mstx_domain_exclude = []
return
if self._mstx_domain_include:
if not isinstance(self._mstx_domain_include, list):
print_warn_msg(
"Invalid parameter mstx_domain_include, which must be of list type, "
"reset it to default."
)
self._mstx_domain_include = []
if any(not isinstance(domain, str) for domain in self._mstx_domain_include):
print_warn_msg(
"Invalid parameter mstx_domain_include, which contents must be of str type, "
"reset it to default."
)
self._mstx_domain_include = []
else:
self._mstx_domain_include = list(set(self._mstx_domain_include))
if self._mstx_domain_exclude:
if not isinstance(self._mstx_domain_exclude, list):
print_warn_msg(
"Invalid parameter mstx_domain_exclude, which must be of list type, "
"reset it to default."
)
self._mstx_domain_exclude = []
if any(not isinstance(domain, str) for domain in self._mstx_domain_exclude):
print_warn_msg(
"Invalid parameter _mstx_domain_exclude, which contents must be of str type, "
"reset it to default."
)
self._mstx_domain_exclude = []
else:
self._mstx_domain_exclude = list(set(self._mstx_domain_exclude))
if self._mstx_domain_include and self._mstx_domain_exclude:
print_warn_msg(
"Parameter mstx_domain_include and mstx_domain_exclude can not be both set, "
"only mstx_domain_include will work."
)
self._mstx_domain_exclude = []
def _check_host_sys_params(self):
if not isinstance(self._host_sys, list):
print_warn_msg(
"Invalid parameter host_sys, which must be of list type, reset it to empty."
)
self._host_sys = []
if not all(
host_sys
in [
HostSystem.CPU,
HostSystem.MEM,
HostSystem.DISK,
HostSystem.NETWORK,
HostSystem.OSRT,
HostSystem.NUMA,
]
for host_sys in self._host_sys
):
print_warn_msg("Invalid parameter host_sys, reset it to empty.")
self._host_sys = list()
self._host_sys = list({str(item) for item in self._host_sys})