"""
FLOPS factory: centralized registry, model registration, and public entry.
All supported models are explicitly registered here for visibility.
"""
from typing import Dict, Type, List
from transformers import PretrainedConfig
from mindspeed_llm.fsdp2.utils.logging import get_logger
from .flops_base import BaseFlopsEstimator, UnknownModelFlopsEstimator, get_device_flops
from .qwen3_flops import Qwen3MoeFlopsEstimator, Qwen3DenseFlopsEstimator
logger = get_logger(__name__)
class FlopsFactory:
_registry: Dict[str, Type[BaseFlopsEstimator]] = {}
@classmethod
def register_model(cls, model_type: str):
"""
Decorator to register a model that supports MFU calculation.
All supported models are listed visibly in this file.
"""
def decorator(estimator_class: Type[BaseFlopsEstimator]):
cls._registry[model_type] = estimator_class
logger.info_rank0(
f"Registered MFU-capable model: {model_type} -> {estimator_class.__name__}"
)
return estimator_class
return decorator
@classmethod
def get_model_estimator(cls, config: PretrainedConfig) -> BaseFlopsEstimator:
"""Get estimator with model_type matching."""
estimator_cls = cls._registry.get(config.model_type, UnknownModelFlopsEstimator)
return estimator_cls(config)
FlopsFactory.register_model("qwen3_moe")(Qwen3MoeFlopsEstimator)
FlopsFactory.register_model("qwen3")(Qwen3DenseFlopsEstimator)
class FlopsCounter:
"""Public entry for MFU/FLOPS calculation."""
def __init__(self, config: PretrainedConfig):
self.flops_estimator = FlopsFactory.get_model_estimator(config)
def estimate_flops(
self,
batch_seqlens: List[int],
delta_time: float
) -> tuple[float, float]:
"""
Compute achieved and peak device FLOPS.
Returns:
achieved_flops: TFLOPS
peak_flops: device peak TFLOPS
"""
tokens_sum = sum(batch_seqlens)
achieved = self.flops_estimator.calculate_achieved_flops(tokens_sum, batch_seqlens, delta_time)
peak = get_device_flops()
return achieved, peak