from __future__ import annotations

from collections.abc import Generator

try:
    from ..model_configs import resolve_configs
    from ..shape_grids import ELEM_TOKENS_GRID
except ImportError:
    from model_configs import resolve_configs
    from shape_grids import ELEM_TOKENS_GRID

from .base import TheoryShapeRow


def generate_split_qkv_rmsnorm_rope_rows(
    model_names: list[str] | None,
) -> Generator[TheoryShapeRow, None, None]:
    """Generate replayable vLLM-Ascend fused QKV+RMSNorm+RoPE rows.

    The custom op consumes a fused QKV input whose last dimension is
    q_hidden + 2 * kv_hidden, then returns Q, K, and V tensors. The profiled
    CSV only records the fused input plus the rope cache width; replay creates
    positions and a full cos/sin cache from that metadata.
    """

    seen: set[tuple[int, int, int, int]] = set()
    for cfg in resolve_configs(model_names):
        if cfg.is_mla():
            continue

        for tp in cfg.tp_sizes:
            if cfg.num_attention_heads % tp != 0:
                continue

            local_q_heads = cfg.num_attention_heads // tp
            if cfg.num_kv_heads >= tp:
                local_kv_heads = cfg.num_kv_heads // tp
            elif tp % cfg.num_kv_heads == 0:
                local_kv_heads = 1
            else:
                continue

            q_hidden = local_q_heads * cfg.head_dim
            kv_hidden = local_kv_heads * cfg.head_dim
            input_hidden = q_hidden + 2 * kv_hidden
            rope_dim = cfg.head_dim
            for tokens in ELEM_TOKENS_GRID:
                key = (tokens, q_hidden, kv_hidden, rope_dim)
                if key in seen:
                    continue
                seen.add(key)
                yield TheoryShapeRow(
                    [(tokens, input_hidden), (rope_dim,)],
                    [(tokens, q_hidden), (tokens, kv_hidden), (tokens, kv_hidden)],
                    extra_values={
                        "Input Data Types": "DT_BF16;DT_BF16",
                        "Input Formats": "ND;ND",
                        "Output Data Types": "DT_BF16;DT_BF16;DT_BF16",
                        "Output Formats": "ND;ND;ND",
                    },
                )