import os
import sys
import argparse
import importlib
import itertools
import torch
import numpy as np
from backends.cpu_impl import tforward
from utils.data import generate_qkv, generate_npu_mask, trans_bnsd_to_layout, gen_block_table
from utils.compare import data_compare_benchmark_new, Result, check_result, analyze_fail_distribution
from utils.io import save_tensor_to_txt, load_tensor_from_txt
_DEFAULT_CASE_FILES = ["base", "fia_stc", "functional_stc",
"functional_redline_train_tnd", "functional_redline_train",
"functional_redline_infer", "performance_redline_infer"]
GPU_AVAILABLE = False
flash_attn_gpu = None
NPU_AVAILABLE = False
flash_attn_npu = None
flash_attn_metadata_only = None
flash_attn_npu_graph = None
try:
from backends.gpu_impl import flash_attn_gpu as _gpu_func
flash_attn_gpu = _gpu_func
GPU_AVAILABLE = True
except ImportError:
pass
try:
from backends.npu_impl import flash_attn_npu as _npu_func
from backends.npu_impl import flash_attn_metadata_only as _meta_func
from backends.npu_impl import flash_attn_npu_graph as _graph_func
flash_attn_npu = _npu_func
flash_attn_metadata_only = _meta_func
flash_attn_npu_graph = _graph_func
NPU_AVAILABLE = True
except ImportError:
pass
def load_case_modules(module_names):
all_cases = {}
for mod_name in module_names:
mod = importlib.import_module(mod_name)
for case_name, case_dict in mod.TestCases.items():
keys = list(case_dict.keys())
values = [case_dict[k] for k in keys]
combos = list(itertools.product(*values))
if len(combos) == 1:
all_cases[f"{mod_name}/{case_name}"] = dict(zip(keys, combos[0]))
else:
for i, combo in enumerate(combos):
all_cases[f"{mod_name}/{case_name}_{i}"] = dict(zip(keys, combo))
return all_cases
def normalize_case(raw):
"""将 test_cases 模块字段转换为 call_flash_attn 所需 kwargs 格式。"""
c = dict(raw)
layout_q = c.get("layout_q", "BNSD")
c["input_layout"] = layout_q
c.setdefault("layout_kv", c.get("layout_kv", layout_q))
c.setdefault("layout_out", c.get("layout_out", layout_q))
c.setdefault("N2", c.get("N1"))
c.setdefault("S2", c.get("S1"))
c.setdefault("DV", c.get("D"))
c.setdefault("DRope", 0)
c.setdefault("q_start_idx", 0)
c.setdefault("win_left", 2147483647)
c.setdefault("win_right", 2147483647)
c.setdefault("prefix", [])
c.setdefault("q_range", None)
c.setdefault("k_range", None)
c.setdefault("v_range", None)
for key in ("cu_seqlens_q", "cu_seqlens_kv", "seqused_q", "seqused_kv"):
if c.get(key) == [None] or c.get(key) is None:
c.pop(key, None)
if layout_q == "TND":
c.setdefault("seqused_q", [c["cu_seqlens_q"][i+1] - c["cu_seqlens_q"][i]
for i in range(len(c["cu_seqlens_q"]) - 1)])
if c.get("layout_kv", layout_q) not in ("PA_BBND", "PA_BNBD", "PA_NZ"):
c.setdefault("cu_seqlens_kv", list(c["cu_seqlens_q"]))
c.setdefault("seqused_kv", [c["cu_seqlens_kv"][i+1] - c["cu_seqlens_kv"][i]
for i in range(len(c["cu_seqlens_kv"]) - 1)])
c["B"] = 1
layout_kv_val = c.get("layout_kv", layout_q)
if layout_kv_val in ("PA_BBND", "PA_BNBD", "PA_NZ"):
c.setdefault("cu_seqlens_kv", None)
c.setdefault("block_size", c.get("block_size"))
c.setdefault("block_table_shape", c.get("block_table_shape"))
block_size = c.get("block_size", 1)
block_table_shape = c.get("block_table_shape", [])
if "seqused_kv" not in c:
c["seqused_kv"] = [c.get("S2", c.get("S1"))] * c.get("B", 1)
seqused_kv = c.get("seqused_kv")
if ("block_table" in c.keys()) and (c.get("block_table") is not None):
block_table = torch.tensor(c.get("block_table"), dtype=torch.int32)
c["block_table"] = block_table
else:
b_val = c.get("B")
if block_table_shape:
b_val = block_table_shape[0]
max_block_num_per_batch = block_table_shape[1]
else:
if layout_q == "TND":
b_val = len(c["cu_seqlens_q"]) - 1
max_block_num_per_batch = (max(seqused_kv) + block_size - 1) // block_size
block_table = torch.full((b_val, max_block_num_per_batch), -1, dtype=torch.int32)
block_idx = 0
for i in range(b_val):
b_seq = seqused_kv[i] if len(seqused_kv) > i else seqused_kv[0]
b_block_num = (b_seq + block_size - 1) // block_size
for j in range(b_block_num):
block_table[i][j] = block_idx
block_idx += 1
c["block_table"] = block_table
c["actual_b"] = block_table.shape[0]
return c
def resolve_case_ids(case_id_arg, all_cases):
if case_id_arg == "all":
return sorted(all_cases.keys())
ids = [x.strip() for x in case_id_arg.split(",")]
result = []
missing = []
for cid in ids:
full_names = [k for k in all_cases if k.endswith(f"/{cid}") or k == cid or k.rsplit("/", 1)[-1].startswith(f"{cid}_")]
if len(full_names) >= 1:
result.extend(full_names)
else:
missing.append(cid)
if missing:
print(f"[WARN] 以下 case 不存在: {missing}")
return result
def call_flash_attn(test_name, dump_tensors=False, dump_dir="./dump_output",
verbose_diff=False, visualize=False, viz_dir="./viz_output",
meta_only=False, compare_mode=False, graph_mode=False,
load_gpu_dump=None, load_npu_dump=None, fail_analysis=False,
**kwargs):
b = kwargs.get("B", 1)
n1 = kwargs.get("N1")
n2 = kwargs.get("N2", n1)
sq = kwargs.get("S1", -1)
skv = kwargs.get("S2", sq)
d = kwargs.get("D")
d_v = kwargs.get("DV", d)
d_rope = kwargs.get("DRope", 0)
input_layout = kwargs.get("input_layout")
layout_kv = kwargs.get("layout_kv", input_layout)
output_layout = kwargs.get('layout_out')
q_start_idx = kwargs.get("q_start_idx", 0)
dtype = kwargs.get("Dtype", "bf16")
pttype = torch.float16 if dtype == "fp16" else torch.bfloat16
input_dtype = pttype
mask_mode = kwargs.get("mask_mode", None)
win_left = kwargs.get("win_left", 2147483647)
win_right = kwargs.get("win_right", 2147483647)
prefix = kwargs.get("prefix", [])
q_range = kwargs.get("q_range", None)
k_range = kwargs.get("k_range", None)
v_range = kwargs.get("v_range", None)
lse_flag = kwargs.get("return_softmax_lse", 0)
sq_gen = sq; skv_gen = skv
if input_layout == "TND":
cu_q = kwargs["cu_seqlens_q"]
sq = max(cu_q[i+1] - cu_q[i] for i in range(len(cu_q) - 1))
sq_gen = cu_q[-1]
if layout_kv in ("PA_BBND", "PA_BNBD", "PA_NZ"):
skv = max(kwargs["seqused_kv"])
skv_gen = sum(kwargs["seqused_kv"])
else:
cu_kv = kwargs["cu_seqlens_kv"]
skv = max(cu_kv[i+1] - cu_kv[i] for i in range(len(cu_kv) - 1))
skv_gen = cu_kv[-1]
kwargs["S1"] = sq
kwargs["S2"] = skv
q, k, v, q_rope, k_rope, qf, kf = generate_qkv(b, n1, n2, sq_gen, skv_gen, d, d_v, d_rope,
input_layout, input_dtype,
q_range=q_range, k_range=k_range, v_range=v_range)
print(f"[{test_name}] CPU 参考计算...")
out_cpu, x_max, x_sum = tforward(qf, kf, v, **kwargs)
out_cpu_out_layout = trans_bnsd_to_layout(out_cpu, output_layout, **kwargs)
lse_out = torch.log(x_sum) + x_max
if dump_tensors:
dump_path = os.path.join(dump_dir, test_name)
os.makedirs(dump_path, exist_ok=True)
save_tensor_to_txt(q, os.path.join(dump_path, "q.txt"))
save_tensor_to_txt(k, os.path.join(dump_path, "k.txt"))
save_tensor_to_txt(v, os.path.join(dump_path, "v.txt"))
save_tensor_to_txt(out_cpu_out_layout.float(), os.path.join(dump_path, "cpu_out.txt"))
gpu_out = None
npu_out = None
if kwargs.get("_use_gpu", False):
if not GPU_AVAILABLE:
print(f"[{test_name}] 警告: GPU 不可用,但要求 GPU 模式,尝试从 dump 加载。")
if GPU_AVAILABLE:
print(f"[{test_name}] GPU 计算...")
atten_mask = generate_npu_mask(b, sq, skv, mask_mode, win_left, win_right, prefix)
device = torch.cuda.current_device()
q_gpu = q.cuda(device); k_gpu = k.cuda(device); v_gpu = v.cuda(device)
q_rope_gpu = q_rope.cuda(device) if q_rope is not None else None
k_rope_gpu = k_rope.cuda(device) if k_rope is not None else None
mask_gpu = atten_mask.cuda(device) if atten_mask is not None else None
gpu_out_bnsd = flash_attn_gpu(q_gpu, k_gpu, v_gpu, q_rope_gpu, k_rope_gpu, mask_gpu, **kwargs)
gpu_out = trans_bnsd_to_layout(gpu_out_bnsd, output_layout, **kwargs).cpu()
if dump_tensors:
dump_path = os.path.join(dump_dir, test_name)
os.makedirs(dump_path, exist_ok=True)
save_tensor_to_txt(gpu_out.float(), os.path.join(dump_path, "gpu_out.txt"))
print(f"[{test_name}] 已保存 q/k/v/gpu_out → {dump_path}/")
else:
print(f"[{test_name}] GPU 不可用,跳过计算。")
if gpu_out is None or load_gpu_dump:
gpu_dump_path = load_gpu_dump if load_gpu_dump else os.path.join(dump_dir, test_name, "gpu_out.txt")
if os.path.exists(gpu_dump_path):
print(f"[{test_name}] 从 dump 加载 GPU 结果: {gpu_dump_path}")
gpu_out = load_tensor_from_txt(gpu_dump_path, target_dtype=torch.float32, target_device='cpu')
else:
print(f"[{test_name}] 未找到 GPU dump 文件: {gpu_dump_path}")
if not kwargs.get("_use_gpu", False):
if not NPU_AVAILABLE:
print(f"[{test_name}] 警告: NPU 不可用,尝试从 dump 加载。")
if NPU_AVAILABLE and (n1 >= n2 and n1 % n2 == 0):
atten_mask = generate_npu_mask(b, sq, skv, mask_mode, win_left, win_right, prefix)
if graph_mode:
print(f"[{test_name}] NPU 图模式计算...")
npu_out, lse_npu = flash_attn_npu_graph(q, k, v, q_rope, k_rope, atten_mask, **kwargs)
else:
print(f"[{test_name}] NPU 单算子模式计算...")
npu_out, lse_npu = flash_attn_npu(q, k, v, q_rope, k_rope, atten_mask, **kwargs)
elif NPU_AVAILABLE:
print(f"[{test_name}] 跳过 NPU: N1={n1} 不满足 N1>=N2 且 N1%N2==0 (N2={n2})")
else:
print(f"[{test_name}] NPU 不可用,跳过计算。")
if npu_out is None or load_npu_dump:
npu_dump_path = load_npu_dump if load_npu_dump else os.path.join(dump_dir, test_name, "npu_out.txt")
if os.path.exists(npu_dump_path):
print(f"[{test_name}] 从 dump 加载 NPU 结果: {npu_dump_path}")
npu_out = load_tensor_from_txt(npu_dump_path, target_dtype=torch.float32, target_device='cpu')
else:
print(f"[{test_name}] 未找到 NPU dump 文件: {npu_dump_path}")
passed_cpu_gpu = False
passed_cpu_npu = False
passed_gpu_npu = False
if compare_mode:
if gpu_out is not None and npu_out is not None:
print(f"\n{'='*40} 三方对比模式(详细精度统计) {'='*40}")
dtype_str = kwargs.get("Dtype", "fp16")
params = {
"op_name": "flash_attn", "case_name": test_name,
"dtype_output": [dtype_str], "dtype_input": [dtype_str],
"red_range": {
"fp32": "0.000001/0.00001/0.0001/0.0005",
"fp16": "0.001/0.002/0.005/0.01",
"bf16": "0.001/0.002/0.005/0.01",
},
"bm_cmp_std": {
"fp32": {"max_re_rtol": 10.0, "avg_re_rtol": 2.0, "rmse_rtol": 2.0,
"small_value": 1e-06, "small_value_atol": 0.0},
"fp16": {"max_re_rtol": 10.0, "avg_re_rtol": 2.0, "rmse_rtol": 2.0,
"small_value": 0.001, "small_value_atol": 0.001},
"bf16": {"max_re_rtol": 10.0, "avg_re_rtol": 2.0, "rmse_rtol": 2.0,
"small_value": 1e-07, "small_value_atol": 0.004},
},
}
str1, str2, data = data_compare_benchmark_new(params, npu_out.float().numpy(),
gpu_out.float().numpy(),
out_cpu_out_layout.float().numpy(),
dtype_str, i=0)
print("\n====== 三方对比结果 ======")
print(f"状态: {str1}, 原因: {str2}")
passed_cpu_gpu = (str1 == 'Pass' or str1 == 'warning')
passed_cpu_npu = (str1 == 'Pass' or str1 == 'warning')
passed_gpu_npu = (str1 == 'Pass' or str1 == 'warning')
else:
print(f"[{test_name}] 缺少 GPU 或 NPU 结果,无法执行三方对比,使用标准对比模式。")
compare_mode = False
if not compare_mode:
if gpu_out is not None:
print(f"\n{'='*40} 对比 CPU vs GPU {'='*40}")
passed_cpu_gpu = check_result(test_name + "_CPU_vs_GPU", out_cpu_out_layout.float(), gpu_out.float(),
except_label="CPU", comp_label="GPU", verbose_diff=verbose_diff)
else:
print(f"[{test_name}] 缺少 GPU 结果,跳过 CPU vs GPU 对比。")
if npu_out is not None:
print(f"\n{'='*40} 对比 CPU vs NPU {'='*40}")
passed_cpu_npu = check_result(test_name + "_CPU_vs_NPU", out_cpu_out_layout.float(), npu_out.float(),
except_label="CPU", comp_label="NPU", verbose_diff=verbose_diff)
if fail_analysis and isinstance(passed_cpu_npu, dict) and not passed_cpu_npu["passed"]:
analyze_fail_distribution(test_name, out_cpu_out_layout, npu_out, dump_dir=dump_dir, **kwargs)
else:
print(f"[{test_name}] 缺少 NPU 结果,跳过 CPU vs NPU 对比。")
if gpu_out is not None and npu_out is not None:
print(f"\n{'='*40} 对比 GPU vs NPU {'='*40}")
passed_gpu_npu = check_result(test_name + "_GPU_vs_NPU", gpu_out.float(), npu_out.float(),
except_label="GPU", comp_label="NPU", verbose_diff=verbose_diff)
else:
print(f"[{test_name}] 缺少 GPU 或 NPU 结果,跳过 GPU vs NPU 对比。")
if visualize:
try:
from utils.precision_visual import visualize_from_tensors
if npu_out is not None:
visualize_from_tensors(out_cpu_out_layout.float(), npu_out.float(), case_name=test_name, out_dir=viz_dir)
if gpu_out is not None:
visualize_from_tensors(out_cpu_out_layout.float(), gpu_out.float(), case_name=test_name + "_gpu", out_dir=viz_dir)
except ImportError:
print("[WARN] precision_visual 导入失败,请确认 matplotlib 已安装")
except Exception as exc:
print(f"[WARN] 可视化异常: {exc}")
passed_lse = 1
if lse_flag == 1 and npu_out is not None:
passed_lse = 0
print(f"\n{'='*40} 对比 lse goldem vs lse npu {'='*40}")
passed_lse = check_result("LSE CHECK", lse_out.float(), lse_npu.float(),
except_label="lse_golden", comp_label="lse_npu", verbose_diff=verbose_diff)
elif lse_flag == 1:
print(f"[{test_name}] 缺少 NPU 结果,跳过 LSE 对比。")
def _to_stats(val):
if isinstance(val, dict):
return val
if val:
return {"passed": True, "max_abs": 0.0, "mean_abs": 0.0, "fail_cnt": 0, "total": 0, "fail_ratio": 0.0}
return None
attn_stats = _to_stats(passed_cpu_npu) or _to_stats(passed_cpu_gpu)
if attn_stats is None:
attn_stats = {"passed": False, "max_abs": float('nan'), "mean_abs": float('nan'),
"fail_cnt": -1, "total": -1, "fail_ratio": float('nan')}
if lse_flag != 1:
lse_stats = {"passed": True, "max_abs": 0.0, "mean_abs": 0.0, "fail_cnt": 0, "total": 0, "fail_ratio": 0.0}
else:
lse_stats = _to_stats(passed_lse)
if lse_stats is None:
lse_stats = {"passed": False, "max_abs": float('nan'), "mean_abs": float('nan'),
"fail_cnt": -1, "total": -1, "fail_ratio": float('nan')}
return {"attn": attn_stats, "lse": lse_stats}
def _build_result_row(name, res, max_name_len):
a = res["attn"]
l = res["lse"]
a_tag = "✓ PASS" if a["passed"] else "✗ FAIL"
l_tag = "✓ PASS" if l["passed"] else "✗ FAIL"
a_max = f"{a['max_abs']:.6f}" if a['total'] > 0 else "N/A"
a_fr = f"{a['fail_ratio']*100:.4f}%" if a['total'] > 0 else "N/A"
l_max = f"{l['max_abs']:.6f}" if l['total'] > 0 else "---"
l_fr = f"{l['fail_ratio']*100:.4f}%" if l['total'] > 0 else "---"
return (f"│ {name:<{max_name_len}} "
f"{a_tag:>8} {a_max:>12} {a_fr:>10} │ "
f"{l_tag:>8} {l_max:>12} {l_fr:>10} │")
def _print_results_table(results, total_cases, is_final=False):
"""Print formatted results table. Returns (pass_cnt, fail_cnt)."""
SEP = "─" * 120
max_name_len = max((len(n) for n in results.keys()), default=28)
if is_final:
title = f" 汇总结果 ({len(results)} cases)"
else:
title = f" 中间统计结果 ({len(results)}/{total_cases} cases)"
print(f"\n┌{SEP}┐")
print(f"│{title}")
print(f"├{SEP}┤")
hdr = (f"│ {'Case':<{max_name_len}} "
f"{'Attn':>8} {'MaxAbsErr':>12} {'FailRatio':>10} │ "
f"{'LSE':>8} {'MaxAbsErr':>12} {'FailRatio':>10} │")
print(hdr)
print(f"├{SEP}┤")
pass_cnt = fail_cnt = 0
for name, res in results.items():
print(_build_result_row(name, res, max_name_len))
both_pass = res["attn"]["passed"] and res["lse"]["passed"]
if both_pass: pass_cnt += 1
else: fail_cnt += 1
print(f"├{SEP}┤")
if is_final:
print(f"│ 通过: {pass_cnt} 失败: {fail_cnt} 共: {len(results)}")
else:
print(f"│ 已执行: {len(results)}/{total_cases} 通过: {pass_cnt} 失败: {fail_cnt}")
print(f"└{SEP}┘")
return pass_cnt, fail_cnt
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FlashAttn 精度测试(CPU vs GPU vs NPU)")
parser.add_argument("--case_files", type=str, default=",".join(_DEFAULT_CASE_FILES),
help=f"逗号分隔的 test case 模块名(不含 .py)")
parser.add_argument("--case_id", type=str, default="all", help="case名称,多个用逗号分隔")
parser.add_argument("--device_id", type=int, default=0, help="NPU device id")
parser.add_argument("--use_gpu", action="store_true", help="使用 GPU 进行计算")
parser.add_argument("--gpu-device", type=int, default=0, help="GPU 设备 ID")
parser.add_argument("--dump_tensors", action="store_true", help="保存 q/k/v 及输出为 txt")
parser.add_argument("--dump_dir", type=str, default="./dump_output")
parser.add_argument("--verbose_diff", action="store_true", help="逐元素输出精度对比表")
parser.add_argument("--visualize", action="store_true", help="生成精度热力图")
parser.add_argument("--viz_dir", type=str, default="./viz_output")
parser.add_argument("--meta_only", action="store_true", help="只调用 metadata 算子")
parser.add_argument("--load_gpu_dump", type=str, default=None, help="GPU dump 文件路径")
parser.add_argument("--load_npu_dump", type=str, default=None, help="NPU dump 文件路径")
parser.add_argument("--compare_mode", action="store_true", help="三方对比模式")
parser.add_argument("--graph_mode", action="store_true", help="图模式调用 NPU 算子")
parser.add_argument("--fail_analysis", action="store_true", help="失败元素分布分析")
parser.add_argument("--report_interval", type=int, default=20, help="每 N 条用例输出一次中间统计")
args = parser.parse_args()
if args.use_gpu:
if not torch.cuda.is_available():
print("[ERROR] CUDA 不可用。"); sys.exit(1)
torch.cuda.set_device(args.gpu_device)
print(f"[INFO] GPU 模式,使用设备 cuda:{args.gpu_device}")
elif args.compare_mode:
if torch.cuda.is_available():
torch.cuda.set_device(args.gpu_device)
print(f"[INFO] Compare 模式,使用 GPU 设备 cuda:{args.gpu_device}")
elif args.load_gpu_dump and args.load_npu_dump:
print(f"[INFO] Compare 模式,从 dump 文件对比(无 GPU/NPU)")
else:
try:
import torch_npu
torch.npu.set_device(args.device_id)
print(f"[INFO] Compare 模式,使用 NPU 设备 npu:{args.device_id}")
except ImportError:
print("[ERROR] compare_mode 需要至少一种设备或同时指定 --load_gpu_dump 和 --load_npu_dump")
sys.exit(1)
else:
try:
import torch_npu
torch.npu.set_device(args.device_id)
print(f"[INFO] NPU 模式,使用设备 npu:{args.device_id}")
except ImportError:
print("[ERROR] torch_npu 未安装。"); sys.exit(1)
case_files = args.case_files
module_names = ["test_cases." + x.strip() for x in case_files.split(",")]
all_cases = load_case_modules(module_names)
run_cases = resolve_case_ids(args.case_id, all_cases)
if not run_cases:
print("[ERROR] 没有可运行的 case。"); sys.exit(1)
results = {}
_int_cnt = 0
for name in run_cases:
_int_cnt += 1
if NPU_AVAILABLE:
torch.npu.synchronize()
config = all_cases[name]
try:
kwargs = normalize_case(config)
except Exception as e:
import traceback
print(f"[ERROR] {name} normalize_case 异常: {e}")
traceback.print_exc()
results[name] = {"attn": {"passed": False, "max_abs": float('nan'), "mean_abs": float('nan'),
"fail_cnt": -1, "total": -1, "fail_ratio": float('nan')},
"lse": {"passed": False, "max_abs": float('nan'), "mean_abs": float('nan'),
"fail_cnt": -1, "total": -1, "fail_ratio": float('nan')}}
if args.report_interval > 0 and _int_cnt % args.report_interval == 0 and _int_cnt < len(run_cases):
_print_results_table(results, len(run_cases))
continue
kwargs["_use_gpu"] = args.use_gpu
print(f"\n{'='*66}")
print(f" Case: {name} "
f"B={kwargs.get('B')} N1={kwargs.get('N1')} N2={kwargs.get('N2')} "
f"S1={kwargs.get('S1')} S2={kwargs.get('S2')} D={kwargs.get('D')} "
f"layout={kwargs.get('input_layout')} layout_out={kwargs.get('layout_out')} dtype={kwargs.get('Dtype')}")
print(f"{'='*66}")
try:
passed = call_flash_attn(
name, dump_tensors=args.dump_tensors, dump_dir=args.dump_dir,
verbose_diff=args.verbose_diff, visualize=args.visualize, viz_dir=args.viz_dir,
meta_only=args.meta_only, compare_mode=args.compare_mode, graph_mode=args.graph_mode,
load_gpu_dump=args.load_gpu_dump, load_npu_dump=args.load_npu_dump,
fail_analysis=args.fail_analysis, **kwargs)
except Exception as e:
import traceback
print(f"[ERROR] {name} 运行异常: {e}")
traceback.print_exc()
passed = {"attn": {"passed": False, "max_abs": float('nan'), "mean_abs": float('nan'),
"fail_cnt": -1, "total": -1, "fail_ratio": float('nan')},
"lse": {"passed": False, "max_abs": float('nan'), "mean_abs": float('nan'),
"fail_cnt": -1, "total": -1, "fail_ratio": float('nan')}}
results[name] = passed
if args.report_interval > 0 and _int_cnt % args.report_interval == 0 and _int_cnt < len(run_cases):
_print_results_table(results, len(run_cases))
fail_cnt = _print_results_table(results, len(run_cases), is_final=True)[1]
sys.exit(0 if fail_cnt == 0 else 1)