import itertools
import torch
from torch._inductor.virtualized import ops, OpsValue, V
from torch._inductor.ir import log, Layout
from torch._inductor import config
def patch_fallback_kernel_codegen():
def codegen_npu(self, wrapper) -> None:
kernel = self.op_overload
if kernel.namespace == "aten":
if not isinstance(kernel, torch._ops.OpOverload):
raise AssertionError(f"kernel should be OpOverload, but got {type(kernel)}")
if V.graph.cpp_wrapper:
from torchgen.aoti.fallback_ops import inductor_fallback_ops
self.use_runtime_dispatch = True
if str(kernel) in inductor_fallback_ops:
log.warning(
"%s is using proxy executor as fallback instead of aoti shim.",
kernel,
)
elif kernel.namespace == "_quantized":
if not isinstance(kernel, torch._ops.OpOverload):
raise AssertionError
elif V.graph.cpp_wrapper:
self.use_runtime_dispatch = (
kernel not in config.aot_inductor.custom_ops_to_c_shims
)
if (
V.graph.cpp_wrapper
and isinstance(kernel, torch._ops.OpOverload)
and not self.use_runtime_dispatch
):
def is_number(t: torch.JitType) -> bool:
if isinstance(t, torch.OptionalType):
return is_number(t.getElementType())
return isinstance(t, torch.NumberType)
args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
args_iter = itertools.chain(
args,
(
self.get_kwargs_value(k, **kwargs)
for k in self.ordered_kwargs_for_cpp_kernel
),
)
self.use_runtime_dispatch = any(
isinstance(v, complex) and is_number(a.real_type)
for v, a in zip(args_iter, kernel._schema.arguments)
)
self.codegen_comment(wrapper)
if self.use_runtime_dispatch:
exported_args = self.export_extern_kernel_node()
wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(),
self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload,
exported_args,
self.outputs if self.outputs else self.mutation_outputs,
)
else:
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_unbacked_symbol_defs(wrapper)
from torch._inductor.ir import FallbackKernel
FallbackKernel.codegen = codegen_npu
def patch_extern_kernel_codegen_size_asserts():
from torch._inductor.ir import ExternKernel
from . import config as npu_config
original_codegen_size_asserts = ExternKernel.codegen_size_asserts
def npu_codegen_size_asserts(self, wrapper):
fx_node = getattr(self, 'fx_node', None)
should_skip = False
if fx_node and fx_node.target:
skip_config = getattr(npu_config, 'skip_specific_stride_asserts', [])
if isinstance(skip_config, (list, tuple)):
should_skip = fx_node.target in skip_config
if should_skip:
if config.size_asserts and not V.graph.cpp_wrapper:
from torch._inductor.utils import sympy_product
if sympy_product(self.get_size()) == 0:
return
wrapper.writeline(
f"# NPU: Skipping stride assertion for {fx_node.target}"
)
else:
original_codegen_size_asserts(self, wrapper)
ExternKernel.codegen_size_asserts = npu_codegen_size_asserts