from __future__ import annotations
import sys
from itertools import product
from pathlib import Path
from typing import Any, Callable, Generator
from .evaluator import SafeExprEval, _parse_shape_expr
try:
from .model_configs import get_matmul_nk_pairs
from .generators import (
TheoryShapeRow,
generate_dispatch_ffn_combine_rows,
generate_fused_attention_rows,
generate_grouped_matmul_rows,
generate_split_qkv_rmsnorm_rope_rows,
)
from .utils import (
INPUT_SHAPES_COLUMN,
OUTPUT_SHAPES_COLUMN,
align_shape_slot_count,
build_generated_row,
collect_generated_rows,
parse_shape_text,
)
from .shape_grids import (
ATTN_BATCH_GRID,
ATTN_SEQ_GRID,
ELEM_HIDDEN_GRID,
ELEM_TOKENS_GRID,
HEAD_DIM_GRID,
HEADS_GRID,
KV_HEADS_GRID,
M_GRID,
MOE_TOKENS_GRID,
NK_GRID,
PAD_TOKENS_GRID,
)
except ImportError:
from .model_configs import get_matmul_nk_pairs
from generators import (
TheoryShapeRow,
generate_dispatch_ffn_combine_rows,
generate_fused_attention_rows,
generate_grouped_matmul_rows,
generate_split_qkv_rmsnorm_rope_rows,
)
from .utils import (
INPUT_SHAPES_COLUMN,
OUTPUT_SHAPES_COLUMN,
align_shape_slot_count,
build_generated_row,
collect_generated_rows,
parse_shape_text,
)
from .shape_grids import (
ATTN_BATCH_GRID,
ATTN_SEQ_GRID,
ELEM_HIDDEN_GRID,
ELEM_TOKENS_GRID,
HEAD_DIM_GRID,
HEADS_GRID,
KV_HEADS_GRID,
M_GRID,
MOE_TOKENS_GRID,
NK_GRID,
PAD_TOKENS_GRID,
)
_GRID_REGISTRY: dict[str, list[int]] = {
"M_GRID": M_GRID,
"NK_GRID": NK_GRID,
"ELEM_TOKENS_GRID": ELEM_TOKENS_GRID,
"ELEM_HIDDEN_GRID": ELEM_HIDDEN_GRID,
"HEADS_GRID": HEADS_GRID,
"HEAD_DIM_GRID": HEAD_DIM_GRID,
"KV_HEADS_GRID": KV_HEADS_GRID,
"ATTN_SEQ_GRID": ATTN_SEQ_GRID,
"ATTN_BATCH_GRID": ATTN_BATCH_GRID,
"MOE_TOKENS_GRID": MOE_TOKENS_GRID,
"PAD_TOKENS_GRID": PAD_TOKENS_GRID,
}
def _resolve_grid(spec, registry: dict[str, list[int]] = _GRID_REGISTRY) -> list[int]:
if isinstance(spec, str):
return registry[spec]
if isinstance(spec, list):
return spec
raise ValueError(f"Invalid grid spec: {spec}")
def generate_from_template(
pattern: dict,
model_names: list[str] | None,
) -> Generator[TheoryShapeRow, None, None]:
iterators: dict[str, list[int]] = {}
for var_name, grid_spec in pattern.get("iterators", {}).items():
iterators[var_name] = _resolve_grid(grid_spec)
constants = {k: int(v) for k, v in pattern.get("constants", {}).items()}
constraints = pattern.get("constraints", [])
input_templates = pattern["inputs"]
output_templates = pattern["outputs"]
extra_values = {}
for source_key, csv_key in (
("input_dtypes", "Input Data Types"),
("input_formats", "Input Formats"),
("output_dtypes", "Output Data Types"),
("output_formats", "Output Formats"),
):
values = pattern.get(source_key)
if values:
extra_values[csv_key] = ";".join(str(value) for value in values)
use_nk = pattern.get("model_nk_pairs", False)
if use_nk and model_names:
nk_pairs = sorted(get_matmul_nk_pairs(model_names))
iterators.pop("N", None)
iterators.pop("K", None)
var_names = list(iterators.keys())
var_grids = [iterators[n] for n in var_names]
for vals in product(*var_grids):
base = dict(zip(var_names, vals))
base.update(constants)
evaluator = SafeExprEval(base)
for n, k in nk_pairs:
evaluator.vars.update({"N": n, "K": k})
inputs = [_parse_shape_expr(t, evaluator) for t in input_templates]
outputs = [_parse_shape_expr(t, evaluator) for t in output_templates]
yield TheoryShapeRow(inputs, outputs, extra_values=dict(extra_values))
return
if use_nk:
for var_name, grid_spec in pattern.get("fallback_iterators", {}).items():
iterators[var_name] = _resolve_grid(grid_spec)
var_names = list(iterators.keys())
var_grids = [iterators[n] for n in var_names]
for vals in product(*var_grids):
vd = dict(zip(var_names, vals))
vd.update(constants)
evaluator = SafeExprEval(vd)
if constraints and not all(evaluator.eval(c) for c in constraints):
continue
inputs = [_parse_shape_expr(t, evaluator) for t in input_templates]
outputs = [_parse_shape_expr(t, evaluator) for t in output_templates]
yield TheoryShapeRow(inputs, outputs, extra_values=dict(extra_values))
def resolve_theory_pattern_name(
kernel_type: str,
assignments: dict[str, str],
kernel_meta: dict[str, Any],
) -> str | None:
pattern_name = assignments.get(kernel_type)
if not pattern_name:
parent = kernel_meta.get("alternates_of")
if parent:
pattern_name = assignments.get(parent)
if not pattern_name and kernel_meta.get("query_mode") == "elementwise":
pattern_name = "elementwise_binary"
return pattern_name
def resolve_complex_generator(
func_name: str,
model_names: list[str] | None,
complex_generators: dict[str, Callable],
signature_cache: dict[Callable, bool],
) -> Generator[TheoryShapeRow, None, None] | None:
func = complex_generators.get(func_name)
if not func:
return None
if func not in signature_cache:
import inspect
signature_cache[func] = "model_names" in inspect.signature(func).parameters
if signature_cache[func]:
return func(model_names)
return func()
def get_theory_generator(
kernel_type: str,
model_names: list[str] | None,
config: dict,
op_meta: dict[str, dict],
*,
complex_generators: dict[str, Callable],
signature_cache: dict[Callable, bool],
) -> Generator[TheoryShapeRow, None, None] | None:
assignments = config.get("assignments", {})
patterns = config.get("patterns", {})
km = op_meta.get(kernel_type, {})
if km.get("zero_cost") or km.get("composite") or km.get("communication"):
return None
pattern_name = resolve_theory_pattern_name(kernel_type, assignments, km)
if not pattern_name:
return None
pattern = patterns.get(pattern_name)
if not pattern:
return None
func_name = pattern.get("generator_function")
if func_name:
return resolve_complex_generator(func_name, model_names, complex_generators, signature_cache)
return generate_from_template(pattern, model_names)
def default_complex_generators() -> dict[str, Callable]:
return {
"_theory_grouped_matmul": generate_grouped_matmul_rows,
"_theory_dfc": generate_dispatch_ffn_combine_rows,
"_theory_fused_attention": generate_fused_attention_rows,
"_theory_split_qkv_rmsnorm_rope": generate_split_qkv_rmsnorm_rope_rows,
}
def get_default_theory_generator(
kernel_type: str,
model_names: list[str] | None,
config: dict,
op_meta: dict[str, dict],
) -> Generator[TheoryShapeRow, None, None] | None:
return get_theory_generator(
kernel_type,
model_names,
config,
op_meta,
complex_generators=default_complex_generators(),
signature_cache={},
)
def collect_theory_generated_rows(
headers: list[str],
source_rows: list[dict[str, str]],
generated: Generator[TheoryShapeRow, None, None],
*,
csv_path: Path,
file_index: int,
total_files: int,
max_rows: int | None,
rng,
max_hbm_bytes: int | None = None,
) -> list[dict[str, str]]:
template_row = source_rows[0]
if max_rows is not None:
all_rows = list(generated)
if len(all_rows) > max_rows:
all_rows = rng.sample(all_rows, max_rows) if rng is not None else all_rows[:max_rows]
grid_iter = iter(all_rows)
else:
grid_iter = generated
template_input_text = template_row.get(INPUT_SHAPES_COLUMN, "")
template_inputs = parse_shape_text(template_input_text)
template_output_text = template_row.get(OUTPUT_SHAPES_COLUMN, "")
template_outputs = parse_shape_text(template_output_text) if template_output_text else []
memory_filter_active = max_hbm_bytes is not None and max_hbm_bytes > 0
input_dtypes: list[str] = []
output_dtypes: list[str] = []
if memory_filter_active:
try:
from ..memory_estimator import (
exceeds_memory_budget,
format_bytes,
parse_dtype_from_template_row,
)
except ImportError:
from memory_estimator import (
exceeds_memory_budget,
format_bytes,
parse_dtype_from_template_row,
)
input_dtypes, output_dtypes = parse_dtype_from_template_row(template_row)
skipped_count = 0
total_count = 0
def split_metadata_slots(value: str) -> list[str]:
raw = str(value or "").strip().strip('"')
if not raw:
return []
return [part.strip() for part in raw.split(";")]
def is_absent_slot(value: str) -> bool:
return value.strip().upper() in {"", "NULL", "NONE", "UNDEFINED", "DT_UNDEFINED"}
def clear_absent_shape_slots(
shapes: list[tuple[int, ...]],
dtype_cell: str,
format_cell: str,
) -> list[tuple[int, ...]]:
dtypes = split_metadata_slots(dtype_cell)
formats = split_metadata_slots(format_cell)
sanitized = list(shapes)
for index in range(len(sanitized)):
dtype_absent = index < len(dtypes) and is_absent_slot(dtypes[index])
format_absent = index < len(formats) and is_absent_slot(formats[index])
if dtype_absent or format_absent:
sanitized[index] = ()
return sanitized
def build_theory_generated_row(row: TheoryShapeRow) -> dict[str, str] | None:
nonlocal skipped_count, total_count
total_count += 1
input_shapes = align_shape_slot_count(template_inputs, row.input_shapes)
output_shapes = align_shape_slot_count(template_outputs, row.output_shapes)
input_dtype_cell = row.extra_values.get(
"Input Data Types",
template_row.get("Input Data Types", ""),
)
input_format_cell = row.extra_values.get(
"Input Formats",
template_row.get("Input Formats", ""),
)
output_dtype_cell = row.extra_values.get(
"Output Data Types",
template_row.get("Output Data Types", ""),
)
output_format_cell = row.extra_values.get(
"Output Formats",
template_row.get("Output Formats", ""),
)
input_shapes = clear_absent_shape_slots(input_shapes, input_dtype_cell, input_format_cell)
output_shapes = clear_absent_shape_slots(output_shapes, output_dtype_cell, output_format_cell)
if memory_filter_active:
exceeded, est_bytes = exceeds_memory_budget(
input_shapes,
output_shapes,
input_dtypes,
output_dtypes,
max_bytes=max_hbm_bytes,
)
if exceeded:
skipped_count += 1
return None
return build_generated_row(
headers,
template_row,
input_shapes,
output_shapes,
extra_values=row.extra_values,
)
filtered_rows = collect_generated_rows(
grid_iter,
build_theory_generated_row,
file_index=file_index,
total_files=total_files,
csv_path=csv_path,
total_rows=max_rows,
progress_interval=500,
)
if memory_filter_active and skipped_count > 0:
budget_str = format_bytes(max_hbm_bytes)
print(
f"\n[MEMORY] {csv_path.name}: filtered {skipped_count}/{total_count} rows "
f"exceeding {budget_str} HBM budget",
file=sys.stderr,
)
return filtered_rows