"""
Generate Q/K input and golden output for TBMM_QK cases.
Supports dynamic case configuration via CLI or generated_cases.json.
Softmax/pv tiling follows TILE_S1 (default 256).
"""
import argparse
import json
import os
import shutil
from pathlib import Path
import numpy as np
np.random.seed(7)
S0_BASE = 64
HEAD_SIZE = 128
TILE_S1_DEFAULT = 128
def gen_case(path, s0, s1, head_size=HEAD_SIZE, cube_s1=128, tile_s1=TILE_S1_DEFAULT, is_causal=False):
q_fp32 = (np.random.randn(s0, head_size).astype(np.float16) * 1.5).astype(np.float32)
k_fp32 = (np.random.randn(head_size, s1).astype(np.float16) * 1.5).astype(np.float32)
q = q_fp32.astype(np.float16)
k = k_fp32.astype(np.float16)
golden = (q.astype(np.float32).dot(k.astype(np.float32))).astype(np.float32)
assert s1 % tile_s1 == 0, "S1 must be divisible by TILE_S1"
assert tile_s1 % cube_s1 == 0, "TILE_S1 must be divisible by CUBE_S1"
q.tofile(os.path.join(path, 'q.bin'))
k.tofile(os.path.join(path, 'k.bin'))
kt = k.T.astype(np.float16)
kt.tofile(os.path.join(path, 'kt.bin'))
golden.tofile(os.path.join(path, 'qk.bin'))
arr_f32 = golden.astype(np.float32)
if is_causal:
mask = np.triu((np.ones(arr_f32.shape) * float(-3.40282e+38)).astype(np.float32), 1)
arr_f32 += mask
scale = 1/np.sqrt(head_size)
num_tiles = s1 // tile_s1
full_exp = np.zeros((s0, s1), dtype=np.float32)
global_sums = []
exp_max_parts = []
global_max = None
global_sum = None
for ti in range(num_tiles):
c0 = ti * tile_s1
c1 = c0 + tile_s1
tile = arr_f32[:, c0:c1]
local_max = np.max(tile, axis=1, keepdims=True).astype(np.float32)
if global_max is not None:
local_max = np.maximum(local_max, global_max).astype(np.float32)
if ti == 0:
new_global_max = local_max
tmp_float = (tile - new_global_max) * scale
tmp_float_exp = np.exp(tmp_float).astype(np.float32)
new_global_sum = (np.sum(tmp_float_exp, axis=1, keepdims=True).astype(np.float32))
exp_max_tile = np.ones_like(new_global_max).astype(np.float32)
else:
exp_max = (global_max - local_max).astype(np.float32)
exp_max = np.exp(exp_max * scale).astype(np.float32)
new_global_max = local_max
tmp_float = (tile - new_global_max) * scale
tmp_float_exp = np.exp(tmp_float).astype(np.float32)
new_global_sum = exp_max * global_sum + (np.sum(tmp_float_exp, axis=1, keepdims=True).astype(np.float32) )
exp_max_tile = exp_max
full_exp[:, c0:c1] = tmp_float_exp
global_sums.append(new_global_sum.reshape(-1))
exp_max_parts.append(exp_max_tile.reshape(-1))
global_max = new_global_max
global_sum = new_global_sum
tmp_float_exp = full_exp
soft = (full_exp).astype(np.float16)
soft.tofile(os.path.join(path, 'p.bin'))
tmp_float_exp.tofile(os.path.join(path, 'p_fp32.bin'))
v_fp32 = (np.random.randn(s1, head_size).astype(np.float16) * 1.2).astype(np.float32)
v = v_fp32.astype(np.float16)
soft_f32 = soft.astype(np.float32)
pv = np.zeros((s0, head_size), dtype=np.float32)
num_tiles = s1 // tile_s1
pv_tile_fifo_parts = []
for ti in range(num_tiles):
c0 = ti * tile_s1
soft_tile = soft_f32[:, c0:c0+tile_s1]
v_tile = v[c0:c0+tile_s1, :].astype(np.float32)
pv_tile_fifo = (soft_tile.dot(v_tile)).astype(np.float32)
pv_tile_fifo_parts.append(pv_tile_fifo)
pv += pv_tile_fifo
v.tofile(os.path.join(path, 'v.bin'))
vt = v.T.astype(np.float16)
vt.tofile(os.path.join(path, 'vt.bin'))
pv.tofile(os.path.join(path, 'pv.bin'))
for idx, part in enumerate(pv_tile_fifo_parts):
part.tofile(os.path.join(path, f'pv_tile_fifo{idx}.bin'))
for idx, g in enumerate(global_sums):
g.astype(np.float32).tofile(os.path.join(path, f'global_sum_part{idx}.bin'))
for idx, e in enumerate(exp_max_parts):
e.astype(np.float32).tofile(os.path.join(path, f'exp_max_part{idx}.bin'))
o_running = np.zeros((s0, head_size), dtype=np.float32)
for ti, part in enumerate(pv_tile_fifo_parts):
if ti == 0:
o_running = part.copy()
else:
exp_max_tile = exp_max_parts[ti].reshape((s0, 1)).astype(np.float32)
o_running = exp_max_tile * o_running + part
if ti == num_tiles - 1:
new_global_sum_tile = global_sums[ti].reshape((s0, 1)).astype(np.float32)
o_running = o_running / new_global_sum_tile
o_running.astype(np.float32).tofile(os.path.join(path, f'o_part{ti}.bin'))
o_running.astype(np.float32).tofile(os.path.join(path, 'o.bin'))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Generate TFA golden data")
parser.add_argument("--case", dest="case_name", help="Generate only the specified case name (e.g. case_float_H_128_S0_512_S1_2048)")
parser.add_argument("--cases", action="append", help="Explicit case entry HEAD_SIZE,S0,S1[,CUBE_S0[,TILE_S1]] (repeatable; CUBE_S1 fixed at 128)")
parser.add_argument("--cases-json", dest="cases_json", default=None, help="Path to generated_cases.json (defaults to ../generated_cases.json if present)")
parser.add_argument("--head-size", type=int, help="HEAD_SIZE for a single on-demand case")
parser.add_argument("--s0", type=int, help="S0 for a single on-demand case")
parser.add_argument("--s1", type=int, help="S1 for a single on-demand case")
parser.add_argument("--causal-mask", type=int, help="Enable causel mask")
args = parser.parse_args()
script_root = Path(__file__).resolve().parents[1]
default_json = script_root / "generated_cases.json"
def parse_case_entry(entry: str):
parts = [p.strip() for p in entry.split(',') if p.strip()]
if len(parts) < 3:
raise ValueError("Case entry must be HEAD_SIZE,S0,S1 or HEAD_SIZE,S0,S1,CUBE_S0[,TILE_S1]")
head, s0, s1 = map(int, parts[:3])
cube_s0 = int(parts[3]) if len(parts) >= 4 else s0
if s0 % cube_s0 != 0:
raise ValueError("S0 must be divisible by CUBE_S0")
tile_s1 = int(parts[4]) if len(parts) >= 5 else TILE_S1_DEFAULT
cube_s1 = 128
return head, s0, s1, cube_s1, tile_s1
cases = []
if args.cases:
for entry in args.cases:
head, s0, s1, cube_s1, tile_s1 = parse_case_entry(entry)
cases.append((f"case_float_H_{head}_S0_{s0}_S1_{s1}", (s0, head, s1, cube_s1, tile_s1)))
elif args.head_size and args.s0 and args.s1:
cases.append((f"case_float_H_{args.head_size}_S0_{args.s0}_S1_{args.s1}", (args.s0, args.head_size, args.s1, 128, TILE_S1_DEFAULT)))
elif args.cases_json or default_json.exists():
json_path = Path(args.cases_json) if args.cases_json else default_json
payload = json.loads(json_path.read_text())
for entry in payload:
cases.append((entry["name"], (entry["s0"], entry["head_size"], entry["s1"], entry.get("cube_s1", 128), entry.get("tile_s1", TILE_S1_DEFAULT))))
else:
cases = [
('case_float_H_128_S0_128_S1_1024', (128, HEAD_SIZE, 1024, 128, TILE_S1_DEFAULT)),
('case_float_H_128_S0_128_S1_2048', (128, HEAD_SIZE, 2048, 128, TILE_S1_DEFAULT)),
('case_float_H_128_S0_128_S1_8192', (128, HEAD_SIZE, 8192, 128, TILE_S1_DEFAULT)),
('case_float_H_128_S0_512_S1_1024', (512, HEAD_SIZE, 1024, 128, TILE_S1_DEFAULT)),
('case_float_H_128_S0_512_S1_2048', (512, HEAD_SIZE, 2048, 128, TILE_S1_DEFAULT)),
('case_float_H_128_S0_512_S1_8192', (512, HEAD_SIZE, 8192, 128, TILE_S1_DEFAULT)),
]
if args.case_name:
target = args.case_name
if target.endswith('_precision_debug'):
target = target[:-len('_precision_debug')]
filtered = [entry for entry in cases if entry[0] == target]
if filtered:
cases = filtered
else:
try:
head, s0, s1, cube_s1, tile_s1 = parse_case_entry(target)
synthetic_name = f"case_float_H_{head}_S0_{s0}_S1_{s1}"
cases = [(synthetic_name, (s0, head, s1, cube_s1, tile_s1))]
except Exception:
raise ValueError(f"Requested case '{args.case_name}' not found in configured cases")
build_dir = script_root / "build"
for name, (s0, head_size, s1, cube_s1, tile_s1) in cases:
case_dir = build_dir / name
os.makedirs(case_dir, exist_ok=True)
gen_case(str(case_dir), s0, s1, head_size, cube_s1, tile_s1, bool(args.causal_mask))