import importlib
import inspect
import logging
import os
import sys
import torch
import torch_npu
from torch import _TorchCompileWrapper
from torch._dynamo import optimize
from torch._dynamo.utils import tensortype_to_dtype
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.constant import ConstantVariable
from torch._dynamo.variables.ctx_manager import AutocastModeVariable
from torch._dynamo.variables.functions import SkipFunctionVariable
from torch._dynamo.variables.lists import TupleVariable
from torch._dynamo.variables.tensor import TensorVariable
from torch._dynamo.variables.torch import (
TorchCtxManagerClassVariable,
TorchInGraphFunctionVariable,
)
from torch._dynamo.variables.user_defined import UserDefinedClassVariable
from torch_npu.dynamo import _get_global_npu_backend
use_jit_script = False
log = logging.getLogger(__name__)
class NPUTorchCtxManagerClassVariable(TorchCtxManagerClassVariable):
def call_function(self, tx, args, kwargs):
return NPUAutocastModeVariable.create(self.value, args, kwargs)
class NPUAutocastModeVariable(AutocastModeVariable):
@staticmethod
def create(func, args, kwargs):
bound_args = inspect.signature(func).bind(*args, **kwargs)
bound_args.apply_defaults()
target_values = []
kwargs.clear()
for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
if key == "device_type" and func in [
torch_npu.npu.amp.autocast,
]:
arg = "npu" if func is torch_npu.npu.amp.autocast else "cpu"
else:
arg = bound_args.arguments[key]
if isinstance(arg, VariableTracker):
target_values.append(arg.as_python_constant())
else:
target_values.append(arg)
var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
return var
def UserDefinedClassVariable__new__(cls, value, **kwargs):
if value in [
torch.npu.amp.autocast,
torch_npu.npu.amp.autocast,
torch.npu.amp.autocast_mode.autocast,
torch_npu.npu.amp.autocast_mode.autocast,
]:
return NPUTorchCtxManagerClassVariable(value, **kwargs)
elif value in [
torch.npu.Stream,
torch_npu.npu.Stream,
torch.npu.streams.Stream,
torch_npu.npu.streams.Stream,
torch_npu.npu.BoolTensor,
torch_npu.npu.ByteTensor,
torch_npu.npu.CharTensor,
torch_npu.npu.DoubleTensor,
torch_npu.npu.FloatTensor,
torch_npu.npu.HalfTensor,
torch_npu.npu.IntTensor,
torch_npu.npu.LongTensor,
torch_npu.npu.ShortTensor,
torch_npu.npu.BFloat16Tensor,
]:
return TorchInGraphFunctionVariable(value, **kwargs)
return cls.__new__raw(cls)
def SkipFunctionVariable__new__(cls, value, reason=None, **kwargs):
if value in [
torch.npu.stream,
torch_npu.npu.stream,
torch_npu.npu.utils.stream,
]:
return TorchInGraphFunctionVariable(value, **kwargs)
return cls.__new__raw(cls)
def TensorVariable_call_method(self, tx, name, args, kwargs):
if (
name == "type"
and self.dtype is not None
and len(args) == 0
and isinstance(self.device, torch.device)
and self.device.type == "npu"
):
tensortype = next(k for k, v in tensortype_to_dtype.items() if self.dtype in v)
constant_result = ConstantVariable.create(f"torch.npu.{tensortype.__name__}")
if len(args) == 1:
return constant_result.getitem_const(args[0])
elif args:
return TupleVariable([constant_result.getitem_const(a) for a in args])
return constant_result
else:
return TensorVariable.call_method_raw(self, tx, name, args, kwargs)
class _InductorNpuRegistry:
_disabled_register = False
_loaded_backend = None
@classmethod
def register_inductor_npu(cls):
if cls._disabled_register:
return
current = os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")
if cls._loaded_backend != current:
if "torch_npu._inductor" not in sys.modules:
importlib.import_module("torch_npu._inductor")
else:
sys.modules["torch_npu._inductor"]._load_backend()
cls._loaded_backend = current
@classmethod
def disable_register(cls):
cls._disabled_register = True
@classmethod
def enable_register(cls):
cls._disabled_register = False
@classmethod
def has_initialized(cls):
return cls._loaded_backend is not None
def is_inductor_npu_initialized():
return _InductorNpuRegistry.has_initialized()
def disable_register_inductor_npu():
_InductorNpuRegistry.disable_register()
def enable_register_inductor_npu():
_InductorNpuRegistry.enable_register()
def register_inductor_npu():
_InductorNpuRegistry.register_inductor_npu()
def _resolve_npu_backend_from_wrapper(wrapper) -> str:
"""Resolve npu backend with priority: wrapper options > global config > env."""
wrapper_backend = wrapper.config.get("npu_backend")
if wrapper_backend not in (None, "", "default"):
return wrapper_backend
global_backend = getattr(torch._inductor.config, "npu_backend", None)
if global_backend not in (None, "", "default"):
return global_backend
return os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")
class _NpuBackendScope:
"""Apply resolved npu backend for one compile invocation and restore env."""
def __init__(self, backend: str):
self.backend = backend
self._old_env = None
def __enter__(self):
self._old_env = os.environ.get("TORCHINDUCTOR_NPU_BACKEND")
os.environ["TORCHINDUCTOR_NPU_BACKEND"] = self.backend
register_inductor_npu()
return self
def __exit__(self, exc_type, exc, tb):
if self._old_env is None:
os.environ.pop("TORCHINDUCTOR_NPU_BACKEND", None)
else:
os.environ["TORCHINDUCTOR_NPU_BACKEND"] = self._old_env
return False
def patch_inductor_wrapper():
from typing import Any, Literal, Optional
from torch import _TorchCompileInductorWrapper
from torch.utils._config_module import _ConfigEntry, Config, ConfigModule
src_apply_options = _TorchCompileInductorWrapper.apply_options
src_init = _TorchCompileInductorWrapper.__init__
src_get_config_copy = ConfigModule.get_config_copy
src_call = _TorchCompileInductorWrapper.__call__
def new_apply_options(self, options: Optional[dict[str, Any]]):
if options is not None and options.get("enable_shape_handling", False):
if not is_inductor_npu_initialized():
register_inductor_npu()
torch_npu._inductor.patch_shape_handling()
src_apply_options(self, options)
def new_get_config_copy(self) -> dict[str, Any]:
ori_dict = src_get_config_copy(self)
if self is not torch._inductor.config:
return ori_dict
NpuBackendType = Literal["default", "mlir", "dvm"]
if "npu_backend" not in ori_dict:
ori_dict["npu_backend"] = "default"
self._config["npu_backend"] = _ConfigEntry(
Config(default="default", value_type=NpuBackendType)
)
if "enable_shape_handling" not in ori_dict:
ori_dict["enable_shape_handling"] = False
self._config["enable_shape_handling"] = _ConfigEntry(
Config(default=False, value_type=bool)
)
if "shape_handling_configs" not in ori_dict:
ori_dict["shape_handling_configs"] = None
self._config["shape_handling_configs"] = _ConfigEntry(
Config(default=None, value_type=list)
)
if "shape_handling_dict" not in ori_dict:
ori_dict["shape_handling_dict"] = None
self._config["shape_handling_dict"] = _ConfigEntry(
Config(default=None, value_type=dict)
)
return ori_dict
def new_init(self, mode, options, dynamic):
src_init(self, mode, options, dynamic)
backend = _resolve_npu_backend_from_wrapper(self)
if backend=="mlir" or backend=="dvm":
with _NpuBackendScope(backend):
device_id = torch_npu.npu.current_device()
torch_npu._C._recovery_all_npu_stream(device_id)
def new_call(self, model_, inputs_):
backend = _resolve_npu_backend_from_wrapper(self)
with _NpuBackendScope(backend):
return src_call(self, model_, inputs_)
_TorchCompileInductorWrapper.__call__ = new_call
_TorchCompileInductorWrapper.apply_options = new_apply_options
_TorchCompileInductorWrapper.__init__ = new_init
ConfigModule.get_config_copy = new_get_config_copy
torch._inductor.config.get_config_copy()
def patch_dynamo_optimize():
src_optimize = optimize
def npu_optimize(*args, **kwargs):
backend = None
if "backend" in kwargs:
backend = kwargs["backend"]
elif len(args) == 1:
backend = args[0]
backend_name = None
if isinstance(backend, str):
backend_name = backend
elif isinstance(backend, _TorchCompileWrapper):
backend_name = backend.compiler_name
if backend_name == "npu":
_get_global_npu_backend(backend_name)
return src_optimize(*args, **kwargs)
torch._dynamo.optimize = npu_optimize
def patch_base_schedulernode():
from torch._inductor.scheduler import BaseSchedulerNode, ExternKernelSchedulerNode
original_get_read_write_buffer_accesses = (
BaseSchedulerNode.get_read_write_buffer_accesses
)
def new_get_read_write_buffer_accesses(
self_instance, include_reads: bool, include_writes: bool
) -> dict[str, int]:
if isinstance(self_instance, ExternKernelSchedulerNode):
return {}
return original_get_read_write_buffer_accesses(
self_instance, include_reads, include_writes
)
BaseSchedulerNode.get_read_write_buffer_accesses = (
new_get_read_write_buffer_accesses
)
def patch_user_defined_class_variable():
import functools
original_method = UserDefinedClassVariable._in_graph_classes
@staticmethod
@functools.lru_cache(None)
def patched_in_graph_classes():
result = original_method()
result.add(torch.npu.Event)
result.add(torch.npu.Stream)
return result
UserDefinedClassVariable._in_graph_classes = patched_in_graph_classes
def fake_record_stream(self, s):
"""
let dynamo trace Tensor.record_stream as this empty function,
and you can replace it later in your compile backend to an actual function
"""
if isinstance(self, torch._subclasses.fake_tensor.FakeTensor):
return
raise RuntimeError(
"tensor.record_stream is not supported on torch.compile! "
"You should write a pass to replace torch.npu.fake_record_stream to an actual function in FX graph "
"before aot_autograd."
)
def patch_record_stream():
torch.npu.fake_record_stream = fake_record_stream
def method_record_stream(self, s):
tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx()
return torch._dynamo.variables.TorchInGraphFunctionVariable(
torch.npu.fake_record_stream
).call_function(tx, [self, s], {})
torch._dynamo.variables.tensor.TensorVariable.method_record_stream = (
method_record_stream
)
def patch_variable_builder():
original_warp = torch._dynamo.variables.builder.VariableBuilder._wrap
def _patch_wrapper(self, value):
if isinstance(value, torch.npu.Event):
self.install_guards(torch._dynamo.guards.GuardBuilder.ID_MATCH)
torch._dynamo.utils.store_user_object_weakref(value)
event_proxy = self.tx.output.create_proxy(
"call_function",
torch._dynamo.utils.get_user_object_from_id,
(id(value),),
{},
)
torch._dynamo.utils.set_example_value(event_proxy.node, value)
out = torch._dynamo.variables.ctx_manager.EventVariable(
event_proxy,
value,
source=self.source,
)
return out
return original_warp(self, value)
torch._dynamo.variables.builder.VariableBuilder._wrap = _patch_wrapper
def patch_builtin_variable():
origin_call_id = torch._dynamo.variables.builtin.BuiltinVariable.call_id
def _wrap_call_id(self, tx, *args):
if torch._dynamo.variables.builtin.istype(
args[0], torch._dynamo.variables.ctx_manager.EventVariable
):
return torch._dynamo.variables.ConstantVariable.create(id(args[0].value))
return origin_call_id(self, tx, *args)
torch._dynamo.variables.builtin.BuiltinVariable.call_id = _wrap_call_id
def patch_event_variable_python_type():
"""
Add the 'python_type' method to the EventVariable class.
"""
def python_type(self):
return type(self.value)
if "python_type" not in torch._dynamo.variables.ctx_manager.EventVariable.__dict__:
torch._dynamo.variables.ctx_manager.EventVariable.python_type = python_type
def add_dynamo_methods():
UserDefinedClassVariable.__new__raw = UserDefinedClassVariable.__new__
UserDefinedClassVariable.__new__ = UserDefinedClassVariable__new__
SkipFunctionVariable.__new__raw = SkipFunctionVariable.__new__
SkipFunctionVariable.__new__ = SkipFunctionVariable__new__
TensorVariable.call_method_raw = TensorVariable.call_method
TensorVariable.call_method = TensorVariable_call_method
patch_dynamo_optimize()
patch_inductor_wrapper()
patch_base_schedulernode()
patch_user_defined_class_variable()
patch_record_stream()
patch_event_variable_python_type()
patch_variable_builder()
patch_builtin_variable()