import os
import csv
from datetime import datetime
from typing import Dict, Any, Optional, List
import openpyxl
from openpyxl.styles import Font, Alignment, PatternFill, Border, Side
class Record:
"""测试结果记录类,用于将测试用例结果写入 Excel"""
def __init__(self, output_path: str = "test_results.xlsx"):
"""初始化记录器
Args:
output_path: Excel 文件输出路径
"""
self.output_path = output_path
self.workbook = openpyxl.Workbook()
self.worksheet = self.workbook.active
self.worksheet.title = "Test Results"
self.header_font = Font(bold=True, color="FFFFFF")
self.header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
self.header_alignment = Alignment(horizontal="center", vertical="center")
self.pass_fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
self.fail_fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
self.border = Border(
left=Side(style="thin"),
right=Side(style="thin"),
top=Side(style="thin"),
bottom=Side(style="thin")
)
self.base_headers = [
"Test Case Name",
"Batch Size",
"Head Num",
"Head Dim QK",
"Head Dim V",
"Max SeqLen Q",
"Max SeqLen K",
"Has RAB",
"Data Type",
"Seed",
"Overall Pass"
]
self.grad_headers = ["DQ", "DK", "DV", "DRAB"]
self.detail_headers = ["Actual-FP32 Ref", "FP32 Ref", "Actual-Out Ref", "Passed"]
self.seq_stat_headers = ["SeqLen Q Mean", "SeqLen Q Max", "SeqLen Q Min",
"SeqLen K Mean", "SeqLen K Max", "SeqLen K Min"]
self._init_headers()
self.current_row = 2
def record(
self,
case_name: str,
params: Dict[str, Any],
detail: Optional[Dict[str, Any]] = None,
seq_stats: Optional[Dict[str, Any]] = None
):
"""记录单个测试用例结果
Args:
case_name: 用例名称
params: 用例参数字典,包含:
▪ batch_size
▪ head_num
▪ head_dim_qk
▪ head_dim_v
▪ max_seqlen_q
▪ max_seqlen_k
▪ has_rab
▪ data_type
▪ seed
detail: 详细精度数据,由 Validator.backward_verify 返回的 detail 字典
seq_stats: 序列长度统计数据,包含:
▪ seq_lens_q_mean, seq_lens_q_max, seq_lens_q_min
▪ seq_lens_k_mean, seq_lens_k_max, seq_lens_k_min
"""
col = 1
self.worksheet.cell(row=self.current_row, column=col, value=case_name)
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=params.get("batch_size"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=params.get("head_num"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=params.get("head_dim_qk"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=params.get("head_dim_v"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=params.get("max_seqlen_q"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=params.get("max_seqlen_k"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=params.get("has_rab"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=str(params.get("data_type")))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=params.get("seed"))
col += 1
passed = True
if detail:
for grad in self.grad_headers:
grad_detail = detail.get(grad)
if grad_detail is None:
continue
if not self._check_pass(grad_detail):
passed = False
break
pass_cell = self.worksheet.cell(row=self.current_row, column=col, value=passed)
pass_cell.fill = self.pass_fill if passed else self.fail_fill
pass_cell.alignment = Alignment(horizontal="center", vertical="center")
col += 1
if detail:
for grad in self.grad_headers:
grad_detail = detail.get(grad)
if grad_detail is None:
for _ in self.detail_headers:
self.worksheet.cell(row=self.current_row, column=col, value="N/A")
col += 1
continue
has_detail_data = "actual-fp32_out_ref" in grad_detail
if has_detail_data:
self.worksheet.cell(row=self.current_row, column=col, value=grad_detail.get("actual-fp32_out_ref"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=grad_detail.get("fp32_out_ref"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=grad_detail.get("actual-out_ref"))
col += 1
else:
self.worksheet.cell(row=self.current_row, column=col, value="N/A")
col += 1
self.worksheet.cell(row=self.current_row, column=col, value="N/A")
col += 1
self.worksheet.cell(row=self.current_row, column=col, value="N/A")
col += 1
grad_passed = grad_detail.get("passed", False)
pass_symbol = "✔" if grad_passed else "✗"
pass_cell = self.worksheet.cell(row=self.current_row, column=col, value=pass_symbol)
pass_cell.fill = self.pass_fill if grad_passed else self.fail_fill
pass_cell.alignment = Alignment(horizontal="center", vertical="center")
col += 1
else:
for _ in self.grad_headers:
for _ in self.detail_headers:
self.worksheet.cell(row=self.current_row, column=col, value="N/A")
col += 1
if seq_stats:
self.worksheet.cell(row=self.current_row, column=col, value=seq_stats.get("seq_lens_q_mean"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=seq_stats.get("seq_lens_q_max"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=seq_stats.get("seq_lens_q_min"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=seq_stats.get("seq_lens_k_mean"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=seq_stats.get("seq_lens_k_max"))
col += 1
self.worksheet.cell(row=self.current_row, column=col, value=seq_stats.get("seq_lens_k_min"))
col += 1
else:
for _ in self.seq_stat_headers:
self.worksheet.cell(row=self.current_row, column=col, value="N/A")
col += 1
for c in range(1, col):
cell = self.worksheet.cell(row=self.current_row, column=c)
cell.border = self.border
self.current_row += 1
def _check_pass(self, grad_detail: Dict[str, Any]) -> bool:
"""检查单个梯度是否通过验证"""
if "passed" in grad_detail:
return grad_detail["passed"]
actual_fp32_out_ref = grad_detail.get("actual-fp32_out_ref", 0)
fp32_out_ref = grad_detail.get("fp32_out_ref", 0)
multiplier = 5
try_allclose = grad_detail.get("try_allclose", False)
return (actual_fp32_out_ref <= multiplier * fp32_out_ref) or try_allclose
def save(self):
"""保存 Excel 文件"""
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
self.workbook.save(self.output_path)
print(f"Test results saved to: {self.output_path}")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.save()
def _init_headers(self):
"""初始化表头"""
headers = self.base_headers.copy()
for grad in self.grad_headers:
for detail in self.detail_headers:
headers.append(f"{grad}_{detail}")
headers.extend(self.seq_stat_headers)
for col, header in enumerate(headers, start=1):
cell = self.worksheet.cell(row=1, column=col, value=header)
cell.font = self.header_font
cell.fill = self.header_fill
cell.alignment = self.header_alignment
cell.border = self.border
self.worksheet.column_dimensions['A'].width = 30
for col in range(2, 12):
self.worksheet.column_dimensions[openpyxl.utils.get_column_letter(col)].width = 12
for col in range(12, 12 + len(self.seq_stat_headers)):
self.worksheet.column_dimensions[openpyxl.utils.get_column_letter(col)].width = 14
class BenchmarkRecord:
"""Benchmark测试结果记录类,用于将测试用例输入写入CSV"""
def __init__(self, output_path: str = "tmp_benchmark.csv"):
"""初始化记录器
Args:
output_path: CSV 文件输出路径
"""
self.output_path = output_path
self.headers = [
'batch_size', 'head_num', 'max_seqlen_q', 'max_seqlen_k',
'head_dim_qk', 'head_dim_v', 'has_rab', 'data_type', 'seed',
'seq_lens_q_mean', 'seq_lens_q_max', 'seq_lens_q_min',
'seq_lens_k_mean', 'seq_lens_k_max', 'seq_lens_k_min'
]
self.current_row = 0
self._init_file()
def _init_file(self):
"""初始化CSV文件"""
if os.path.exists(self.output_path):
os.remove(self.output_path)
with open(self.output_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(self.headers)
def record(self, params: Dict[str, Any], seq_stats: Optional[Dict[str, Any]] = None):
"""记录单个测试用例输入参数
Args:
params: 用例参数字典,包含:
▪ batch_size
▪ head_num
▪ head_dim_qk
▪ head_dim_v
▪ max_seqlen_q
▪ max_seqlen_k
▪ has_rab
▪ data_type
▪ seed
seq_stats: 序列长度统计数据,包含:
▪ seq_lens_q_mean, seq_lens_q_max, seq_lens_q_min
▪ seq_lens_k_mean, seq_lens_k_max, seq_lens_k_min
"""
row = [
params.get("batch_size"),
params.get("head_num"),
params.get("max_seqlen_q"),
params.get("max_seqlen_k"),
params.get("head_dim_qk"),
params.get("head_dim_v"),
params.get("has_rab"),
str(params.get("data_type")).replace('torch.', ''),
params.get("seed")
]
if seq_stats:
row.extend([
seq_stats.get("seq_lens_q_mean"),
seq_stats.get("seq_lens_q_max"),
seq_stats.get("seq_lens_q_min"),
seq_stats.get("seq_lens_k_mean"),
seq_stats.get("seq_lens_k_max"),
seq_stats.get("seq_lens_k_min")
])
else:
row.extend([None] * 6)
with open(self.output_path, 'a', newline='') as f:
writer = csv.writer(f)
writer.writerow(row)
self.current_row += 1
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass