from typing import Any, Dict, Iterable, List, Optional
from .context import AdaptationContext
from .hints import HintLedger, UserHint
from .insight import ObservedKernel, RawInsightSummary
_RAW_TO_TC_OP = {
"FusedInferAttentionScore": ("tensor_cast.attention.default", "medium"),
"MoeGatingTopK": ("tensor_cast.moe_gating_top_k_softmax.default", "medium"),
"RmsNorm": ("tensor_cast.rms_norm.default", "low"),
}
_PROFILING_ONLY_KERNELS = {
"DispatchFFNCombine",
"QuantBatchMatmulV3",
"DynamicQuant",
"MatMulV2",
"AddRmsNormBias",
}
def _hinted_mapping(hints: Iterable[UserHint], profiling_name: str) -> Optional[Dict[str, Any]]:
for hint in hints:
if hint.kind != "op_mapping_hint":
continue
if hint.data.get("profiling_op") != profiling_name:
continue
tc_op = hint.data.get("tc_op")
if not tc_op:
continue
return {
"name": tc_op,
"confidence": hint.confidence,
"source": f"user_hint:{profiling_name}",
}
return None
def _expected_op_from_kernel(kernel: ObservedKernel, hints: Iterable[UserHint]) -> Optional[Dict[str, Any]]:
hinted = _hinted_mapping(hints, kernel.normalized_name)
if hinted is not None:
hinted["count"] = kernel.occurrences
return hinted
if kernel.normalized_name in _RAW_TO_TC_OP:
op_name, confidence = _RAW_TO_TC_OP[kernel.normalized_name]
return {
"name": op_name,
"count": kernel.occurrences,
"confidence": confidence,
"source": f"raw_insight:{kernel.normalized_name}",
}
if kernel.normalized_name in _PROFILING_ONLY_KERNELS:
return {
"name": f"profiling.{kernel.normalized_name}",
"count": kernel.occurrences,
"confidence": "low",
"source": f"raw_insight:{kernel.normalized_name}",
}
return None
def build_evidence_draft(
context: AdaptationContext,
raw_insight: RawInsightSummary,
hints: Optional[HintLedger] = None,
case_name: Optional[str] = None,
top_n: int = 20,
) -> Dict[str, Any]:
hint_items = [] if hints is None else hints.hints
major_ops: List[Dict[str, Any]] = []
seen_ops = set()
for kernel in raw_insight.top_kernels(top_n):
expected = _expected_op_from_kernel(kernel, hint_items)
if expected is None:
continue
key = (expected["name"], expected.get("source"))
if key in seen_ops:
continue
seen_ops.add(key)
major_ops.append(expected)
generated_case_name = case_name or _default_case_name(context)
return {
"version": 1,
"model": {
"model_id": context.model_id,
"raw_command": context.raw_command,
},
"cases": [
{
"name": generated_case_name,
"input": _evidence_input_from_context(context),
"observed_kernels": [kernel.to_dict() for kernel in raw_insight.top_kernels(top_n)],
"expected": {
"total_forward": {
"time_s": raw_insight.total_wall_duration_ms / 1000.0,
"rel_tolerance": 0.2,
"source": "raw_insight:Totals.wall_duration_ms",
},
"major_ops": major_ops,
},
"notes": [
"Generated from raw Insight profiling and optional user hints.",
"raw Insight Totals wall duration is used as expected total_forward time.",
"Low-confidence profiling.* entries are placeholders for fused or profiling-only kernels.",
],
}
],
}
def _default_case_name(context: AdaptationContext) -> str:
model_name = context.model_id.rstrip("/").split("/")[-1].lower().replace("_", "-")
phase = "decode" if context.normalized_args.get("decode") else "prefill"
quant = context.normalized_args.get("quantize_linear_action")
suffix = f"-{str(quant).lower()}" if quant else ""
return f"{model_name}-{phase}{suffix}"
def _evidence_input_from_context(context: AdaptationContext) -> Dict[str, Any]:
data = dict(context.normalized_args)
aliases = {
"compile": "do_compile",
"compile_allow_graph_break": "allow_graph_break",
"num_devices": "world_size",
"query_length": "query_len",
}
for source, target in aliases.items():
if source in data and target not in data:
data[target] = data[source]
data.setdefault("model_id", context.model_id)
return data