"""NPU Graph Operator Handler -- base class, global registry, and utilities.

This module contains the core building blocks of the NPU Graph operator
handler framework:

- :class:`NpuGraphOpHandler`          -- abstract base class for handlers.
- :data:`_NPU_GRAPH_OP_HANDLERS`      -- global registry (dict).
- :func:`register_npu_graph_handler`  -- class-decorator for registration.

Design Contracts
----------------
1. **Stateless Handler** -- All hook methods are ``@classmethod``; the
   registry stores *class objects*, never instances.
2. **Function Replacement Consistency** -- When ``prepare_capture``
   substitutes ``func`` with ``actual_func``, the handler registered for
   ``actual_func.__name__`` **must** implement a compatible
   ``update_args`` (typically guaranteed by a shared intermediate base).

See Also
--------
``torch_npu/npu/graphs.py`` for the template-method skeleton that
consumes this registry.
"""

import logging
from copy import deepcopy

import torch
from torch_npu._C import _weak_ref_tensor as TensorWeakRef

logger = logging.getLogger(__name__)


class NpuGraphOpHandler:
    r"""Base class for NPU Graph operator handlers.

    Subclasses override ``@classmethod`` hooks to customize capture / update
    behavior for specific operators, while the framework keeps
    stream / event / task-group orchestration in a common template.

    **Stateless by design** -- All hook methods are ``@classmethod`` (first
    parameter is ``cls``, not ``self``).  There is no instance; the global
    registry stores **class objects** directly.  This structurally prevents
    storing mutable per-invocation state.  Class-level constants (e.g.
    ``UPDATE_SPECS``) are accessible via ``cls``.

    **Declarative update_args via UPDATE_SPECS**: Subclasses declare which
    update keys they consume and where those keys live in args/kwargs via the
    ``UPDATE_SPECS`` class attribute. The base ``update_args`` implementation
    walks the spec and applies updates uniformly. Subclasses normally do not
    need to override ``update_args``.

    Schema:

    .. code-block:: python

        UPDATE_SPECS: Dict[op_name, List[Tuple[Literal["arg", "kwarg"], int | str, key_name]]]

    Example:

    .. code-block:: python

        @register_npu_graph_handler(["my_op", "my_op.default"])
        class MyHandler(NpuGraphOpHandler):
            UPDATE_SPECS = {
                "my_op":         [("arg", 2, "batch")],
                "my_op.default": [("arg", 2, "batch")],
            }
    """

    UPDATE_SPECS = {}

    @classmethod
    def prepare_capture(cls, func, args, kwargs):
        r"""Prepare operator call before graph-task recording.

        This hook runs **before** ``graph_task_group_begin`` and can be used
        for operator-specific preprocessing such as workspace allocation,
        output pre-allocation, or switching from ``.default`` to ``.out``
        overloads.

        .. note:: **Function Replacement Contract**

            If ``actual_func`` differs from ``func``, ensure the handler
            registered for ``actual_func.__name__`` implements a compatible
            ``update_args``.  The recommended approach is to share a common
            base class that defines ``update_args`` (see ``_IFAv1Base`` /
            ``_IFAv2Base``).

        Args:
            func (OpOverload): Original operator callable.
            args (tuple): Original arguments.
            kwargs (dict): Original keyword arguments.

        Returns:
            tuple[Callable, tuple, dict]: ``(actual_func, args, kwargs)`` to
            execute during recording.
        """
        return func, args, kwargs

    @classmethod
    def postprocess_result(cls, result, kwargs):
        r"""Post-process operator return value after recording.

        Called after ``graph_task_group_end`` and dispatch-record creation.

        Args:
            result: Raw return value from ``actual_func(*args, **kwargs)``.
            kwargs (dict): Current keyword arguments (may contain ``"out"``).

        Returns:
            Final value returned to the Python caller.
        """
        return result

    @classmethod
    def get_update_specs(cls, op_name):
        r"""Return per-op update specs.

        Args:
            op_name (str): Operator dispatch name (e.g. ``"_npu_paged_attention.default"``).

        Returns:
            List of ``(loc, idx_or_name, key)`` tuples. ``loc`` is ``"arg"`` or
            ``"kwarg"``; ``idx_or_name`` is an int index for ``"arg"`` or a
            string name for ``"kwarg"``; ``key`` is the user-facing update key.
            Empty list if this op is not in ``UPDATE_SPECS``.
        """
        return cls.UPDATE_SPECS.get(op_name, [])

    @classmethod
    def update_args(cls, dispatch_record, update_input):
        r"""Apply operator-specific updates by walking ``UPDATE_SPECS``.

        Default implementation reads the spec for this op and assigns the
        matching key's value from ``update_input`` to the recorded args/kwargs
        slot. Subclasses normally do not need to override this; declare
        ``UPDATE_SPECS`` instead.

        Args:
            dispatch_record (_GraphDispatchRecord): Recorded operator call.
            update_input (dict): User-provided update payload.
        """
        specs = cls.get_update_specs(dispatch_record.op_cache_entry.__name__)
        for loc, idx_or_name, key in specs:
            if key not in update_input:
                continue
            if loc == "arg":
                if len(dispatch_record.args) > idx_or_name:
                    dispatch_record.args[idx_or_name] = update_input[key]
            elif loc == "kwarg":
                dispatch_record.kwargs[idx_or_name] = update_input[key]

    @classmethod
    def record_wrap_kwarg(cls, key, value, tensor_param_names):
        r"""Convert a kwarg value into record-time storage representation.

        Called only during the **capture** phase (creating the dispatch
        record).  The purpose of ``TensorWeakRef`` conversion is to avoid the
        Python-side record holding strong references to NPU tensors, letting
        the C++ graph runtime manage tensor memory lifetimes.

        .. note::

            The **update** phase uses direct assignment for kwargs
            (``record.kwargs[key] = update_input[key]``), consistent with the
            original implementation and with how ``update_args`` handles
            arguments.  Update is a short-lived "assign -> replay"
            flow where weak-ref conversion is unnecessary.

        Logic (consistent with original, with list/tuple generalisation):

        - ``None`` -> ``None`` (fast path)
        - ``list`` / ``tuple`` -> element-wise: NPU Tensor -> ``TensorWeakRef``,
          else -> ``deepcopy`` (replaces old hardcoded ``if key == "out"``
          that assumed exactly 2 Tensors).  Only NPU tensors are wrapped;
          CPU tensors use ``deepcopy`` because ``TensorWeakRef`` from
          torch_npu._C is only valid for NPU tensors.
        - Single value where ``key`` in ``tensor_param_names`` and value is an
          NPU Tensor -> ``TensorWeakRef``
        - Everything else -> ``deepcopy``

        Args:
            key (str): Kwarg name.
            value: Kwarg value.
            tensor_param_names (list[str]): Tensor-typed kwarg names parsed
                from operator schema.

        Returns:
            Stored value for the dispatch record.
        """
        if value is None:
            return None

        def _is_npu_tensor(t):
            return torch.is_tensor(t) and "npu" in str(t.device)

        if isinstance(value, (list, tuple)):
            wrapped = [
                TensorWeakRef(t) if _is_npu_tensor(t) else deepcopy(t)
                for t in value
            ]
            return type(value)(wrapped)

        if key in tensor_param_names and _is_npu_tensor(value):
            return TensorWeakRef(value)

        return deepcopy(value)


# ---------------------------------------------------------------------------
#  Global Registry
# ---------------------------------------------------------------------------

_NPU_GRAPH_OP_HANDLERS = {}


def register_npu_graph_handler(op_names):
    r"""Register an operator handler via class decorator.

    The decorated class itself (not an instance) is stored in the global
    registry.  All hook methods must be ``@classmethod``.

    Args:
        op_names (str or list[str]): Operator names resolved from
            ``func.__name__`` in ``__torch_dispatch__`` (for example,
            ``"my_op"``, ``"my_op.default"``, ``"my_op.out"``).

    Returns:
        A class decorator.

    Example::

        @register_npu_graph_handler(["my_op", "my_op.default"])
        class MyHandler(NpuGraphOpHandler):
            ...
    """
    def decorator(cls):
        names = op_names if isinstance(op_names, (list, tuple)) else [op_names]
        for name in names:
            if name in _NPU_GRAPH_OP_HANDLERS:
                existing = _NPU_GRAPH_OP_HANDLERS[name].__name__
                logger.warning(
                    "NpuGraphOpHandler for '%s' is being overridden: %s -> %s",
                    name, existing, cls.__name__,
                )
            _NPU_GRAPH_OP_HANDLERS[name] = cls   # store class, not instance
        return cls
    return decorator