import dataclasses
import fnmatch
from typing import Any, Dict, List, Optional, Tuple
from tensor_cast.layers.mla import DeepseekSparseAttention
from tensor_cast.transformers.custom_model_registry import ModelProfile
from .inspect import ModelStructureFacts, ProfileCandidate
@dataclasses.dataclass(frozen=True)
class AdapterRecipe:
name: str
target_passes: List[str]
required_fields: Dict[str, List[str]]
optional_fields: Dict[str, List[str]] = dataclasses.field(default_factory=dict)
description: str = ""
@dataclasses.dataclass(frozen=True)
class RecipeProfileHint:
recipe_name: str
model_type_patterns: Tuple[str, ...] = ("*",)
mla_module_name_patterns: Tuple[str, ...] = ()
mla_module_class_type: Optional[type] = None
moe_gate_returns_raw_logits: Optional[bool] = None
source: str = "builtin recipe hint"
def matches(self, structure: ModelStructureFacts, candidate: ProfileCandidate) -> bool:
model_type = structure.model_type or ""
if not any(fnmatch.fnmatchcase(model_type, pattern) for pattern in self.model_type_patterns):
return False
if not self.mla_module_name_patterns:
return True
mla_module_name = candidate.mla_module_name.value if candidate.mla_module_name else ""
return any(fnmatch.fnmatchcase(mla_module_name, pattern) for pattern in self.mla_module_name_patterns)
DEEPSEEK_LIKE_MLA_MOE_RECIPE = AdapterRecipe(
name="deepseek_like_mla_moe",
target_passes=["MLA", "MoE", "Shard"],
required_fields={
"MLA": ["kv_a_proj_with_mqa", "kv_b_proj", "o_proj", "kv_a_layernorm"],
"MoE": ["gate", "experts"],
},
optional_fields={
"MLA": ["q_proj", "q_a_proj", "q_b_proj", "q_a_layernorm"],
"MoE": ["shared_experts", "shared_experts_gate", "top_k"],
},
description="DeepSeek-like MLA plus standard routed MoE structure.",
)
_RECIPE_PROFILE_HINTS: Tuple[RecipeProfileHint, ...] = (
RecipeProfileHint(
recipe_name="deepseek_like_mla",
model_type_patterns=("deepseek*", "glm_moe_dsa"),
mla_module_name_patterns=("Deepseek*SparseAttention", "Deepseek*Attention", "GlmMoeDsaAttention"),
mla_module_class_type=DeepseekSparseAttention,
source="deepseek-like sparse MLA recipe",
),
)
def _candidate_value(candidate: ProfileCandidate, field_name: str, default: Any = None) -> Any:
field = getattr(candidate, field_name)
return default if field is None else field.value
def _matching_recipe_hints(
structure: ModelStructureFacts,
candidate: ProfileCandidate,
) -> Tuple[RecipeProfileHint, ...]:
recipe_name = _candidate_value(candidate, "recipe")
if recipe_name is None:
return ()
return tuple(
hint for hint in _RECIPE_PROFILE_HINTS if hint.recipe_name == recipe_name and hint.matches(structure, candidate)
)
def materialize_profile_candidate(
structure: ModelStructureFacts,
candidate: ProfileCandidate,
) -> ModelProfile:
hints = _matching_recipe_hints(structure, candidate)
mla_module_class_type = None
moe_gate_returns_raw_logits = _candidate_value(candidate, "moe_gate_returns_raw_logits", False)
for hint in hints:
if hint.mla_module_class_type is not None:
mla_module_class_type = hint.mla_module_class_type
if hint.moe_gate_returns_raw_logits is not None:
moe_gate_returns_raw_logits = hint.moe_gate_returns_raw_logits
moe_num_experts_key = _candidate_value(candidate, "moe_num_experts_key")
profile_kwargs = {
"model_type": _candidate_value(candidate, "model_type", structure.model_type),
"moe_module_name": _candidate_value(candidate, "moe_module_name"),
"moe_num_experts_key": ("num_experts" if moe_num_experts_key is None else moe_num_experts_key),
"moe_field_names_override": _candidate_value(candidate, "moe_field_names_override"),
"moe_gate_returns_raw_logits": moe_gate_returns_raw_logits,
"mtp_block_module_name": _candidate_value(candidate, "mtp_block_module_name"),
"mla_module_name": _candidate_value(candidate, "mla_module_name"),
"mla_field_names_override": _candidate_value(candidate, "mla_field_names_override"),
"model_family": _candidate_value(candidate, "model_family"),
"visual_module_path": _candidate_value(candidate, "visual_module_path"),
"language_module_path": _candidate_value(candidate, "language_module_path"),
"visual_layers_module_path": _candidate_value(candidate, "visual_layers_module_path"),
"visual_layers_path_str": _candidate_value(candidate, "visual_layers_path_str"),
"language_layers_path_str": _candidate_value(candidate, "language_layers_path_str"),
"visual_merger_linear_mapping": _candidate_value(candidate, "visual_merger_linear_mapping", {}),
"visual_mlp_linear_mapping": _candidate_value(candidate, "visual_mlp_linear_mapping", {}),
}
if mla_module_class_type is not None:
profile_kwargs["mla_module_class_type"] = mla_module_class_type
return ModelProfile(**profile_kwargs)
def materialization_hints_to_dict(
structure: ModelStructureFacts,
candidate: ProfileCandidate,
) -> List[Dict[str, Any]]:
return [
{
"recipe_name": hint.recipe_name,
"source": hint.source,
"mla_module_class_type": (
None
if hint.mla_module_class_type is None
else f"{hint.mla_module_class_type.__module__}.{hint.mla_module_class_type.__name__}"
),
"moe_gate_returns_raw_logits": hint.moe_gate_returns_raw_logits,
}
for hint in _matching_recipe_hints(structure, candidate)
]
@dataclasses.dataclass(frozen=True)
class SkillTask:
title: str
reason: str
inputs: Dict[str, Any]
expected_outputs: List[str]
verification_steps: List[str]
recipe: Optional[str] = None
def build_unsupported_semantics_task(reason: str, inputs: Dict[str, Any], recipe: Optional[str] = None) -> SkillTask:
return SkillTask(
title="Implement unsupported TensorCast model adapter semantics",
reason=reason,
inputs=inputs,
expected_outputs=[
"candidate ModelProfile diff",
"candidate wrapper or performance-model implementation",
"tests or evidence verifier assertions",
],
verification_steps=[
"run patch dry-run and inspect PatchReport",
"run actual summary collection for evidence case",
"run EvidenceVerifier and require deterministic PASS",
],
recipe=recipe,
)