"""NPU Graph Operator Handler -- base class, global registry, and utilities.
This module contains the core building blocks of the NPU Graph operator
handler framework:
- :class:`NpuGraphOpHandler` -- abstract base class for handlers.
- :data:`_NPU_GRAPH_OP_HANDLERS` -- global registry (dict).
- :func:`register_npu_graph_handler` -- class-decorator for registration.
Design Contracts
----------------
1. **Stateless Handler** -- All hook methods are ``@classmethod``; the
registry stores *class objects*, never instances.
2. **Function Replacement Consistency** -- When ``prepare_capture``
substitutes ``func`` with ``actual_func``, the handler registered for
``actual_func.__name__`` **must** implement a compatible
``update_args`` (typically guaranteed by a shared intermediate base).
See Also
--------
``torch_npu/npu/graphs.py`` for the template-method skeleton that
consumes this registry.
"""
import logging
from copy import deepcopy
import torch
from torch_npu._C import _weak_ref_tensor as TensorWeakRef
logger = logging.getLogger(__name__)
class NpuGraphOpHandler:
r"""Base class for NPU Graph operator handlers.
Subclasses override ``@classmethod`` hooks to customize capture / update
behavior for specific operators, while the framework keeps
stream / event / task-group orchestration in a common template.
**Stateless by design** -- All hook methods are ``@classmethod`` (first
parameter is ``cls``, not ``self``). There is no instance; the global
registry stores **class objects** directly. This structurally prevents
storing mutable per-invocation state. Class-level constants (e.g.
``UPDATE_SPECS``) are accessible via ``cls``.
**Declarative update_args via UPDATE_SPECS**: Subclasses declare which
update keys they consume and where those keys live in args/kwargs via the
``UPDATE_SPECS`` class attribute. The base ``update_args`` implementation
walks the spec and applies updates uniformly. Subclasses normally do not
need to override ``update_args``.
Schema:
.. code-block:: python
UPDATE_SPECS: Dict[op_name, List[Tuple[Literal["arg", "kwarg"], int | str, key_name]]]
Example:
.. code-block:: python
@register_npu_graph_handler(["my_op", "my_op.default"])
class MyHandler(NpuGraphOpHandler):
UPDATE_SPECS = {
"my_op": [("arg", 2, "batch")],
"my_op.default": [("arg", 2, "batch")],
}
"""
UPDATE_SPECS = {}
@classmethod
def prepare_capture(cls, func, args, kwargs):
r"""Prepare operator call before graph-task recording.
This hook runs **before** ``graph_task_group_begin`` and can be used
for operator-specific preprocessing such as workspace allocation,
output pre-allocation, or switching from ``.default`` to ``.out``
overloads.
.. note:: **Function Replacement Contract**
If ``actual_func`` differs from ``func``, ensure the handler
registered for ``actual_func.__name__`` implements a compatible
``update_args``. The recommended approach is to share a common
base class that defines ``update_args`` (see ``_IFAv1Base`` /
``_IFAv2Base``).
Args:
func (OpOverload): Original operator callable.
args (tuple): Original arguments.
kwargs (dict): Original keyword arguments.
Returns:
tuple[Callable, tuple, dict]: ``(actual_func, args, kwargs)`` to
execute during recording.
"""
return func, args, kwargs
@classmethod
def postprocess_result(cls, result, kwargs):
r"""Post-process operator return value after recording.
Called after ``graph_task_group_end`` and dispatch-record creation.
Args:
result: Raw return value from ``actual_func(*args, **kwargs)``.
kwargs (dict): Current keyword arguments (may contain ``"out"``).
Returns:
Final value returned to the Python caller.
"""
return result
@classmethod
def get_update_specs(cls, op_name):
r"""Return per-op update specs.
Args:
op_name (str): Operator dispatch name (e.g. ``"_npu_paged_attention.default"``).
Returns:
List of ``(loc, idx_or_name, key)`` tuples. ``loc`` is ``"arg"`` or
``"kwarg"``; ``idx_or_name`` is an int index for ``"arg"`` or a
string name for ``"kwarg"``; ``key`` is the user-facing update key.
Empty list if this op is not in ``UPDATE_SPECS``.
"""
return cls.UPDATE_SPECS.get(op_name, [])
@classmethod
def update_args(cls, dispatch_record, update_input):
r"""Apply operator-specific updates by walking ``UPDATE_SPECS``.
Default implementation reads the spec for this op and assigns the
matching key's value from ``update_input`` to the recorded args/kwargs
slot. Subclasses normally do not need to override this; declare
``UPDATE_SPECS`` instead.
Args:
dispatch_record (_GraphDispatchRecord): Recorded operator call.
update_input (dict): User-provided update payload.
"""
specs = cls.get_update_specs(dispatch_record.op_cache_entry.__name__)
for loc, idx_or_name, key in specs:
if key not in update_input:
continue
if loc == "arg":
if len(dispatch_record.args) > idx_or_name:
dispatch_record.args[idx_or_name] = update_input[key]
elif loc == "kwarg":
dispatch_record.kwargs[idx_or_name] = update_input[key]
@classmethod
def record_wrap_kwarg(cls, key, value, tensor_param_names):
r"""Convert a kwarg value into record-time storage representation.
Called only during the **capture** phase (creating the dispatch
record). The purpose of ``TensorWeakRef`` conversion is to avoid the
Python-side record holding strong references to NPU tensors, letting
the C++ graph runtime manage tensor memory lifetimes.
.. note::
The **update** phase uses direct assignment for kwargs
(``record.kwargs[key] = update_input[key]``), consistent with the
original implementation and with how ``update_args`` handles
arguments. Update is a short-lived "assign -> replay"
flow where weak-ref conversion is unnecessary.
Logic (consistent with original, with list/tuple generalisation):
- ``None`` -> ``None`` (fast path)
- ``list`` / ``tuple`` -> element-wise: NPU Tensor -> ``TensorWeakRef``,
else -> ``deepcopy`` (replaces old hardcoded ``if key == "out"``
that assumed exactly 2 Tensors). Only NPU tensors are wrapped;
CPU tensors use ``deepcopy`` because ``TensorWeakRef`` from
torch_npu._C is only valid for NPU tensors.
- Single value where ``key`` in ``tensor_param_names`` and value is an
NPU Tensor -> ``TensorWeakRef``
- Everything else -> ``deepcopy``
Args:
key (str): Kwarg name.
value: Kwarg value.
tensor_param_names (list[str]): Tensor-typed kwarg names parsed
from operator schema.
Returns:
Stored value for the dispatch record.
"""
if value is None:
return None
def _is_npu_tensor(t):
return torch.is_tensor(t) and "npu" in str(t.device)
if isinstance(value, (list, tuple)):
wrapped = [
TensorWeakRef(t) if _is_npu_tensor(t) else deepcopy(t)
for t in value
]
return type(value)(wrapped)
if key in tensor_param_names and _is_npu_tensor(value):
return TensorWeakRef(value)
return deepcopy(value)
_NPU_GRAPH_OP_HANDLERS = {}
def register_npu_graph_handler(op_names):
r"""Register an operator handler via class decorator.
The decorated class itself (not an instance) is stored in the global
registry. All hook methods must be ``@classmethod``.
Args:
op_names (str or list[str]): Operator names resolved from
``func.__name__`` in ``__torch_dispatch__`` (for example,
``"my_op"``, ``"my_op.default"``, ``"my_op.out"``).
Returns:
A class decorator.
Example::
@register_npu_graph_handler(["my_op", "my_op.default"])
class MyHandler(NpuGraphOpHandler):
...
"""
def decorator(cls):
names = op_names if isinstance(op_names, (list, tuple)) else [op_names]
for name in names:
if name in _NPU_GRAPH_OP_HANDLERS:
existing = _NPU_GRAPH_OP_HANDLERS[name].__name__
logger.warning(
"NpuGraphOpHandler for '%s' is being overridden: %s -> %s",
name, existing, cls.__name__,
)
_NPU_GRAPH_OP_HANDLERS[name] = cls
return cls
return decorator