from dataclasses import dataclass, field
from typing import List, Literal, Optional
import logging
from mindspeed_mm.config.arguments.base_args import BaseArguments
logger = logging.getLogger(__name__)
class StaticParam(BaseArguments):
level: str = field(
default="level1",
metadata={"help": "The info level of profiler."},
)
with_stack: bool = field(
default=False,
metadata={"help": "Whether to collect operator call stack info."},
)
with_memory: bool = field(
default=False,
metadata={"help": "Whether to collect the memory usage of the operator."},
)
record_shapes: bool = field(
default=False,
metadata={"help": "Whether to collect the innput shapes and input types of operators."},
)
with_cpu: bool = field(
default=False,
metadata={"help": "Whether to collect CPU events."},
)
save_path: str = field(
default="./profiling",
metadata={"help": "Direction to export the profiling result."},
)
start_step: int = field(
default=10,
metadata={"help": "Start step for profiling."},
)
end_step: int = field(
default=11,
metadata={"help": "End step for profiling."},
)
data_simplification: bool = field(
default=False,
metadata={"help": "Whether to enable the data simplification mode."},
)
aic_metrics_type: str = field(
default="PipeUtilization",
metadata={"help": "AI Core performance metric collection items."},
)
analyse_flag: bool = field(
default=True,
metadata={"help": "Whether to analyse profiling online."},
)
class Profiler(BaseArguments):
enable: bool = field(
default=False,
metadata={"help": "Enable profiling."},
)
profile_type: str = field(
default="static",
metadata={"help": "the type of profiling"},
)
ranks: List[int] = field(
default_factory=lambda: [0],
metadata={
"help": "List of ranks to profile (default is rank 0 only)"
},
)
static_param: StaticParam = field(default_factory=StaticParam)
class MemoryProfiler(BaseArguments):
enable: bool = field(
default=False,
metadata={"help": "Enable memory profiling."},
)
start_step: int = field(
default=1,
metadata={"help": "Start step for memory profiling."},
)
end_step: int = field(
default=2,
metadata={"help": "End step for memory profiling."},
)
save_path: str = field(
default="./memory_snapshot",
metadata={"help": "Direction to export the memory profiling result."},
)
dump_ranks: List[int] = field(
default_factory=lambda: [0],
metadata={"help": "List of ranks to memory profile (default is rank 0 only)"},
)
stacks: Literal["python", "all"] = field(
default="all",
metadata={
"help": "python, include Python, TorchScript, and inductor frames in tracebacks, all, additionally include C++ frames."},
)
max_entries: Optional[int] = field(
default=None,
metadata={"help": "Keep a maximum of `max_entries` alloc/free events in the recorded history recorded."},
)
mem_info: bool = field(
default=False,
metadata={"help": "Whether to print memory infos."},
)
class ToolsArguments(BaseArguments):
profile: Profiler = field(default_factory=Profiler)
memory_profile: MemoryProfiler = field(default_factory=MemoryProfiler)