import dataclasses
from typing import Any, Dict, Iterable, List, Optional, Tuple
from tensor_cast.model_config import MlaFieldNames, MoEFieldNames
@dataclasses.dataclass(frozen=True)
class ProfileValidationIssue:
field: str
message: str
severity: str = "error"
@dataclasses.dataclass(frozen=True)
class ProfileValidationReport:
model_type: str
issues: Tuple[ProfileValidationIssue, ...]
@property
def passed(self) -> bool:
return not any(issue.severity == "error" for issue in self.issues)
def raise_for_errors(self) -> None:
errors = [issue for issue in self.issues if issue.severity == "error"]
if not errors:
return
details = "; ".join(f"{issue.field}: {issue.message}" for issue in errors)
raise ValueError(f"Invalid ModelProfile for {self.model_type!r}: {details}")
def _validate_non_empty_string(
issues: List[ProfileValidationIssue],
field: str,
value: Optional[str],
required: bool = False,
) -> None:
if value is None:
if required:
issues.append(ProfileValidationIssue(field, "must be set"))
return
if not isinstance(value, str) or not value.strip():
issues.append(ProfileValidationIssue(field, "must be a non-empty string"))
def _field_values(field_names: Any) -> Iterable[Tuple[str, Any]]:
if isinstance(field_names, dict):
return field_names.items()
if not dataclasses.is_dataclass(field_names):
return ()
return ((field.name, getattr(field_names, field.name)) for field in dataclasses.fields(field_names))
def _validate_field_name_values(
issues: List[ProfileValidationIssue],
prefix: str,
field_names: Any,
base_fields: Any,
) -> None:
valid_fields = {field.name for field in dataclasses.fields(base_fields)}
for field_name, value in _field_values(field_names):
if field_name not in valid_fields:
issues.append(
ProfileValidationIssue(
f"{prefix}.{field_name}",
"must be a supported field name",
)
)
continue
if value is None:
continue
if not isinstance(value, str) or not value.strip():
issues.append(
ProfileValidationIssue(
f"{prefix}.{field_name}",
"must be None or a non-empty string",
)
)
def normalize_profile(profile: Any) -> Any:
if profile.moe_field_names_override and not isinstance(profile.moe_field_names_override, dict):
if not dataclasses.is_dataclass(profile.moe_field_names_override):
raise TypeError("moe_field_names_override must be a dict")
profile.moe_field_names_override = {
field.name: getattr(profile.moe_field_names_override, field.name)
for field in dataclasses.fields(profile.moe_field_names_override)
}
if profile.mla_field_names_override and not isinstance(profile.mla_field_names_override, dict):
profile.mla_field_names_override = {
field.name: value
for field, value in (
(field, getattr(profile.mla_field_names_override, field.name))
for field in dataclasses.fields(profile.mla_field_names_override)
)
}
return profile
def validate_profile(profile: Any) -> ProfileValidationReport:
issues: List[ProfileValidationIssue] = []
_validate_non_empty_string(issues, "model_type", profile.model_type, required=True)
_validate_non_empty_string(issues, "moe_module_name", profile.moe_module_name)
_validate_non_empty_string(issues, "mtp_block_module_name", profile.mtp_block_module_name)
_validate_non_empty_string(issues, "mla_module_name", profile.mla_module_name)
_validate_non_empty_string(issues, "model_family", profile.model_family)
if profile.moe_module_name:
if isinstance(profile.moe_num_experts_key, str):
_validate_non_empty_string(
issues,
"moe_num_experts_key",
profile.moe_num_experts_key,
required=True,
)
elif isinstance(profile.moe_num_experts_key, list):
if not profile.moe_num_experts_key:
issues.append(
ProfileValidationIssue(
"moe_num_experts_key",
"list must not be empty when MoE is enabled",
)
)
for index, key in enumerate(profile.moe_num_experts_key):
_validate_non_empty_string(
issues,
f"moe_num_experts_key[{index}]",
key,
required=True,
)
else:
issues.append(
ProfileValidationIssue(
"moe_num_experts_key",
"must be a string or a list of strings",
)
)
_validate_field_name_values(
issues,
"moe_field_names_override",
profile.moe_field_names_override or MoEFieldNames(),
MoEFieldNames(),
)
if profile.mla_module_name:
if profile.mla_module_class_type is None:
issues.append(ProfileValidationIssue("mla_module_class_type", "must be set when MLA is enabled"))
try:
mla_config = profile.build_mla_config()
except ValueError as exc:
issues.append(ProfileValidationIssue("mla_field_names_override", str(exc)))
else:
_validate_field_name_values(
issues,
"mla_field_names_override",
mla_config.field_names if mla_config is not None else MlaFieldNames(),
MlaFieldNames(),
)
if profile.patch_method is not None and not callable(profile.patch_method):
issues.append(ProfileValidationIssue("patch_method", "must be callable"))
if profile.custom_expert_module_type is not None and not callable(profile.custom_expert_module_type):
issues.append(ProfileValidationIssue("custom_expert_module_type", "must be callable"))
return ProfileValidationReport(
model_type=str(profile.model_type),
issues=tuple(issues),
)
def _normalize_override_for_review(value: Any, base_fields: Any) -> Any:
if dataclasses.is_dataclass(value):
value = dataclasses.asdict(value)
if not isinstance(value, dict):
return value
defaults = {field.name: getattr(base_fields, field.name) for field in dataclasses.fields(base_fields)}
return {key: item for key, item in value.items() if item is not None and item != defaults.get(key)}
def profile_to_review_dict(profile: Any) -> Dict[str, Any]:
data: Dict[str, Any] = {}
for field in dataclasses.fields(profile):
value = getattr(profile, field.name)
if value is None:
continue
if field.name in {"moe_gate_returns_raw_logits", "moe_route_after_dp_transform"} and value is False:
continue
if field.name == "custom_expert_module_type" and (
not profile.moe_module_name or _callable_name(value).endswith(".MoeExpertMLP")
):
continue
if field.name == "mla_module_class_type" and not profile.mla_module_name:
continue
if dataclasses.is_dataclass(value):
value = dataclasses.asdict(value)
elif callable(value) and not isinstance(value, (str, bytes)):
value = _callable_name(value)
if field.name.endswith("_field_names_override"):
base_fields = MoEFieldNames() if field.name.startswith("moe_") else MlaFieldNames()
value = _normalize_override_for_review(value, base_fields)
if not value:
continue
if field.name == "moe_num_experts_key" and value == "num_experts":
continue
data[field.name] = value
return data
def _callable_name(value: Any) -> str:
if callable(value) and not isinstance(value, (str, bytes)):
return f"{value.__module__}.{value.__name__}"
return str(value)