import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Union
from tensor_cast.layers import COLWISE_LINEAR, ROWWISE_LINEAR
from tensor_cast.model_config import MlaFieldNames, MoEFieldNames
from tensor_cast.transformers.custom_model_registry import get_model_profile
@dataclasses.dataclass(frozen=True)
class ModuleFacts:
path: str
class_name: str
fields: Tuple[str, ...]
parameter_shapes: Dict[str, Tuple[int, ...]] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass(frozen=True)
class ModelStructureFacts:
model_type: Optional[str]
num_hidden_layers: Optional[int]
hidden_size: Optional[int]
num_attention_heads: Optional[int]
num_key_value_heads: Optional[int]
intermediate_size: Optional[int]
expert_fields: Dict[str, Any]
attention_like_modules: Tuple[ModuleFacts, ...]
moe_like_modules: Tuple[ModuleFacts, ...]
mlp_like_modules: Tuple[ModuleFacts, ...]
visual_module_paths: Tuple[str, ...] = ()
language_module_paths: Tuple[str, ...] = ()
visual_layers_path_candidates: Tuple[str, ...] = ()
language_layers_path_candidates: Tuple[str, ...] = ()
visual_merger_linear_mapping: Dict[str, str] = dataclasses.field(default_factory=dict)
visual_mlp_linear_mapping: Dict[str, str] = dataclasses.field(default_factory=dict)
known_recipe_matches: Tuple[str, ...] = ()
@dataclasses.dataclass(frozen=True)
class CandidateField:
value: Any
source: str
confidence: str = "medium"
@dataclasses.dataclass(frozen=True)
class ProfileCandidate:
model_type: Optional[CandidateField] = None
moe_module_name: Optional[CandidateField] = None
moe_num_experts_key: Optional[CandidateField] = None
moe_field_names_override: Optional[CandidateField] = None
moe_gate_returns_raw_logits: Optional[CandidateField] = None
mla_module_name: Optional[CandidateField] = None
mla_field_names_override: Optional[CandidateField] = None
mtp_block_module_name: Optional[CandidateField] = None
model_family: Optional[CandidateField] = None
visual_module_path: Optional[CandidateField] = None
language_module_path: Optional[CandidateField] = None
visual_layers_module_path: Optional[CandidateField] = None
visual_layers_path_str: Optional[CandidateField] = None
language_layers_path_str: Optional[CandidateField] = None
visual_merger_linear_mapping: Optional[CandidateField] = None
visual_mlp_linear_mapping: Optional[CandidateField] = None
recipe: Optional[CandidateField] = None
def _module_field_names(module: Any) -> Tuple[str, ...]:
fields = set(vars(module).keys())
fields.update(getattr(module, "_modules", {}).keys())
fields.update(getattr(module, "_parameters", {}).keys())
fields.update(getattr(module, "_buffers", {}).keys())
return tuple(sorted(fields))
def _module_facts(path: str, module: Any) -> ModuleFacts:
parameters = {}
if hasattr(module, "named_parameters"):
for param_name, param in module.named_parameters(recurse=False):
parameters[param_name] = tuple(param.shape)
return ModuleFacts(
path=path,
class_name=type(module).__name__,
fields=_module_field_names(module),
parameter_shapes=parameters,
)
def _module_has_any(module: Any, names: Tuple[str, ...]) -> bool:
return any(hasattr(module, name) for name in names)
def _facts_has_any(facts: ModuleFacts, names: Tuple[str, ...]) -> bool:
fields = set(facts.fields)
return any(name in fields for name in names)
def _facts_has_all(facts: ModuleFacts, names: Tuple[str, ...]) -> bool:
fields = set(facts.fields)
return all(name in fields for name in names)
def _is_mla_like_attention(facts: ModuleFacts) -> bool:
has_compressed_kv_path = _facts_has_all(
facts,
("kv_a_proj_with_mqa", "kv_b_proj", "kv_a_layernorm", "o_proj"),
)
has_query_path = _facts_has_any(facts, ("q_proj",)) or _facts_has_all(
facts,
("q_a_proj", "q_b_proj", "q_a_layernorm"),
)
has_latent_config = _facts_has_any(
facts,
(
"q_lora_rank",
"kv_lora_rank",
"qk_nope_head_dim",
"qk_rope_head_dim",
"v_head_dim",
),
)
return has_compressed_kv_path and (has_query_path or has_latent_config)
def _is_moe_like_module(facts: ModuleFacts) -> bool:
fields = set(facts.fields)
has_expert_container = "experts" in fields
has_router = "gate" in fields or "router" in fields
return has_expert_container and has_router
def _pick_module_name(modules: Tuple[ModuleFacts, ...]) -> Optional[str]:
if not modules:
return None
class_counts: Dict[str, int] = {}
for module in modules:
class_counts[module.class_name] = class_counts.get(module.class_name, 0) + 1
return sorted(class_counts.items(), key=lambda item: (-item[1], item[0]))[0][0]
def _infer_override(base_fields: Any, facts: ModuleFacts) -> Dict[str, str]:
fields = set(facts.fields)
default_names = {
getattr(base_fields, field.name)
for field in dataclasses.fields(base_fields)
if getattr(base_fields, field.name) is not None
}
override = {}
for field in dataclasses.fields(base_fields):
default_name = getattr(base_fields, field.name)
if default_name is None or default_name in fields:
continue
candidates = [
candidate for candidate in _candidate_aliases(field.name, fields) if candidate not in default_names
]
if candidates:
override[field.name] = candidates[0]
return override
def _candidate_aliases(field_name: str, fields: set[str]) -> List[str]:
normalized = field_name.replace("_", "")
aliases = []
for field in sorted(fields):
compact = field.replace("_", "")
if field_name in field or normalized in compact or compact in normalized:
aliases.append(field)
singular = field_name.rstrip("s")
aliases.extend(field for field in sorted(fields) if singular and singular in field)
return list(dict.fromkeys(aliases))
def _config_has_key(config: Any, key_path: Union[str, Tuple[str, ...]]) -> bool:
if config is None:
return False
if isinstance(key_path, str):
return hasattr(config, key_path)
current = config
for key in key_path:
if not hasattr(current, key):
return False
current = getattr(current, key)
return True
def _config_get(config: Any, key_path: Union[str, Tuple[str, ...]]) -> Any:
if isinstance(key_path, str):
return getattr(config, key_path)
current = config
for key in key_path:
current = getattr(current, key)
return current
def _get_attr_path(root: Any, path: str) -> Any:
current = root
for part in path.split("."):
if part == "":
return None
if part.isdigit() and isinstance(current, (list, tuple)):
index = int(part)
if index >= len(current):
return None
current = current[index]
continue
if not hasattr(current, part):
return None
current = getattr(current, part)
return current
def _has_attr_path(root: Any, path: str) -> bool:
return _get_attr_path(root, path) is not None
def _existing_paths(root: Any, candidates: Tuple[str, ...]) -> Tuple[str, ...]:
return tuple(path for path in candidates if _has_attr_path(root, path))
def _join_path(prefix: str, suffix: str) -> str:
return f"{prefix}.{suffix}" if suffix else prefix
def _layer_path_candidates(root: Any, module_paths: Tuple[str, ...]) -> Tuple[str, ...]:
candidates: List[str] = []
suffixes = ("blocks", "layers", "encoder.layers", "model.layers")
for module_path in module_paths:
module = _get_attr_path(root, module_path)
if module is None:
continue
for suffix in suffixes:
if _has_attr_path(module, suffix):
candidates.append(_join_path(module_path, suffix))
return tuple(dict.fromkeys(candidates))
def _wildcard_numeric_path(path: str) -> str:
return ".".join("*" if part.isdigit() else part for part in path.split("."))
def _linear_parallel_kind(path: str) -> Optional[str]:
leaf = path.rsplit(".", maxsplit=1)[-1]
if leaf in {"linear_fc1", "fc1", "gate_proj", "up_proj"}:
return COLWISE_LINEAR
if leaf in {"linear_fc2", "fc2", "down_proj"}:
return ROWWISE_LINEAR
return None
def _collect_visual_linear_mappings(
root: Any,
visual_module_paths: Tuple[str, ...],
) -> Tuple[Dict[str, str], Dict[str, str]]:
merger_mapping: Dict[str, str] = {}
mlp_mapping: Dict[str, str] = {}
if not hasattr(root, "named_modules"):
return merger_mapping, mlp_mapping
visual_prefixes = tuple(f"{path}." for path in visual_module_paths)
for name, module in root.named_modules():
if not any(name.startswith(prefix) for prefix in visual_prefixes):
continue
fields = _module_field_names(module)
if "weight" not in fields:
continue
parallel_kind = _linear_parallel_kind(name)
if parallel_kind is None:
continue
wildcard_name = _wildcard_numeric_path(name)
if ".mlp." in wildcard_name:
mlp_mapping[wildcard_name] = parallel_kind
elif ".merger." in wildcard_name or ".deepstack_merger_list." in wildcard_name:
merger_mapping[wildcard_name] = parallel_kind
return merger_mapping, mlp_mapping
def _infer_model_family(model_type: Optional[str], has_visual: bool) -> Optional[str]:
if model_type in {"qwen3_vl", "qwen3_vl_moe"}:
return "qwen3_vl"
if has_visual:
return "default"
return None
def _candidate_expert_key_paths() -> Tuple[Union[str, Tuple[str, ...]], ...]:
top_level_keys = (
"num_experts",
"num_local_experts",
"n_routed_experts",
"num_routing_experts",
"moe_num_experts",
"expert_num",
)
nested_keys = tuple((root, key) for root in ("text_config", "llm_config") for key in top_level_keys)
return top_level_keys + nested_keys
def _expert_key_to_profile_value(
key_path: Union[str, Tuple[str, ...]],
) -> Union[str, List[str]]:
if isinstance(key_path, str):
return key_path
return list(key_path)
def _display_key(key_path: Union[str, Tuple[str, ...]]) -> str:
return ".".join(key_path) if isinstance(key_path, tuple) else key_path
def _collect_expert_fields(config: Any) -> Dict[str, Any]:
expert_fields: Dict[str, Any] = {}
for key_path in _candidate_expert_key_paths():
if _config_has_key(config, key_path):
expert_fields[_display_key(key_path)] = {
"profile_key": _expert_key_to_profile_value(key_path),
"value": _config_get(config, key_path),
}
return expert_fields
def _pick_expert_key(expert_fields: Dict[str, Any]) -> Optional[Union[str, List[str]]]:
if not expert_fields:
return None
preferred_order = (
"num_experts",
"num_local_experts",
"n_routed_experts",
"num_routing_experts",
"moe_num_experts",
"expert_num",
"text_config.num_experts",
"text_config.num_local_experts",
"text_config.n_routed_experts",
"text_config.num_routing_experts",
"text_config.moe_num_experts",
"text_config.expert_num",
"llm_config.num_experts",
"llm_config.num_local_experts",
"llm_config.n_routed_experts",
"llm_config.num_routing_experts",
"llm_config.moe_num_experts",
"llm_config.expert_num",
)
for key in preferred_order:
if key in expert_fields:
return expert_fields[key]["profile_key"]
return next(iter(expert_fields.values()))["profile_key"]
def inspect_model_structure(
model: Any, hf_config: Optional[Any] = None
) -> Tuple[ModelStructureFacts, ProfileCandidate]:
config = hf_config or getattr(model, "hf_config", None)
root = model.unwrap() if hasattr(model, "unwrap") else model
attention_like = []
moe_like = []
mlp_like = []
if hasattr(root, "named_modules"):
for name, module in root.named_modules():
if not name:
continue
class_name = type(module).__name__.lower()
leaf_name = name.rsplit(".", maxsplit=1)[-1].lower()
facts = _module_facts(name, module)
if "attn" in leaf_name or "attention" in class_name:
attention_like.append(facts)
if _is_moe_like_module(facts):
moe_like.append(facts)
if (
"mlp" in leaf_name
or "mlp" in class_name
or _module_has_any(
module,
("gate_proj", "up_proj", "down_proj"),
)
):
mlp_like.append(_module_facts(name, module))
model_type = getattr(config, "model_type", None)
mla_like_attention = tuple(facts for facts in attention_like if _is_mla_like_attention(facts))
recipe_matches = []
if mla_like_attention:
recipe_matches.append("deepseek_like_mla")
if moe_like:
recipe_matches.append("standard_moe")
expert_keys = _collect_expert_fields(config)
visual_module_paths = _existing_paths(
root,
(
"visual",
"vision_tower",
"vision_model",
"model.visual",
"model.vision_tower",
),
)
language_module_paths = _existing_paths(
root,
("language_model", "text_model", "llm", "model", "model.language_model"),
)
visual_layer_paths = _layer_path_candidates(root, visual_module_paths)
language_layer_paths = _layer_path_candidates(root, language_module_paths)
if _has_attr_path(root, "layers"):
language_layer_paths = tuple(dict.fromkeys((*language_layer_paths, "layers")))
visual_merger_mapping, visual_mlp_mapping = _collect_visual_linear_mappings(root, visual_module_paths)
if visual_module_paths:
recipe_matches.append("visual_language")
facts = ModelStructureFacts(
model_type=model_type,
num_hidden_layers=getattr(config, "num_hidden_layers", None),
hidden_size=getattr(config, "hidden_size", None),
num_attention_heads=getattr(config, "num_attention_heads", None),
num_key_value_heads=getattr(config, "num_key_value_heads", None),
intermediate_size=getattr(config, "intermediate_size", None),
expert_fields=expert_keys,
attention_like_modules=tuple(attention_like),
moe_like_modules=tuple(moe_like),
mlp_like_modules=tuple(mlp_like),
visual_module_paths=visual_module_paths,
language_module_paths=language_module_paths,
visual_layers_path_candidates=visual_layer_paths,
language_layers_path_candidates=language_layer_paths,
visual_merger_linear_mapping=visual_merger_mapping,
visual_mlp_linear_mapping=visual_mlp_mapping,
known_recipe_matches=tuple(recipe_matches),
)
profile = get_model_profile(model_type) if model_type else None
registered_mla_module_name = profile.mla_module_name if profile and profile.mla_module_name else None
mla_modules_for_candidate = facts.attention_like_modules if registered_mla_module_name else mla_like_attention
mla_module_name = registered_mla_module_name or _pick_module_name(mla_modules_for_candidate)
moe_module_name = (
profile.moe_module_name if profile and profile.moe_module_name else _pick_module_name(facts.moe_like_modules)
)
mla_override = None
if mla_modules_for_candidate:
override = _infer_override(MlaFieldNames(), mla_modules_for_candidate[0])
if override:
mla_override = CandidateField(override, mla_modules_for_candidate[0].path, "medium")
moe_override = None
if facts.moe_like_modules:
override = _infer_override(MoEFieldNames(), facts.moe_like_modules[0])
if override:
moe_override = CandidateField(override, facts.moe_like_modules[0].path, "medium")
candidate = ProfileCandidate(
model_type=CandidateField(model_type, "hf_config.model_type", "high") if model_type else None,
moe_module_name=(
CandidateField(moe_module_name, "registered profile or moe-like scan", "medium")
if moe_module_name
else None
),
moe_num_experts_key=(
CandidateField(_pick_expert_key(expert_keys), "hf_config expert key scan", "medium")
if expert_keys
else None
),
moe_field_names_override=moe_override,
moe_gate_returns_raw_logits=CandidateField(False, "safe default", "low") if moe_module_name else None,
mla_module_name=(
CandidateField(mla_module_name, "registered profile or attention-like scan", "medium")
if mla_module_name
else None
),
mla_field_names_override=mla_override,
model_family=(
CandidateField(
_infer_model_family(model_type, bool(facts.visual_module_paths)),
"model_type and visual module scan",
"medium",
)
if _infer_model_family(model_type, bool(facts.visual_module_paths))
else None
),
visual_module_path=(
CandidateField(facts.visual_module_paths[0], "module_tree_scan", "medium")
if facts.visual_module_paths
else None
),
language_module_path=(
CandidateField(facts.language_module_paths[0], "module_tree_scan", "medium")
if facts.language_module_paths
else None
),
visual_layers_path_str=(
CandidateField(facts.visual_layers_path_candidates[0], "module_tree_scan", "medium")
if facts.visual_layers_path_candidates
else None
),
visual_layers_module_path=(
CandidateField(facts.visual_layers_path_candidates[0], "module_tree_scan", "medium")
if facts.visual_layers_path_candidates
else None
),
language_layers_path_str=(
CandidateField(facts.language_layers_path_candidates[0], "module_tree_scan", "medium")
if facts.language_layers_path_candidates
else None
),
visual_merger_linear_mapping=(
CandidateField(facts.visual_merger_linear_mapping, "visual linear scan", "medium")
if facts.visual_merger_linear_mapping
else None
),
visual_mlp_linear_mapping=(
CandidateField(facts.visual_mlp_linear_mapping, "visual linear scan", "medium")
if facts.visual_mlp_linear_mapping
else None
),
recipe=(
CandidateField(facts.known_recipe_matches[0], "structure recipe match", "medium")
if facts.known_recipe_matches
else None
),
)
return facts, candidate