"""分析msprobe mix级别db文件,寻找首个输入一致输出不一致的API。"""
import argparse
import sqlite3
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from msprobe_utils import build_excluded_rules, is_key_excluded, print_kv_table, is_api_excluded
def format_hierarchy(hierarchy):
"""格式化层级路径,返回格式化的行列表。"""
lines = []
dl = 0
for node in hierarchy:
nt = node.get('node_type', '')
if nt not in ('0', '1'):
continue
tl = {'0': 'Module', '1': 'API'}.get(nt, '?')
name = node.get('node_name', '?')
prefix = '→ ' if dl > 0 else ''
lines.append(f"{' ' * dl}{prefix}{name} [{tl}]")
dl += 1
return lines
def extract_md5_values(data_json, node_name=None, excluded_rules=None):
md5_map = {}
if not data_json:
return md5_map
try:
data = json.loads(data_json)
except json.JSONDecodeError:
return md5_map
for key, value in data.items():
if not (isinstance(value, dict) and 'md5' in value):
continue
if node_name and excluded_rules:
excluded = False
for io_type in ('input', 'output'):
marker = f'.{io_type}.'
if marker in key:
idx = key.rsplit('.', 1)[1]
if is_key_excluded(node_name, io_type, idx, excluded_rules):
excluded = True
break
if excluded:
continue
md5_map[key] = value['md5']
return md5_map
def compare_md5(npu_md5_map, bench_md5_map):
"""比较NPU和Bench的md5,返回 (是否一致, 差异列表)"""
all_keys = sorted(set(list(npu_md5_map.keys()) + list(bench_md5_map.keys())))
is_match = True
diffs = []
for key in all_keys:
npu_m = npu_md5_map.get(key, "<MISSING>")
bench_m = bench_md5_map.get(key, "<MISSING>")
if npu_m != bench_m:
is_match = False
diffs.append(f"{key}: NPU={npu_m} vs Bench={bench_m}")
return is_match, diffs, all_keys
def build_hierarchy(cursor, node_id, data_source, step, rank):
"""通过追踪up_node构建从root到指定节点的层级路径。返回 (路径列表, 子节点列表)"""
path = []
current_id = node_id
max_depth = 50
for _ in range(max_depth):
cursor.execute(
"SELECT id, node_name, node_type, up_node, sub_nodes FROM tb_nodes WHERE id=?",
(current_id,)
)
row = cursor.fetchone()
if not row:
break
nid, nname, ntype, up_node, sub_nodes = row
children = []
if sub_nodes:
try:
children = json.loads(sub_nodes)
except (json.JSONDecodeError, TypeError):
pass
path.append({'node_id': nid, 'node_name': nname, 'node_type': ntype, 'sub_nodes': children})
if not up_node or up_node == 'None':
break
cursor.execute(
"SELECT id FROM tb_nodes WHERE data_source=? AND step=? AND rank=? AND node_name=?",
(data_source, step, rank, up_node)
)
parent_row = cursor.fetchone()
if not parent_row:
break
current_id = parent_row[0]
else:
path.append({'node_name': f'... (超过{max_depth}层)', 'node_type': '?'})
path.reverse()
return path
def analyze_rank(db_path, step, rank, excluded_apis=None):
"""分析单个rank,返回该rank的分析结果"""
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
cursor = conn.cursor()
excluded_rules = build_excluded_rules()
cursor.execute(
"SELECT id, node_name, node_order, node_type, input_data, output_data"
" FROM tb_nodes WHERE data_source='NPU' AND precision_index=1"
" AND step=? AND rank=? ORDER BY node_order",
(step, rank)
)
npu_error_nodes = cursor.fetchall()
api_result = None
if npu_error_nodes:
for npu_id, node_name, node_order, node_type, npu_input, npu_output in npu_error_nodes:
if is_api_excluded(node_name, excluded_apis):
continue
bench_id = npu_id.replace("NPU_", "Bench_", 1)
cursor.execute("SELECT input_data, output_data FROM tb_nodes WHERE id=?", (bench_id,))
bench_row = cursor.fetchone()
if not bench_row:
continue
bench_input, bench_output = bench_row
input_match, _, all_input_keys = compare_md5(
extract_md5_values(npu_input, node_name, excluded_rules),
extract_md5_values(bench_input, node_name, excluded_rules)
)
output_match, output_diffs, _ = compare_md5(
extract_md5_values(npu_output, node_name, excluded_rules),
extract_md5_values(bench_output, node_name, excluded_rules)
)
if not input_match:
continue
if not output_match:
hierarchy = build_hierarchy(cursor, npu_id, 'NPU', step, rank)
type_label = 'API' if node_type == '1' else ('Module' if node_type == '0' else f'Type_{node_type}')
api_result = dict(node_name=node_name, node_type=type_label, node_order=node_order,
input_keys=all_input_keys, output_diffs=output_diffs, hierarchy=hierarchy)
break
if api_result is not None:
conn.close()
return dict(step=step, rank=rank, found=True, **api_result)
cursor.execute(
"SELECT id, node_name, node_order, node_type, input_data, output_data"
" FROM tb_nodes WHERE data_source='NPU' AND node_type='0'"
" AND precision_index=1 AND step=? AND rank=? ORDER BY node_order",
(step, rank)
)
npu_module_nodes = cursor.fetchall()
module_result = None
for npu_id, node_name, node_order, node_type, npu_input, npu_output in npu_module_nodes:
if excluded_apis and node_name in excluded_apis:
continue
bench_id = npu_id.replace("NPU_", "Bench_", 1)
cursor.execute("SELECT input_data, output_data FROM tb_nodes WHERE id=?", (bench_id,))
bench_row = cursor.fetchone()
if not bench_row:
continue
bench_input, bench_output = bench_row
if not extract_md5_values(npu_input, node_name, excluded_rules) and not extract_md5_values(npu_output, node_name, excluded_rules):
continue
input_match, _, _ = compare_md5(
extract_md5_values(npu_input, node_name, excluded_rules),
extract_md5_values(bench_input, node_name, excluded_rules)
)
output_match, output_diffs, _ = compare_md5(
extract_md5_values(npu_output, node_name, excluded_rules),
extract_md5_values(bench_output, node_name, excluded_rules)
)
if not input_match:
continue
if not output_match:
module_result = dict(module_name=node_name, node_order=node_order,
output_diffs=output_diffs, hierarchy=build_hierarchy(cursor, npu_id, 'NPU', step, rank))
break
conn.close()
result = dict(step=step, rank=rank, found=False, reason='未找到输入一致输出不一致的API')
if module_result:
result['module_analysis'] = module_result
return result
def analyze_db(db_path, excluded_apis=None):
if not os.path.exists(db_path):
print(f"错误: 文件不存在: {db_path}")
return
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
cursor = conn.cursor()
cursor.execute("SELECT DISTINCT step, rank FROM tb_nodes ORDER BY step, rank")
step_ranks = cursor.fetchall()
cursor.execute("SELECT DISTINCT step FROM tb_nodes ORDER BY step")
steps = [row[0] for row in cursor.fetchall()]
conn.close()
if not step_ranks:
print("db文件中没有数据")
return
print(f"共发现 {len(step_ranks)} 个 (step, rank) 组合: {step_ranks}")
print(f"共 {len(steps)} 个step: {steps}")
print("=" * 80)
for step in steps:
ranks_in_step = [sr[1] for sr in step_ranks if sr[0] == step]
print(f"\n分析 step={step}, ranks={ranks_in_step} (并行分析):")
print("-" * 60)
with ThreadPoolExecutor(max_workers=min(len(ranks_in_step), 32)) as executor:
future_to_rank = {executor.submit(analyze_rank, db_path, step, rank, excluded_apis): rank for rank in ranks_in_step}
step_results = []
for future in as_completed(future_to_rank):
rank = future_to_rank[future]
try:
step_results.append(future.result())
except Exception as e:
step_results.append({'step': step, 'rank': rank, 'found': False, 'reason': f"分析异常: {e}"})
step_results.sort(key=lambda r: r['rank'])
for r in step_results:
kv_rows = []
if r['found']:
hier = '\n'.join(format_hierarchy(r.get('hierarchy', []))) or '-'
items = ['Input MD5 (全部一致):'] + [f' {k}' for k in r.get('input_keys', [])]
items += ['Output MD5 (不一致):'] + [f' {d}' for d in r.get('output_diffs', [])]
kv_rows = [('首个问题API', r['node_name']), ('API所在Module层级', hier),
('API分析依据', '\n'.join(items)),
('首个问题Module', '已找到首个问题API,不分析'), ('Module分析依据', '-')]
else:
kv_rows.extend([('首个问题API', '无'), ('API所在Module层级', '-'),
('API分析依据', r.get('reason', '未找到首个输入一致输出不一致的API'))])
ma = r.get('module_analysis')
if ma:
kv_rows.append(('首个问题Module', ma['module_name']))
items = ['Module Output MD5 (不一致):'] + [f' {d}' for d in ma.get('output_diffs', [])]
kv_rows.append(('Module分析依据', '\n'.join(items)))
else:
kv_rows.extend([('首个问题Module', '无'),
('Module分析依据', '未找到首个输入一致输出不一致的Module')])
print_kv_table(r['rank'], kv_rows)
if not r['found'] and r.get('module_analysis'):
ma = r['module_analysis']
print(" 可能原因: 该Module内可能有被msprobe漏采的API,导致Module整体输出不一致。")
if ma.get('hierarchy'):
print(" 请检查以下层级的子节点列表,确认是否有API未被采集:")
for line in format_hierarchy(ma['hierarchy']):
print(f" {line}")
if any(r['found'] for r in step_results):
if not excluded_apis:
print(f"\n提示: 如果不认为上述API是问题根因,可输入API名称排除后重新分析。支持前缀匹配,多个API以空格分隔。")
if step < max(steps):
print(f"\nstep={step} 已找到结果,跳过后续step。")
break
def main():
parser = argparse.ArgumentParser(description='分析msprobe mix级别db文件,寻找首个输入一致输出不一致的API。')
parser.add_argument('db_path', help='mix级别比对结果db文件(.vis.db)')
parser.add_argument('--exclude-api', nargs='+', default=[], metavar='NAME',
help='要排除的API名称(支持前缀匹配,多个以空格分隔)')
args = parser.parse_args()
excluded = set(args.exclude_api)
if excluded:
print(f"排除以下API前缀: {sorted(excluded)}")
analyze_db(args.db_path, excluded)
if __name__ == "__main__":
main()