import os
import shutil
import logging
import re
import pandas as pd
from packaging import version
import torch
from .base_test import BaseTest, TestSuite
from ..utils.result import Result
from ..utils.file_system import WorkingDir
from ..utils.utils import ColorText
class DecomposeTestSuite(TestSuite):
def __init__(self, name: str, config, work_path: str, cmd: str, max_time):
super().__init__(name, config, work_path, cmd, max_time)
test_cases = [
DecomposeTestCase("check_dump", name, work_path, "")
]
_ = list(map(self.register, test_cases))
def __str__(self):
return f"msmemscope test suite. suite name: {self.name}, " \
f"suite work path: {self._work_path}"
def set_up(self):
super().set_up()
if not os.path.exists(self._work_path):
os.mkdir(self._work_path)
with WorkingDir(self._work_path):
if os.path.exists('memscopeDumpResults'):
shutil.rmtree('memscopeDumpResults')
log_files = [name for name in os.listdir(".") if name.endswith('.log')]
for file in log_files:
os.remove(file)
json_files = [name for name in os.listdir(".") if name.endswith('.json')]
for file in json_files:
os.remove(file)
def tear_down(self):
super().tear_down()
class DecomposeTestCase(BaseTest):
def __init__(self, name: str, case_name: str, real_path: str, golden_path: str):
super().__init__(name)
self.case_name = case_name
self._golden_path = golden_path
self._real_path = real_path
def __str__(self):
return f"case name: {self.name}, " \
f"case path: {self._real_path}"
def comp_memscope_csv_contents(self, file_paths, column):
dfs = []
for file in file_paths:
try:
df = pd.read_csv(file)
dfs.append(df)
except Exception as e:
logging.error(f"Error reading {file}: {str(e)}")
if dfs:
data = pd.concat(dfs, ignore_index=True)
ATTR_VALID_RULES = [
"owner:PTA", "owner:PTA@model@optimizer_state", "owner:PTA@ops@aten",
"owner:PTA@model@gradient", "owner:PTA@model@weight", "owner:CANN",
"owner:CANN@HCCL", "owner:CANN@APP", "owner:CANN@RUNTIME", "owner:CANN@GE"
]
ATTR_VALID_THRESHOLD = {
"owner:PTA": {"min": 5450, "max": 5650},
"owner:CANN": {"min": 200, "max": 350}
}
if list(data.columns) != column.split(','):
logging.error("the sum data column of %s is error, please check your dump file")
return Result(False, [column], [data.columns])
if 'Attr' not in data.columns:
logging.error("ATTR column not found in data")
return Result(False, ["ATTR column missing"], [])
if version.parse(torch.__version__) < version.parse("2.3.0"):
logging.error(f"PyTorch version {torch.__version__} is below minimum required version 2.3.0 for memory profiling")
return Result(True, ["PyTorch version requirement not met", f"Required: >=2.3.0", f"Current: {torch.__version__}"], [])
for check_element in ATTR_VALID_RULES:
actual_count = data["Attr"].str.count(check_element).sum()
if check_element in ATTR_VALID_THRESHOLD:
lower_bound = ATTR_VALID_THRESHOLD[check_element]["min"]
upper_bound = ATTR_VALID_THRESHOLD[check_element]["max"]
if not (lower_bound <= actual_count <= upper_bound):
logging.error(f"{check_element} does not match the expectation, please check your dump file")
return Result(False, [f"{check_element} does not match the expectation", -1], [actual_count])
else:
if actual_count < 1:
logging.error(f"{check_element} does not exist, please check your dump file")
return Result(False, [f"{check_element} does not exist", -1], [-1])
return Result(True, [], [])
def compare_memscope_csv(self):
FILE_GEN_COUNT = 3
FILE_GEN_DIR = os.path.join(self._real_path, 'memscopeDumpResults')
logging.info("checking csv...")
if not os.path.exists(FILE_GEN_DIR):
logging.error("directory %s not exist", FILE_GEN_DIR)
return Result(False, [], [])
new_csv_files_names = []
new_csv_files_paths = []
for root, dirs, files in os.walk(FILE_GEN_DIR):
for file in files:
if file.endswith('.csv') and file.startswith('memscope_dump'):
full_path = os.path.join(root, file)
new_csv_files_names.append(file)
new_csv_files_paths.append(full_path)
if len(new_csv_files_names) != FILE_GEN_COUNT:
logging.error("Failed to generate %d CSV files", FILE_GEN_COUNT)
return Result(False, [FILE_GEN_COUNT], [len(new_csv_files_names)])
for i in range(FILE_GEN_COUNT):
if not re.match('memscope_dump_\d{1,20}\.csv', new_csv_files_names[i]):
logging.error("A CSV file matching naming convention memscope_dump could not be found")
return Result(False, [], [])
column = ("ID,Event,Event Type,Name,Timestamp(ns),Process Id,Thread Id,Device Id,"
"Ptr,Attr,Call Stack(Python),Call Stack(C)")
result = self.comp_memscope_csv_contents(new_csv_files_paths, column)
if not result.success:
return result
logging.info("check finish")
return Result(True, [], [])
def run(self) -> Result:
super().run()
logging.debug(f"run {self}")
print(f"{ColorText.run_test} {self}")
result = Result(False, [], [])
if self._name == "check_dump":
result = self.compare_memscope_csv()
self.report(result)
return result
def set_up(self):
super().set_up()
def tear_down(self):
super().tear_down()