"""NPU Graph Operator Handler -- base class, global registry, and utilities.
This module contains the core building blocks of the NPU Graph operator
handler framework:
- :class:`NpuGraphOpHandler` -- abstract base class for handlers.
- :data:`_NPU_GRAPH_OP_HANDLERS` -- global registry (dict).
- :func:`register_npu_graph_handler` -- class-decorator for registration.
Design Contracts
----------------
1. **Stateless Handler** -- All hook methods are ``@classmethod``; the
registry stores *class objects*, never instances.
2. **Function Replacement Consistency** -- When ``prepare_capture``
substitutes ``func`` with ``actual_func``, the handler registered for
``actual_func.__name__`` **must** implement a compatible
``update_args`` (typically guaranteed by a shared intermediate base).
See Also
--------
``torch_npu/npu/graphs.py`` for the template-method skeleton that
consumes this registry.
"""
import logging
from copy import deepcopy
import torch
from torch_npu._C import _weak_ref_tensor as TensorWeakRef
logger = logging.getLogger(__name__)
class NpuGraphOpHandler:
r"""Base class for NPU Graph operator handlers.
Subclasses override ``@classmethod`` hooks to customize capture / update
behavior for specific operators, while the framework keeps
stream / event / task-group orchestration in a common template.
**Stateless by design** -- All hook methods are ``@classmethod`` (first
parameter is ``cls``, not ``self``). There is no instance; the global
registry stores **class objects** directly. This structurally prevents
storing mutable per-invocation state. Class-level constants (e.g.
``_OP_ARG_SPECS``) are accessible via ``cls``.
Users should inherit this class and override the needed hooks:
.. code-block:: python
@register_npu_graph_handler(["my_op", "my_op.default"])
class MyHandler(NpuGraphOpHandler):
@classmethod
def update_args(cls, record, update_input):
if "batch" in update_input and len(record.args) >= 3:
record.args[2] = update_input["batch"]
"""
@classmethod
def prepare_capture(cls, func, args, kwargs):
r"""Prepare operator call before graph-task recording.
This hook runs **before** ``graph_task_group_begin`` and can be used
for operator-specific preprocessing such as workspace allocation,
output pre-allocation, or switching from ``.default`` to ``.out``
overloads.
.. note:: **Function Replacement Contract**
If ``actual_func`` differs from ``func``, ensure the handler
registered for ``actual_func.__name__`` implements a compatible
``update_args``. The recommended approach is to share a common
base class that defines ``update_args`` (see ``_IFAv1Base`` /
``_IFAv2Base``).
Args:
func (OpOverload): Original operator callable.
args (tuple): Original arguments.
kwargs (dict): Original keyword arguments.
Returns:
tuple[Callable, tuple, dict]: ``(actual_func, args, kwargs)`` to
execute during recording.
"""
return func, args, kwargs
@classmethod
def postprocess_result(cls, result, kwargs):
r"""Post-process operator return value after recording.
Called after ``graph_task_group_end`` and dispatch-record creation.
Args:
result: Raw return value from ``actual_func(*args, **kwargs)``.
kwargs (dict): Current keyword arguments (may contain ``"out"``).
Returns:
Final value returned to the Python caller.
"""
return result
@classmethod
def update_args(cls, dispatch_record, update_input):
r"""Apply operator-specific indexed-arg updates.
Framework-level kwargs updates are handled by the dispatch skeleton.
Override this hook only when update values must be applied to
arguments by index.
Args:
dispatch_record (_GraphDispatchRecord): Recorded operator call.
Args can be modified via ``dispatch_record.args[i]``.
update_input (dict): User-provided update payload.
"""
pass
@classmethod
def record_wrap_kwarg(cls, key, value, tensor_param_names):
r"""Convert a kwarg value into record-time storage representation.
Called only during the **capture** phase (creating the dispatch
record). The purpose of ``TensorWeakRef`` conversion is to avoid the
Python-side record holding strong references to NPU tensors, letting
the C++ graph runtime manage tensor memory lifetimes.
.. note::
The **update** phase uses direct assignment for kwargs
(``record.kwargs[key] = update_input[key]``), consistent with the
original implementation and with how ``update_args`` handles
arguments. Update is a short-lived "assign -> replay"
flow where weak-ref conversion is unnecessary.
Logic (consistent with original, with list/tuple generalisation):
- ``None`` -> ``None`` (fast path)
- ``list`` / ``tuple`` -> element-wise: NPU Tensor -> ``TensorWeakRef``,
else -> ``deepcopy`` (replaces old hardcoded ``if key == "out"``
that assumed exactly 2 Tensors). Only NPU tensors are wrapped;
CPU tensors use ``deepcopy`` because ``TensorWeakRef`` from
torch_npu._C is only valid for NPU tensors.
- Single value where ``key`` in ``tensor_param_names`` and value is an
NPU Tensor -> ``TensorWeakRef``
- Everything else -> ``deepcopy``
Args:
key (str): Kwarg name.
value: Kwarg value.
tensor_param_names (list[str]): Tensor-typed kwarg names parsed
from operator schema.
Returns:
Stored value for the dispatch record.
"""
if value is None:
return None
def _is_npu_tensor(t):
return torch.is_tensor(t) and "npu" in str(t.device)
if isinstance(value, (list, tuple)):
wrapped = [
TensorWeakRef(t) if _is_npu_tensor(t) else deepcopy(t)
for t in value
]
return type(value)(wrapped)
if key in tensor_param_names and _is_npu_tensor(value):
return TensorWeakRef(value)
return deepcopy(value)
_NPU_GRAPH_OP_HANDLERS = {}
def register_npu_graph_handler(op_names):
r"""Register an operator handler via class decorator.
The decorated class itself (not an instance) is stored in the global
registry. All hook methods must be ``@classmethod``.
Args:
op_names (str or list[str]): Operator names resolved from
``func.__name__`` in ``__torch_dispatch__`` (for example,
``"my_op"``, ``"my_op.default"``, ``"my_op.out"``).
Returns:
A class decorator.
Example::
@register_npu_graph_handler(["my_op", "my_op.default"])
class MyHandler(NpuGraphOpHandler):
...
"""
def decorator(cls):
names = op_names if isinstance(op_names, (list, tuple)) else [op_names]
for name in names:
if name in _NPU_GRAPH_OP_HANDLERS:
existing = _NPU_GRAPH_OP_HANDLERS[name].__name__
logger.warning(
f"NpuGraphOpHandler for '{name}' is being overridden: "
f"{existing} -> {cls.__name__}"
)
_NPU_GRAPH_OP_HANDLERS[name] = cls
return cls
return decorator