"""检查msprobe数据,确定分析级别(L1或mix)。"""
import argparse
import csv
import os
import sqlite3
def find_first_dump_file(root_path):
for dirpath, _, filenames in os.walk(root_path):
for f in filenames:
if f == 'dump.json':
return os.path.join(dirpath, f)
return None
def detect_path_type(root_path):
"""检测路径类型: 'db', 'csv_xlsx', 或 None"""
names = [root_path] if os.path.isfile(root_path) else os.listdir(root_path)
for name in names:
if name.endswith('.vis.db'):
return 'db'
if name.endswith(('.xlsx', '.csv')):
return 'csv_xlsx'
return None
def _first_file(root_path, exts):
if os.path.isfile(root_path):
return root_path
for name in sorted(os.listdir(root_path)):
if name.endswith(exts):
return os.path.join(root_path, name)
return None
def _check_l0_names(names_iter, max_scan=100):
"""扫描前max_scan个name,全为Module./Cell.前缀则返回True。"""
count = 0
for i, name in enumerate(names_iter):
if i >= max_scan:
break
if name:
count += 1
if not name.startswith(('Module.', 'Cell.')):
return False
return count > 0
def validate_csv_xlsx(root_path):
"""校验csv/xlsx文件头,并检测是否为L0(仅Module级数据)。"""
first = _first_file(root_path, ('.csv', '.xlsx'))
if not first:
return "未找到 .csv 或 .xlsx 文件"
try:
if first.endswith('.csv'):
with open(first, encoding='utf-8') as f:
reader = csv.DictReader(f)
fieldnames = reader.fieldnames or []
if _check_l0_names((row.get('NPU Name', '') for row in reader)):
return "当前数据仅包含Module级信息,没有API级数据,无法分析确定性问题。"
else:
import openpyxl
wb = openpyxl.load_workbook(first, read_only=True)
ws = wb.active
fieldnames = [str(c.value) if c.value is not None else '' for c in next(ws.iter_rows(min_row=1, max_row=1))]
if _check_l0_names(str(row[0]) if row[0] is not None else '' for row in ws.iter_rows(min_row=2, values_only=True)):
wb.close()
return "当前数据仅包含Module级信息,没有API级数据,无法分析确定性问题。"
wb.close()
except Exception as e:
return f"读取文件失败: {first}\n {e}"
missing = [c for c in ('NPU MD5', 'BENCH MD5') if c not in fieldnames]
if missing:
return f"缺少比对字段: {', '.join(missing)},没有包含tensor的CRC-32校验值,无法分析确定性问题。"
return None
def validate_db(root_path):
"""校验db的tb_config表,并检测是否为L0(无API节点)。"""
first = _first_file(root_path, ('.vis.db',))
if not first:
return "未找到 .vis.db 文件"
try:
conn = sqlite3.connect(f"file:{first}?mode=ro", uri=True)
task_values = {row[0] for row in conn.execute("SELECT task FROM tb_config")}
has_api = any(row[0] == '1' for row in conn.execute("SELECT DISTINCT node_type FROM tb_nodes"))
conn.close()
except Exception as e:
return f"读取db文件失败: {first}\n {e}"
if 'md5' not in task_values:
return "tb_config 表的 task 字段不是 md5,没有包含tensor的CRC-32校验值,无法分析确定性问题。"
if not has_api:
return "当前数据仅包含Module级信息,没有API级数据,无法分析确定性问题。"
return None
def check_dump_file(filepath, label):
"""检查dump.json前100行,返回level值或抛出异常。"""
if not filepath:
raise RuntimeError(f"({label}) 未找到 dump.json 文件")
try:
with open(filepath, 'r', encoding='utf-8') as f:
lines = [f.readline() for _ in range(100)]
except Exception as e:
raise RuntimeError(f"({label}) 读取文件失败: {filepath}\n {e}")
content = ''.join(lines)
if '"md5":' not in content:
raise RuntimeError(f"({label}) 当前dump数据没有包含tensor的CRC-32校验值,无法分析确定性问题。\n 文件: {filepath}")
level = None
for line in lines:
stripped = line.strip().rstrip(',')
if stripped.startswith('"level"'):
level = stripped.split(':', 1)[-1].strip().strip('"')
break
if level is None:
raise RuntimeError(f"({label}) dump.json 中未找到 level 字段。\n 文件: {filepath}")
if level not in ('L1', 'mix'):
raise RuntimeError(f'({label}) dump数据的level="{level}",需要为"L1"或"mix"。\n 文件: {filepath}')
return level
def main():
parser = argparse.ArgumentParser(description='检查msprobe数据,确定分析级别(L1或mix)。')
parser.add_argument('target', help='dump target路径,或 db/csv/xlsx 路径')
parser.add_argument('golden', nargs='?', help='dump golden路径(db/csv/xlsx 路径不需要)')
args = parser.parse_args()
if args.golden is None:
if not os.path.exists(args.target):
parser.exit(1, f"错误: 路径不存在: {args.target}\n")
ptype = detect_path_type(args.target)
if ptype == 'db':
err = validate_db(args.target)
elif ptype == 'csv_xlsx':
err = validate_csv_xlsx(args.target)
else:
parser.exit(1, f"错误: 未找到 .vis.db 或 .csv/.xlsx 文件: {args.target}\n")
if err:
parser.exit(1, f"错误: {ptype}文件校验不通过。\n {err}\n")
level = 'mix' if ptype == 'db' else 'L1'
print(f'level="{level}"')
return
target_path, golden_path = args.target, args.golden
for p in (target_path, golden_path):
if not os.path.exists(p):
parser.exit(1, f"错误: 路径不存在: {p}\n")
target_file = find_first_dump_file(target_path)
golden_file = find_first_dump_file(golden_path)
all_pass = True
levels = {}
for filepath, label in [(target_file, 'target'), (golden_file, 'golden')]:
try:
levels[label] = check_dump_file(filepath, label)
except RuntimeError as e:
print(e)
all_pass = False
if all_pass and len(levels) == 2 and levels['target'] != levels['golden']:
print(f"target和golden的level不一致: target=\"{levels['target']}\", golden=\"{levels['golden']}\"")
all_pass = False
if all_pass:
print(f"level=\"{levels['target']}\"")
else:
parser.exit(1)
if __name__ == "__main__":
main()