from __future__ import annotations
import argparse
import random
from pathlib import Path
from typing import Iterable
from .config import load_op_mapping_metadata, load_shape_grid_config
from .generators.fused_attention import FIA_RUNTIME_COLUMNS
from .theory_router import collect_theory_generated_rows, get_default_theory_generator
from .utils import clear_progress, process_csv_with_generated_rows
def process_theory_csv(
csv_path: Path,
model_names: list[str] | None,
config: dict,
op_meta: dict[str, dict],
*,
max_rows: int | None = None,
rng: random.Random | None = None,
file_index: int,
total_files: int,
max_hbm_bytes: int | None = None,
) -> int | None:
kernel_type = csv_path.stem
gen = get_default_theory_generator(kernel_type, model_names, config, op_meta)
if gen is None:
return None
def build_theory_rows(headers: list[str], source_rows: list[dict[str, str]]) -> list[dict[str, str]]:
return collect_theory_generated_rows(
headers,
source_rows,
gen,
csv_path=csv_path,
file_index=file_index,
total_files=total_files,
max_rows=max_rows,
rng=rng,
max_hbm_bytes=max_hbm_bytes,
)
return process_csv_with_generated_rows(
csv_path,
require_rows=False,
extra_headers=FIA_RUNTIME_COLUMNS if kernel_type == "FusedInferAttentionScore" else None,
generated_rows_builder=build_theory_rows,
)
def iter_csv_files(data_dir: Path) -> Iterable[Path]:
return sorted(
path for path in data_dir.rglob("*.csv") if f".tmp{path.suffix}" not in path.name
)
def load_csv_files(data_dir: Path) -> list[Path]:
if not data_dir.is_dir():
raise ValueError(f"Data directory does not exist: {data_dir}")
csv_files = list(iter_csv_files(data_dir))
if not csv_files:
raise ValueError(f"No CSV files found under: {data_dir}")
return csv_files
def run_theory_mode(args: argparse.Namespace, data_dir: Path, csv_files: list[Path]) -> tuple[int, list[Path]]:
total_files = len(csv_files)
total_appended_rows = 0
skipped_files: list[Path] = []
model_names = (
[m.strip() for m in args.target_models.split(",") if m.strip()]
if args.target_models
else None
)
CURRENT_DIR = Path(__file__).resolve().parent
config_path = CURRENT_DIR / "config.yaml"
config = load_shape_grid_config(config_path)
op_meta = load_op_mapping_metadata(data_dir)
max_rows = args.rows if args.rows > 0 else None
rng = random.Random(args.seed) if max_rows else None
max_hbm_gb = getattr(args, 'max_hbm_gb', 32.0)
max_hbm_bytes = int(max_hbm_gb * 1024 ** 3) if max_hbm_gb and max_hbm_gb > 0 else None
print(f"Mode: theory | Target models: {model_names or 'ALL (full grid)'}")
print(f"Config: {config_path.name} | op_mapping: {bool(op_meta)} | max_rows/csv: {max_rows or 'unlimited'}")
if max_hbm_bytes:
print(f"HBM budget: {max_hbm_gb:.1f} GiB per shape row")
for file_index, csv_path in enumerate(csv_files, start=1):
appended_rows = process_theory_csv(
csv_path=csv_path,
model_names=model_names,
config=config,
op_meta=op_meta,
max_rows=max_rows,
rng=rng,
file_index=file_index,
total_files=total_files,
max_hbm_bytes=max_hbm_bytes,
)
if appended_rows is None or appended_rows == 0:
skipped_files.append(csv_path)
continue
total_appended_rows += appended_rows
return total_appended_rows, skipped_files