"""Tests for fill_fia_runtime_metadata JOIN-based backfill."""
import json
import tempfile
from pathlib import Path
import pytest
from tools.perf_data_collection.fill_fia_runtime_metadata import (
backfill,
build_csv_key,
build_jsonl_key,
build_runtime_payload,
load_jsonl,
)
from tests.helpers.cli_runner import run_module_main
class TestBuildCsvKey:
"""Extract JOIN key from CSV Input Shapes 31-slot format."""
def test_mla_prefill_tnd_3d(self):
"""MLA prefill: 3D TND shapes, mask present, no block_table."""
row = {
"Input Shapes": (
'"8192,16,128;8192,16,128;8192,16,128;;2048,2048;2;2;;;;;;;;;;;;;;;;;;8192,16,64;8192,16,64;;;;;"'
),
}
key = build_csv_key(row)
assert key == (
(8192, 16, 128),
(8192, 16, 128),
(8192, 16, 128),
2,
(2048, 2048),
None,
)
def test_mla_prefill_3seq(self):
"""3-seq batch: slot[6] = 3."""
row = {
"Input Shapes": (
'"8192,16,128;8192,16,128;8192,16,128;;2048,2048;3;3;;;;;;;;;;;;;;;;;;8192,16,64;8192,16,64;;;;;"'
),
}
key = build_csv_key(row)
assert key[3] == 3
def test_mla_chunk_no_mask(self):
"""Chunk context: Q != K, no atten_mask."""
row = {
"Input Shapes": ('"4105,16,128;4093,16,128;4093,16,128;;;2;2;;;;;;;;;;;;;;;;;;4105,16,64;4093,16,64;;;;;"'),
}
key = build_csv_key(row)
assert key[0] == (4105, 16, 128)
assert key[1] == (4093, 16, 128)
assert key[4] is None
def test_mla_decode_4d_bnsd(self):
"""MLA decode: 4D BNSD, block_table present."""
row = {
"Input Shapes": (
'"1,16,1,512;1170,1,128,512;1170,1,128,512;;;;1;;;;;;;;1,512;;;;;;;;;1,16,1,64;1170,1,128,64;;;;;"'
),
}
key = build_csv_key(row)
assert key[0] == (1, 16, 1, 512)
assert key[5] == (1, 512)
def test_empty_slots_produce_none(self):
"""Empty or missing slots should produce None, not crash."""
row = {"Input Shapes": '"4099,16,128;4099,16,128;4099,16,128"'}
key = build_csv_key(row)
assert key[3] is None
assert key[4] is None
assert key[5] is None
class TestBuildJsonlKey:
"""Extract JOIN key from JSONL dump record."""
def test_mla_prefill(self):
record = {
"query_shape": [8192, 16, 128],
"key_shape": [8192, 16, 128],
"value_shape": [8192, 16, 128],
"actual_seq_lengths_kv": [4099, 8192],
"atten_mask_shape": [2048, 2048],
"block_table_shape": None,
}
key = build_jsonl_key(record)
assert key == (
(8192, 16, 128),
(8192, 16, 128),
(8192, 16, 128),
2,
(2048, 2048),
None,
)
def test_mla_decode(self):
record = {
"query_shape": [1, 16, 1, 512],
"key_shape": [1170, 1, 128, 512],
"value_shape": [1170, 1, 128, 512],
"actual_seq_lengths_kv": [1],
"atten_mask_shape": None,
"block_table_shape": [1, 512],
}
key = build_jsonl_key(record)
assert key[3] == 1
assert key[5] == (1, 512)
def test_null_seq_lengths_kv(self):
"""actual_seq_lengths_kv can be None (some decode paths)."""
record = {
"query_shape": [1, 16, 1, 512],
"key_shape": [1170, 1, 128, 512],
"value_shape": [1170, 1, 128, 512],
"actual_seq_lengths_kv": None,
"atten_mask_shape": None,
"block_table_shape": [1, 512],
}
key = build_jsonl_key(record)
assert key[3] is None
class TestBuildRuntimePayload:
def test_prefill_payload(self):
record = {
"actual_seq_lengths": [4099, 8192],
"actual_seq_lengths_kv": [4099, 8192],
"block_table_valid_blocks": None,
"num_heads": 16,
"num_key_value_heads": 16,
"sparse_mode": 3,
"input_layout": "TND",
"block_size": 0,
}
payload = build_runtime_payload(record)
assert payload["Runtime actual_seq_lengths_values"] == "4099,8192"
assert payload["Runtime actual_seq_lengths_kv_values"] == "4099,8192"
assert payload["Runtime avg_seq_len"] == "6145.500000"
assert payload["Runtime num_heads"] == "16"
assert payload["Runtime sparse_mode"] == "3"
def test_none_seq_lengths(self):
record = {
"actual_seq_lengths": None,
"actual_seq_lengths_kv": [1],
"block_table_valid_blocks": [1],
"num_heads": 16,
"num_key_value_heads": 1,
"sparse_mode": 0,
"input_layout": "BNSD_NBSD",
"block_size": 128,
}
payload = build_runtime_payload(record)
assert payload["Runtime actual_seq_lengths_values"] == ""
assert payload["Runtime actual_seq_lengths_kv_values"] == "1"
assert payload["Runtime avg_seq_len"] == "1.000000"
class TestBackfill:
"""End-to-end backfill with in-memory data."""
def _make_csv_row(self, input_shapes: str, duration: str = "100.0") -> dict:
return {
"Input Shapes": input_shapes,
"Profiling Average Duration(us)": duration,
}
def _make_jsonl_record(
self,
q,
k,
v,
seq_kv,
mask=None,
bt=None,
sparse=3,
num_heads=16,
kv_heads=16,
) -> dict:
return {
"query_shape": q,
"key_shape": k,
"value_shape": v,
"actual_seq_lengths": seq_kv,
"actual_seq_lengths_kv": seq_kv,
"atten_mask_shape": mask,
"block_table_shape": bt,
"block_table_valid_blocks": None,
"num_heads": num_heads,
"num_key_value_heads": kv_heads,
"sparse_mode": sparse,
"input_layout": "TND",
"block_size": 0,
}
def test_1_to_1_match(self):
"""One CSV row matches exactly one JSONL record."""
csv_row = self._make_csv_row(
'"4099,16,128;4099,16,128;4099,16,128;;2048,2048;1;1;;;;;;;;;;;;;;;;;;4099,16,64;4099,16,64;;;;;"',
"600.0",
)
jsonl_record = self._make_jsonl_record(
[4099, 16, 128],
[4099, 16, 128],
[4099, 16, 128],
[4099],
mask=[2048, 2048],
)
jsonl_index = {build_jsonl_key(jsonl_record): [build_runtime_payload(jsonl_record)]}
rows, matched, total = backfill([csv_row], jsonl_index, "test")
assert matched == 1
assert total == 1
assert rows[0]["Profiling Average Duration(us)"] == "600.0"
assert rows[0]["Runtime avg_seq_len"] == "4099.000000"
def test_1_to_n_expansion(self):
"""One CSV row matches two JSONL records with different seq values."""
csv_row = self._make_csv_row(
'"8192,16,128;8192,16,128;8192,16,128;;2048,2048;;2;;;;;;;;;;;;;;;;;;8192,16,64;8192,16,64;;;;;"',
"1200.0",
)
rec_a = self._make_jsonl_record(
[8192, 16, 128],
[8192, 16, 128],
[8192, 16, 128],
[4099, 8192],
mask=[2048, 2048],
)
rec_b = self._make_jsonl_record(
[8192, 16, 128],
[8192, 16, 128],
[8192, 16, 128],
[100, 8192],
mask=[2048, 2048],
)
key = build_jsonl_key(rec_a)
jsonl_index = {key: [build_runtime_payload(rec_a), build_runtime_payload(rec_b)]}
rows, matched, total = backfill([csv_row], jsonl_index, "test")
assert matched == 1
assert total == 2
assert rows[0]["Profiling Average Duration(us)"] == "1200.0"
assert rows[1]["Profiling Average Duration(us)"] == "1200.0"
avg_values = {r["Runtime avg_seq_len"] for r in rows}
assert len(avg_values) == 2
def test_no_match_passthrough(self):
"""Unmatched CSV row passes through unchanged."""
csv_row = self._make_csv_row('"999,16,128;999,16,128;999,16,128"', "50.0")
rows, matched, total = backfill([csv_row], {}, "test")
assert matched == 0
assert total == 1
assert rows[0]["Profiling Average Duration(us)"] == "50.0"
assert rows[0].get("Runtime avg_seq_len") is None
def test_prefill_vs_chunk_no_cross_match(self):
"""Prefill (K==Q, mask) and chunk (K!=Q, no mask) don't cross-match."""
prefill_csv = self._make_csv_row(
'"4105,16,128;4105,16,128;4105,16,128;;2048,2048;;2;;;;;;;;;;;;;;;;;;4105,16,64;4105,16,64;;;;;"',
)
chunk_csv = self._make_csv_row(
'"4105,16,128;4093,16,128;4093,16,128;;;;2;;;;;;;;;;;;;;;;;;4105,16,64;4093,16,64;;;;;"',
)
prefill_rec = self._make_jsonl_record(
[4105, 16, 128],
[4105, 16, 128],
[4105, 16, 128],
[6, 4105],
mask=[2048, 2048],
)
chunk_rec = self._make_jsonl_record(
[4105, 16, 128],
[4093, 16, 128],
[4093, 16, 128],
[4093, 4093],
mask=None,
sparse=0,
)
jsonl_index = {
build_jsonl_key(prefill_rec): [build_runtime_payload(prefill_rec)],
build_jsonl_key(chunk_rec): [build_runtime_payload(chunk_rec)],
}
rows, matched, total = backfill([prefill_csv, chunk_csv], jsonl_index, "test")
assert matched == 2
assert total == 2
assert rows[0]["Runtime sparse_mode"] == "3"
assert rows[0]["Runtime actual_seq_lengths_kv_values"] == "6,4105"
assert rows[1]["Runtime sparse_mode"] == "0"
assert rows[1]["Runtime actual_seq_lengths_kv_values"] == "4093,4093"
def test_decode_seq_q_differs_from_seq_kv(self):
"""Decode: actual_seq_lengths (Q=1 per seq) != actual_seq_lengths_kv (KV=4500)."""
csv_row = self._make_csv_row(
'"1,16,1,512;1170,1,128,512;1170,1,128,512;;;;1;;;;;;;;1,512;;;;;;;;;;1,16,1,64;1170,1,128,64;;;;;"',
)
rec = {
"query_shape": [1, 16, 1, 512],
"key_shape": [1170, 1, 128, 512],
"value_shape": [1170, 1, 128, 512],
"actual_seq_lengths": None,
"actual_seq_lengths_kv": [4500],
"atten_mask_shape": None,
"block_table_shape": [1, 512],
"block_table_valid_blocks": [36],
"num_heads": 16,
"num_key_value_heads": 1,
"sparse_mode": 0,
"input_layout": "BNSD_NBSD",
"block_size": 128,
}
jsonl_index = {build_jsonl_key(rec): [build_runtime_payload(rec)]}
rows, matched, total = backfill([csv_row], jsonl_index, "test")
assert matched == 1
assert total == 1
assert rows[0]["Runtime actual_seq_lengths_values"] == ""
assert rows[0]["Runtime actual_seq_lengths_kv_values"] == "4500"
assert rows[0]["Runtime avg_seq_len"] == "4500.000000"
assert rows[0]["Runtime num_key_value_heads"] == "1"
assert rows[0]["Runtime block_table_valid_blocks"] == "36"
class TestLoadJsonl:
def test_dedup_identical_records(self):
"""Duplicate JSONL lines with same seq values produce one payload."""
records = [
{
"query_shape": [100, 16, 128],
"key_shape": [100, 16, 128],
"value_shape": [100, 16, 128],
"actual_seq_lengths": [100],
"actual_seq_lengths_kv": [100],
"atten_mask_shape": None,
"block_table_shape": None,
"block_table_valid_blocks": None,
"num_heads": 16,
"num_key_value_heads": 16,
"sparse_mode": 3,
"input_layout": "TND",
"block_size": 0,
},
] * 5
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
for r in records:
f.write(json.dumps(r) + "\n")
path = Path(f.name)
index = load_jsonl(path)
path.unlink()
assert len(index) == 1
payloads = next(iter(index.values()))
assert len(payloads) == 1
class TestEnsureFieldnames:
def test_adds_missing_columns(self):
from tools.perf_data_collection.fill_fia_runtime_metadata import (
ensure_fieldnames,
BACKFILL_COLUMNS,
)
result = ensure_fieldnames(["Input Shapes", "Duration(us)"])
for col in BACKFILL_COLUMNS:
assert col in result
assert "Input Shapes" in result
def test_no_duplication(self):
from tools.perf_data_collection.fill_fia_runtime_metadata import (
ensure_fieldnames,
RUNTIME_AVG_SEQ_LEN,
)
result = ensure_fieldnames([RUNTIME_AVG_SEQ_LEN])
assert result.count(RUNTIME_AVG_SEQ_LEN) == 1
class TestBuildArgparser:
def test_required_args(self):
from tools.perf_data_collection.fill_fia_runtime_metadata import build_argparser
parser = build_argparser()
with pytest.raises(SystemExit):
parser.parse_args([])
def test_parses_args(self):
from tools.perf_data_collection.fill_fia_runtime_metadata import build_argparser
parser = build_argparser()
args = parser.parse_args(
[
"--csv-path",
"fia.csv",
"--jsonl-path",
"runtime.jsonl",
]
)
assert args.csv_path == "fia.csv"
assert args.jsonl_path == "runtime.jsonl"
def test_parses_optional_args(self):
from tools.perf_data_collection.fill_fia_runtime_metadata import build_argparser
parser = build_argparser()
args = parser.parse_args(
[
"--csv-path",
"fia.csv",
"--jsonl-path",
"runtime.jsonl",
"--output-path",
"out.csv",
"--metadata-tag",
"custom_tag",
]
)
assert args.output_path == "out.csv"
assert args.metadata_tag == "custom_tag"
class TestMainCli:
def test_main_missing_csv_exits_nonzero(self, tmp_path):
result = run_module_main(
"tools.perf_data_collection.fill_fia_runtime_metadata",
[
"--csv-path",
str(tmp_path / "nonexistent.csv"),
"--jsonl-path",
str(tmp_path / "nonexistent.jsonl"),
],
)
assert result.returncode != 0
def test_main_success(self, tmp_path):
csv_path = tmp_path / "fia.csv"
csv_path.write_text(
"Input Shapes,Profiling Average Duration(us)\n"
'"4099,16,128;4099,16,128;4099,16,128;;2048,2048;;2;;;;;;;;;;;;;;;;;;4099,16,64;4099,16,64;;;;;",600.0\n'
)
jsonl_path = tmp_path / "runtime.jsonl"
jsonl_path.write_text(
json.dumps(
{
"query_shape": [4099, 16, 128],
"key_shape": [4099, 16, 128],
"value_shape": [4099, 16, 128],
"actual_seq_lengths": [4099],
"actual_seq_lengths_kv": [4099],
"atten_mask_shape": [2048, 2048],
"block_table_shape": None,
"block_table_valid_blocks": None,
"num_heads": 16,
"num_key_value_heads": 16,
"sparse_mode": 3,
"input_layout": "TND",
"block_size": 0,
}
)
+ "\n"
)
result = run_module_main(
"tools.perf_data_collection.fill_fia_runtime_metadata",
[
"--csv-path",
str(csv_path),
"--jsonl-path",
str(jsonl_path),
"--output-path",
str(tmp_path / "out.csv"),
],
)
assert result.returncode == 0