"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2026 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os.path
import logging
import argparse
from enum import Enum
from typing import List
from torch_npu.profiler.analysis._profiler_config import ProfilerConfig
from torch_npu.profiler.analysis.prof_parse._event_tree_parser import EventTree, _EventType
from profiler_event_analyze_patch import init_patch
from memscope import MemoryEvent, PythonTraceEvent, MemScopeDb
from prof_adaptor import EventBuilder
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
init_patch()
def analyze_profiler_events() -> (List[MemoryEvent], List[PythonTraceEvent]):
logging.info(f"Starting to parse profiler data...")
event_tree = EventTree(DumpConfig.profiler_path)
sorted_events = event_tree.sorted_events
if not sorted_events:
raise RuntimeError("No events were parsed from the profiler data.")
min_time = sorted_events[0].start_time_ns
max_time = sorted_events[-1].start_time_ns
logging.info(f"Parsing profiler data finished. "
f"(The first profiler event start timestamp: {min_time}, "
f"The last profiler event start timestamp {max_time})")
DumpConfig.reset_time_range_by_events_time(min_time, max_time)
allocation_events: List[MemoryEvent] = []
python_trace_events: List[PythonTraceEvent] = []
for event in sorted_events:
if event.tag is _EventType.Allocation:
if DumpConfig.crop_mode is CropMode.NO_CROPPING or DumpConfig.start <= event.start_time_ns <= DumpConfig.end:
allocation_events.append(EventBuilder.build_memory_event_by_prof_event(event))
elif event.tag is _EventType.PyCall:
if DumpConfig.crop_mode is CropMode.NO_CROPPING:
python_trace_events.append(EventBuilder.build_python_trace_event_by_prof_event(event))
continue
if event.start_time_ns > DumpConfig.end or event.end_time_ns < DumpConfig.start:
continue
trace_event = EventBuilder.build_python_trace_event_by_prof_event(event)
trace_event.start_time_ns = max(DumpConfig.start, trace_event.start_time_ns)
trace_event.end_time_ns = min(DumpConfig.end, trace_event.end_time_ns)
python_trace_events.append(trace_event)
return allocation_events, python_trace_events
def save_memory_events(events: List[MemoryEvent], db: MemScopeDb):
dump_event_table = db.get_dump_table()
event_records = [event.to_dict() for event in events]
dump_event_table.insert_records(db.conn, event_records)
def save_python_trace_events(events: List[PythonTraceEvent], db: MemScopeDb):
pid_python_trace_map = {}
for event in events:
if event.pid not in pid_python_trace_map:
pid_python_trace_map[event.pid] = []
pid_python_trace_map[event.pid].append(event.to_dict())
for pid, event_records in pid_python_trace_map.items():
python_trace_table = db.get_python_trace_table(pid)
python_trace_table.insert_records(db.conn, event_records)
def dump_profiler_data_to_memscope():
allocation_events, python_trace_events = analyze_profiler_events()
try:
db = MemScopeDb(DumpConfig.output_path)
except Exception as e:
raise RuntimeError("Failed to create output db file.") from e
logging.info(
f"Starting save to db, allocation_events:{len(allocation_events)}, trace_events:{len(python_trace_events)} ...")
save_memory_events(allocation_events, db)
save_python_trace_events(python_trace_events, db)
db.conn.close()
logging.info(f"Dump profiler data to memscope data success, output path: {DumpConfig.output_path}")
class CropMode(Enum):
NO_CROPPING = 1
CROP_FIXED_DURATION_FROM_MIN_EVENT_TIME = 2
CROP_FIXED_DURATION_FROM_SPECIFIED_START_TIME = 3
CROP_FROM_SPECIFIED_START_TIME_TO_THE_END = 4
class DumpConfig:
profiler_path: str
start: int
duration: int
output_path: str
end: int
crop_mode: CropMode
@classmethod
def init(cls):
parser = argparse.ArgumentParser(description="Parse PyTorch Profiler data into memscope format.")
parser.add_argument("profiler_path", type=str,
help="Specify the PyTorch Profiler directory to be parsed.")
parser.add_argument("-s", "--start",
type=lambda x: int(x) if int(x) > 0 else parser.error(f"{x} is not a positive integer"),
default=-1, required=False,
help="Specify the start time for data parsing and trimming, in nanoseconds (ns)."
"If not provided, trimming will start from the beginning of the profiler data.")
parser.add_argument("-d", "--duration",
type=lambda x: int(x) if int(x) > 0 else parser.error(f"{x} is not a positive integer"),
default=-1, required=False,
help="Specify the duration of data to be trimmed, starting from the start time, in nanoseconds (ns)."
"If not provided, the data will be trimmed from the start time to the end of the profiler data.")
parser.add_argument("-o", "--output_path", type=str, required=False,
help="Specify the output file path for the parsed results; "
"defaults to the dump_data subdirectory under the profiler path.")
args = parser.parse_args()
cls.profiler_path = args.profiler_path
if not os.path.exists(cls.profiler_path):
raise RuntimeError(f"The specified profiler path does not exist: {cls.profiler_path}.")
cls.set_time_range_and_crop_mode(args.start, args.duration)
cls.output_path = args.output_path
if not cls.output_path:
import time
cls.output_path = os.path.join(cls.profiler_path, "dump_data", f"leaks_dump_{time.time_ns()}.db")
ProfilerConfig().load_info(cls.profiler_path)
@classmethod
def set_time_range_and_crop_mode(cls, start: int, duration: int):
"""
从输入指定的start与duration确定剪裁模式
:param start: 指定的剪裁开始时间,缺省为-1
:param duration: 指定的剪裁时长,缺省为-1
:return:
"""
if start == -1 and duration == -1:
cls.crop_mode = CropMode.NO_CROPPING
return
if start > 0 and duration > 0:
cls.start = start
cls.duration = duration
cls.end = start + duration
cls.crop_mode = CropMode.CROP_FIXED_DURATION_FROM_SPECIFIED_START_TIME
return
if start == -1:
cls.start = -1
cls.duration = duration
cls.end = -1
cls.crop_mode = CropMode.CROP_FIXED_DURATION_FROM_MIN_EVENT_TIME
return
cls.start = start
cls.duration = -1
cls.end = -1
cls.crop_mode = CropMode.CROP_FROM_SPECIFIED_START_TIME_TO_THE_END
@classmethod
def reset_time_range_by_events_time(cls, min_time: int, max_time: int):
"""
用户所指定的剪裁范围可能超过实际的数据范围,在解析完event事件后进行重置时间范围(取交集)
:param min_time: profiler事件序列的最小时间
:param max_time: profiler事件序列的最大时间
"""
if cls.crop_mode is CropMode.NO_CROPPING:
return
if cls.crop_mode is CropMode.CROP_FIXED_DURATION_FROM_SPECIFIED_START_TIME:
cls.start = max(min_time, cls.start)
cls.end = min(max_time, cls.end)
return
if cls.crop_mode is CropMode.CROP_FIXED_DURATION_FROM_MIN_EVENT_TIME:
cls.start = min_time
cls.end = min(max_time, cls.start + cls.duration)
return
cls.start = max(min_time, cls.start)
cls.end = max_time
if __name__ == '__main__':
DumpConfig.init()
dump_profiler_data_to_memscope()