import inspect
import os
import sys
import logging
from typing import Any, Optional, TYPE_CHECKING
import importlib
import torch
import torch_npu
from torch import _TorchCompileWrapper
from torch._dynamo import optimize
from torch._dynamo.utils import tensortype_to_dtype
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.constant import ConstantVariable
from torch._dynamo.variables.ctx_manager import AutocastModeVariable
from torch._dynamo.variables.functions import SkipFunctionVariable
from torch._dynamo.variables.lists import TupleVariable
from torch._dynamo.variables.streams import StreamContextVariable, StreamVariable
from torch._dynamo.variables.tensor import TensorVariable
from torch._dynamo.variables.torch import (
TorchCtxManagerClassVariable,
TorchInGraphFunctionVariable,
)
from torch._dynamo.variables.user_defined import UserDefinedClassVariable
from torch_npu.dynamo import _get_global_npu_backend
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
use_jit_script = False
log = logging.getLogger(__name__)
class NPUTorchCtxManagerClassVariable(TorchCtxManagerClassVariable):
def call_function(self, tx, args, kwargs):
return NPUAutocastModeVariable.create(self.value, args, kwargs)
class NPUAutocastModeVariable(AutocastModeVariable):
@staticmethod
def create(func, args, kwargs):
bound_args = inspect.signature(func).bind(*args, **kwargs)
bound_args.apply_defaults()
target_values = []
kwargs.clear()
for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
if key == "device_type" and func in [
torch_npu.npu.amp.autocast,
]:
arg = "npu" if func is torch_npu.npu.amp.autocast else "cpu"
else:
arg = bound_args.arguments[key]
if isinstance(arg, VariableTracker):
target_values.append(arg.as_python_constant())
else:
target_values.append(arg)
var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
return var
def UserDefinedClassVariable__new__(cls, value, **kwargs):
if value in [
torch.npu.amp.autocast,
torch_npu.npu.amp.autocast,
torch.npu.amp.autocast_mode.autocast,
torch_npu.npu.amp.autocast_mode.autocast,
]:
return NPUTorchCtxManagerClassVariable(value, **kwargs)
elif value in [
torch_npu.npu.BoolTensor,
torch_npu.npu.ByteTensor,
torch_npu.npu.CharTensor,
torch_npu.npu.DoubleTensor,
torch_npu.npu.FloatTensor,
torch_npu.npu.HalfTensor,
torch_npu.npu.IntTensor,
torch_npu.npu.LongTensor,
torch_npu.npu.ShortTensor,
torch_npu.npu.BFloat16Tensor,
]:
return TorchInGraphFunctionVariable(value, **kwargs)
return cls.__new__raw(cls)
def SkipFunctionVariable__new__(cls, value, reason=None, **kwargs):
if value in [
torch.npu.stream,
torch_npu.npu.stream,
torch_npu.npu.utils.stream,
]:
return TorchInGraphFunctionVariable(value, **kwargs)
return cls.__new__raw(cls)
def TensorVariable_call_method(self, tx, name, args, kwargs):
if (
name == "type"
and self.dtype is not None
and len(args) == 0
and isinstance(self.device, torch.device)
and self.device.type == "npu"
):
tensortype = next(k for k, v in tensortype_to_dtype.items() if self.dtype in v)
constant_result = ConstantVariable.create(f"torch.npu.{tensortype.__name__}")
if len(args) == 1:
return constant_result.getitem_const(args[0])
elif args:
return TupleVariable([constant_result.getitem_const(a) for a in args])
return constant_result
else:
return TensorVariable.call_method_raw(self, tx, name, args, kwargs)
class _InductorNpuRegistry:
_disabled_register = False
_loaded_backend = None
@classmethod
def register_inductor_npu(cls):
if cls._disabled_register:
return
current = os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")
if cls._loaded_backend != current:
if "torch_npu._inductor" not in sys.modules:
importlib.import_module("torch_npu._inductor")
else:
sys.modules["torch_npu._inductor"]._load_backend()
cls._loaded_backend = current
@classmethod
def disable_register(cls):
cls._disabled_register = True
@classmethod
def enable_register(cls):
cls._disabled_register = False
@classmethod
def has_initialized(cls):
return cls._loaded_backend is not None
def is_inductor_npu_initialized():
return _InductorNpuRegistry.has_initialized()
def disable_register_inductor_npu():
_InductorNpuRegistry.disable_register()
def enable_register_inductor_npu():
_InductorNpuRegistry.enable_register()
def register_inductor_npu():
_InductorNpuRegistry.register_inductor_npu()
def _resolve_npu_backend_from_wrapper(wrapper) -> str:
"""Resolve npu backend with priority: wrapper options > global config > env."""
wrapper_backend = wrapper.config.get("npu_backend")
if wrapper_backend not in (None, "", "default"):
return wrapper_backend
global_backend = getattr(torch._inductor.config, "npu_backend", None)
if global_backend not in (None, "", "default"):
return global_backend
return os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")
class _NpuBackendScope:
"""Apply resolved npu backend for one compile invocation and restore env."""
def __init__(self, backend: str):
self.backend = backend
self._old_env = None
def __enter__(self):
self._old_env = os.environ.get("TORCHINDUCTOR_NPU_BACKEND")
os.environ["TORCHINDUCTOR_NPU_BACKEND"] = self.backend
register_inductor_npu()
return self
def __exit__(self, exc_type, exc, tb):
if self._old_env is None:
os.environ.pop("TORCHINDUCTOR_NPU_BACKEND", None)
else:
os.environ["TORCHINDUCTOR_NPU_BACKEND"] = self._old_env
return False
def patch_inductor_wrapper():
from typing import Any
from torch import _TorchCompileInductorWrapper
from torch.utils._config_module import Config, ConfigModule
from torch_npu._compat.utils import make_config_entry
src_init = _TorchCompileInductorWrapper.__init__
src_get_config_copy = ConfigModule.get_config_copy
src_call = _TorchCompileInductorWrapper.__call__
def new_get_config_copy(self) -> dict[str, Any]:
ori_dict = src_get_config_copy(self)
if self is not torch._inductor.config:
return ori_dict
if "npu_backend" not in ori_dict:
ori_dict["npu_backend"] = "default"
self._config["npu_backend"] = make_config_entry(
Config(default="default", value_type=str),
name="npu_backend",
)
return ori_dict
def new_init(self, mode, options, dynamic, name=None):
if name is not None:
src_init(self, mode, options, dynamic, name)
else:
src_init(self, mode, options, dynamic)
backend = _resolve_npu_backend_from_wrapper(self)
if backend == "mlir":
with _NpuBackendScope(backend):
log.info("Running MLIR backend")
device_id = torch_npu.npu.current_device()
torch_npu._C._recovery_all_npu_stream(device_id)
if backend == "dvm":
with _NpuBackendScope(backend):
log.info("Running dvm backend")
def new_call(self, model_, inputs_):
backend = _resolve_npu_backend_from_wrapper(self)
with _NpuBackendScope(backend):
return src_call(self, model_, inputs_)
_TorchCompileInductorWrapper.__call__ = new_call
_TorchCompileInductorWrapper.__init__ = new_init
ConfigModule.get_config_copy = new_get_config_copy
torch._inductor.config.get_config_copy()
def patch_dynamo_optimize():
src_optimize = optimize
def npu_optimize(*args, **kwargs):
backend = None
if "backend" in kwargs:
backend = kwargs["backend"]
elif len(args) == 1:
backend = args[0]
backend_name = None
if isinstance(backend, str):
backend_name = backend
elif isinstance(backend, _TorchCompileWrapper):
backend_name = backend.compiler_name
if backend_name == "npu":
_get_global_npu_backend(backend_name)
return src_optimize(*args, **kwargs)
torch._dynamo.optimize = npu_optimize
def patch_builtin_variable():
origin_call_id = torch._dynamo.variables.builtin.BuiltinVariable.call_id
def _wrap_call_id(self, tx, *args):
if torch._dynamo.variables.builtin.istype(
args[0], torch._dynamo.variables.streams.EventVariable
):
return torch._dynamo.variables.ConstantVariable.create(id(args[0].value))
return origin_call_id(self, tx, *args)
torch._dynamo.variables.builtin.BuiltinVariable.call_id = _wrap_call_id
def patch_stream_event_variable_python_type():
"""
Preserve backend-specific stream/event Python types in Dynamo.
PyTorch's generic StreamVariable/EventVariable report torch.Stream and
torch.Event. NPU subclasses have Python methods that use super(NpuType,
self), so Dynamo must use the real runtime subclass when tracing those
methods, especially when profiler wrappers cause Dynamo to inline them.
"""
def python_type(self):
return type(self.value)
streams = torch._dynamo.variables.streams
streams.StreamVariable.python_type = python_type
streams.EventVariable.python_type = python_type
class NpuStreamContextVariable(StreamContextVariable):
"""This represents NPU stream context with FX graph set_stream node creation."""
@staticmethod
def create(
tx: "InstructionTranslator",
stream_to_enter: "StreamVariable",
**kwargs: dict[str, Any],
) -> "NpuStreamContextVariable":
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
device_interface = get_interface_for_device(stream_to_enter.device)
current_stream_var = wrap_fx_proxy_cls(
StreamVariable,
tx,
tx.output.create_proxy(
"call_function",
device_interface.current_stream,
(None,),
{},
),
)
return NpuStreamContextVariable(
stream_to_enter,
current_stream=current_stream_var,
device_interface=device_interface,
**kwargs,
)
def __init__(
self,
stream: Optional["StreamVariable"],
current_stream: Optional["StreamVariable"] = None,
device_interface: Any | None = None,
**kwargs: Any,
) -> None:
self.current_stream = current_stream
self.device_interface = device_interface
super().__init__(stream, **kwargs)
def enter(
self, tx: "InstructionTranslator", *args: VariableTracker
) -> VariableTracker:
if self.get_stream():
tx.output.create_proxy(
"call_function",
self.device_interface.set_stream,
(self.get_stream().as_proxy(),),
{},
)
return super().enter(tx)
def exit(
self, tx: "InstructionTranslator", *args: VariableTracker
) -> VariableTracker:
if self.get_stream():
tx.output.create_proxy(
"call_function",
self.device_interface.set_stream,
(self.current_stream.as_proxy(),),
{},
)
return super().exit(tx, *args)
def patch_npu_stream_context():
from torch._dynamo.device_interface import get_interface_for_device
def _handle_npu_device_interface_stream(self, tx, stream):
return NpuStreamContextVariable.create(tx, stream)
TorchInGraphFunctionVariable._get_handlers()[
get_interface_for_device("npu").stream
] = _handle_npu_device_interface_stream
def patch_user_defined_class_variable():
import functools
original_method = UserDefinedClassVariable._in_graph_classes
@staticmethod
@functools.lru_cache(None)
def patched_in_graph_classes():
result = original_method()
result.add(torch.npu.Event)
result.add(torch.npu.Stream)
return result
UserDefinedClassVariable._in_graph_classes = patched_in_graph_classes
def add_dynamo_methods():
UserDefinedClassVariable.__new__raw = UserDefinedClassVariable.__new__
UserDefinedClassVariable.__new__ = UserDefinedClassVariable__new__
SkipFunctionVariable.__new__raw = SkipFunctionVariable.__new__
SkipFunctionVariable.__new__ = SkipFunctionVariable__new__
TensorVariable.call_method_raw = TensorVariable.call_method
TensorVariable.call_method = TensorVariable_call_method
patch_dynamo_optimize()
patch_inductor_wrapper()
patch_user_defined_class_variable()
patch_stream_event_variable_python_type()
patch_builtin_variable()
patch_npu_stream_context()
patch_user_defined_class_variable()