"""Simple Graph Handlers (Paged Attention / MLA)."""

__all__ = []

from .npugraph_handler import NpuGraphOpHandler, register_npu_graph_handler


@register_npu_graph_handler([
    "_npu_paged_attention.default",
    "npu_multi_head_latent_attention.out",
])
class _SimpleGraphHandler(NpuGraphOpHandler):
    """Handler for PA (Paged Attention) and MLA operators.

    Attributes:
        _OP_ARG_SPECS (dict[str, tuple[int, str]]): Specifies
            ``op_name -> (arg_index, update_key)`` for each supported
            operator.
    """

    _OP_ARG_SPECS = {
        "_npu_paged_attention.default": (7, "context_lens"),
        "npu_multi_head_latent_attention.out": (5, "context_lens"),
    }

    @classmethod
    def update_args(cls, record, update_input):
        spec = cls._OP_ARG_SPECS.get(record.op_cache_entry.__name__)
        if spec:
            arg_index, key = spec
            if key in update_input and len(record.args) >= (arg_index + 1):
                record.args[arg_index] = update_input[key]