import itertools
import operator
from typing import Any
import sympy
import torch
from torch._dynamo.utils import defake
from torch._inductor import config, graph as inductor_graph, metrics
from torch._inductor.utils import clone_preserve_strides
from torch._inductor.virtualized import NullHandler, V
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.node import Node
from torch.utils._sympy.numbers import int_oo
LazyString = inductor_graph.LazyString
OrderedSet = inductor_graph.OrderedSet
Pointwise = inductor_graph.Pointwise
Reduction = inductor_graph.Reduction
StorageBox = inductor_graph.StorageBox
TensorBox = inductor_graph.TensorBox
constrain_to_fake_tensors = inductor_graph.constrain_to_fake_tensors
constrain_to_fx_strides = inductor_graph.constrain_to_fx_strides
fallback_handler = inductor_graph.fallback_handler
fallback_node_due_to_unsupported_type = (
inductor_graph.fallback_node_due_to_unsupported_type
)
gather_origins = inductor_graph.gather_origins
ir = inductor_graph.ir
is_magic_method = inductor_graph.is_magic_method
log = inductor_graph.log
make_channels_last_strides_for = inductor_graph.make_channels_last_strides_for
needs_realized_inputs = inductor_graph.needs_realized_inputs
resolve_unbacked_bindings = inductor_graph.resolve_unbacked_bindings
GraphLowering = inductor_graph.GraphLowering
def patch_codegen_with_cpp_wrapper():
"""
patch codegen for cpp wrapper, add npu for codegen_with_cpp_wrapper function
"""
def npu_codegen_with_cpp_wrapper(self) -> tuple[str, list[tuple[int, Node]]]:
# add "npu" support
if any(device in self.device_types for device in ["cuda", "xpu", "npu"]):
if config.triton.autotune_at_compile_time:
# If autotune_at_compile_time is True, we can do the codegen in one-pass
return self.codegen()
else:
# first pass
self.cpp_wrapper = False
compiled = self.compile_to_module().call
def materialize(
x: torch.SymInt | torch.SymFloat | torch.Tensor,
) -> int | float | torch.Tensor:
if x is None:
return None
elif isinstance(x, (torch.SymInt, torch.SymFloat)):
# Need concrete value to run dynamic shapes and tune the result
return x.node.hint
elif isinstance(x, FakeTensor):
return defake(x)
else:
if not isinstance(x, torch.Tensor):
raise AssertionError(
"Unknown type when creating real inputs" + str(type(x))
)
return x
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None and not isinstance(
V.real_inputs, NullHandler
):
if tracing_context.output_strides:
tracing_context.output_strides.clear()
params_flat = [
param
for param in tracing_context.params_flat # type: ignore[union-attr]
if param is not None
]
real_inputs = [
materialize(x)
for x in itertools.chain(params_flat, V.real_inputs)
]
else:
# In the backward pass, V.real_inputs is not OrderedSet.
# Generating random inputs based on self.example_inputs sometimes can be problematic,
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
real_inputs = [
materialize(x) # type:ignore[arg-type]
for x in (
self.example_inputs # type:ignore[union-attr]
if isinstance(V.real_inputs, NullHandler)
else V.real_inputs
)
]
if self.mutated_inputs:
mutated_input_idxs = [
idx
for idx, name in enumerate(self.graph_inputs)
if name in self.mutated_inputs
and isinstance(real_inputs[idx], torch.Tensor)
]
for idx in mutated_input_idxs:
# clone mutated Tensor inputs to avoid mutating them in
# the first pass of the CPP wrapper-based compilation, as
# this will lead to a side effect on the example inputs:
# e.g. if torch.compile(f)(x) if called on input-mutating
# f, the inputs x will be mutated twice in the process:
# once here, and again when running the compiled model;
# this will also lead to a numerically incorrect output
mutated_inp = real_inputs[idx]
if not isinstance(mutated_inp, torch.Tensor):
raise AssertionError
real_inputs[idx] = clone_preserve_strides(mutated_inp)
del mutated_inp
with torch.utils._python_dispatch._disable_current_modes():
compiled(real_inputs)
del real_inputs
# second pass
self.cpp_wrapper = True
self.removed_buffers.clear()
self.removed_operations.clear()
self.inplaced_to_remove.clear()
V.graph.sizevars.precomputed_replacements.clear()
V.graph.sizevars.inv_precomputed_replacements.clear()
metrics.reset()
with config.patch({"triton.autotune_at_compile_time": False}):
return self.codegen()
else:
# cpu
return self.codegen()
from torch._inductor.graph import GraphLowering
GraphLowering.codegen_with_cpp_wrapper = npu_codegen_with_cpp_wrapper
def patch_count_bytes():
def count_bytes(self):
total_bytes = 0
node_counts = []
node_runtimes = []
for node in self.scheduler.nodes:
try:
num_bytes = node.get_read_write_buffers_sizes()
except AssertionError:
num_bytes = 0
total_bytes += num_bytes
node_counts.append((node, num_bytes // 4))
node_runtimes.append((node, node.get_estimated_runtime()))
return total_bytes, node_counts, node_runtimes
torch._inductor.graph.GraphLowering.count_bytes = count_bytes
def patch_run_node():
"""
patch run node
"""
def run_node_npu(self, n: torch.fx.Node) -> object:
def debug(msg: str) -> None:
log.debug("lowering %s %s", LazyString(n.format_node), msg)
from torch._inductor.compiler_bisector import CompilerBisector
buffer_watermark = len(self.buffers)
operation_watermark = len(self.operations)
# origins: OrderedSet[Union[Node, ir.IRNode]] = OrderedSet([n])
origins: OrderedSet[Any] = OrderedSet([n])
is_call_function = n.op == "call_function"
if is_call_function:
args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs)
with (
ir.IRNode.current_origins(origins),
self.set_current_node(n),
V.set_current_node(n),
):
if (
n.op == "call_function"
and n.target is not operator.getitem
and (
fallback_node_due_to_unsupported_type(n)
or CompilerBisector.disable_subsystem(
"inductor", "lowerings", lambda: repr(n)
)
)
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, # type: ignore[possibly-undefined]
**kwargs, # type: ignore[possibly-undefined]
)
elif (
n.op == "call_function"
and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation
and config.triton_kernel_default_layout_constraint != "flexible_layout"
):
debug("user_defined_triton_kernel_layout_constraints")
if (
config.triton_kernel_default_layout_constraint
== "needs_fixed_stride_order"
):
old_args = args # type: ignore[possibly-undefined]
old_kwargs = kwargs # type: ignore[possibly-undefined]
if arg_kwarg_vals := n.meta.get("arg_kwarg_vals"):
inp_args = arg_kwarg_vals[0]
inp_kwargs = arg_kwarg_vals[1]
args, kwargs = constrain_to_fake_tensors(
args, kwargs, inp_args, inp_kwargs
)
else:
args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index]
result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
else:
raise RuntimeError(
f"Unknown triton_kernel_default_layout_constraint: {config.triton_kernel_default_layout_constraint}"
)
elif is_magic_method(n.target):
# TODO: this is sus, it probably should be handled in the
# lowerings themselves similarly to sym_size/sym-stride
# https://github.com/pytorch/pytorch/issues/127789
debug("is_magic_method")
if isinstance(
n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool)
):
result = n.meta["val"].node.expr
else:
result = super(GraphLowering, self).run_node(n)
else:
debug("")
result = super(GraphLowering, self).run_node(n)
# require the same stride order for dense outputs,
# 1. user-land view() will not throw because inductor
# output different strides than eager
# long term the solution is to make view() always succeed
# with infallible strides.
# 2: as_strided ops, we need make sure its input has same size/stride with
# eager model to align with eager behavior.
as_strided_ops = [
torch.ops.aten.as_strided.default,
torch.ops.aten.as_strided_.default,
torch.ops.aten.as_strided_scatter.default,
torch.ops.aten.resize.default,
torch.ops.aten.resize_as.default,
]
is_output = any(user.op == "output" for user in n.users)
is_user_visible = n in self.user_visible_output_strides
is_input_for_as_strided = any(
user.target in as_strided_ops for user in n.users
)
if n.meta.get("inductor_realize_to_strides", False) and isinstance(
result, TensorBox
):
result.realize()
strides = n.meta["val"].stride()
sym_strides = torch._inductor.utils.any_is_symbolic(*strides)
if result.maybe_get_stride() != strides and not sym_strides:
stride_order = ir.get_stride_order(strides)
result = ir.ExternKernel.require_stride_order(result, stride_order)
if (
is_output
and isinstance(result, TensorBox)
and isinstance(result.data, ir.BaseView)
):
# Realize so that outputs are correctly aliased
result.realize()
if (is_output or is_input_for_as_strided) and isinstance(
n.meta["val"], torch.Tensor
):
if is_user_visible:
strides = self.user_visible_output_strides.get(n)
else:
strides = n.meta["val"].stride()
if strides is not None and len(strides) > 0:
allow_padding = (
config.pad_outputs or not is_user_visible
) and not is_input_for_as_strided
dense = torch._prims_common.is_non_overlapping_and_dense(
n.meta["val"]
)
unbacked_symbols_in_strides = (
len(free_unbacked_symbols(strides)) > 0
)
if (
not unbacked_symbols_in_strides
and dense
and len(result.get_size()) == 4
and n in self.nodes_prefer_channels_last
and not is_user_visible
and not is_input_for_as_strided
):
strides = ir.FlexibleLayout.stride_ordered_for_memory_format(
result.get_size(), torch.channels_last
)
if not unbacked_symbols_in_strides and len(strides):
# To avoid converting possible view ops to a copy kernel, we use the previous
# require_exact_strides to handle views. But ultimately it's better to require
# the right strides at the tensor definition.
if n.meta["val"]._is_view() or isinstance(
result.data, ir.BaseView
):
result = ir.ExternKernel.require_stride_order(
result,
ir.get_stride_order(strides),
allow_padding=allow_padding,
)
else:
strides = [
s.node.expr if isinstance(s, torch.SymInt) else s
for s in strides
]
result = ir.ExternKernel.require_exact_strides(
result, strides, allow_padding=allow_padding
)
# Realize if (1) any user need inputs realized, or (2) there is
# already too many reads and rematerializing can be bad.
num_users = len(OrderedSet(n.users))
if num_users > 1 and isinstance(result, TensorBox):
for user in n.users:
if user.target in needs_realized_inputs:
result.realize_hint()
# This inclusion is somewhat controversial (from
# discussion between Horace, Natalia, and Elias).
# Currently, it's not very clear why this is helpful.
# The general idea here is that even though a node may
# have FlexibleLayout, we still often *treat* it as if
# it was contiguous. This appears to sometimes result in
# suboptimal behavior.
#
# When we do a better job selecting layout, we should
# revisit this.
need_fixed_layout = [
torch.ops.aten.convolution_backward.default,
torch.ops.aten.mm.default,
torch.ops.aten._int_mm.default,
]
need_fixed_channels_last_layout = []
if not self.layout_opt:
need_fixed_layout.append(torch.ops.aten.convolution.default)
if torch._C._has_mkldnn:
need_fixed_layout += [
torch.ops.mkldnn._linear_pointwise.default,
torch.ops.mkldnn._linear_pointwise.binary,
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.onednn.qlinear_pointwise.default,
torch.ops.onednn.qlinear_pointwise.tensor,
torch.ops.onednn.qlinear_pointwise.binary,
torch.ops.onednn.qlinear_pointwise.binary_tensor,
]
need_fixed_channels_last_layout += [
torch.ops.mkldnn._convolution_pointwise.default,
torch.ops.mkldnn._convolution_pointwise.binary,
torch.ops.mkldnn._convolution_pointwise_.binary,
torch.ops.mkldnn._convolution_transpose_pointwise.default,
torch.ops.onednn.qconv2d_pointwise.default,
torch.ops.onednn.qconv2d_pointwise.binary,
]
if torch._C.has_mkl:
need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
if user.target in need_fixed_layout:
result = ir.ExternKernel.require_stride_order(
result,
ir.get_stride_order(n.meta["val"].stride()),
allow_padding=True,
)
if (
user.target in need_fixed_channels_last_layout
and n is user.args[0]
):
result = ir.ExternKernel.require_stride_order(
result,
ir.get_stride_order(
make_channels_last_strides_for(n.meta["val"].shape)
),
)
if user.op == "output":
if isinstance(result.data.data, (Pointwise, Reduction)):
result.realize()
# TODO(jansel): introduce a store vs inline choice
result.mark_reuse(len(n.users))
# Realize if the IRNode already has accumulated lots of reads
if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
# Prevent excessive accumulation in a computed buffer, when
# there are multiple branches each with small number of memory
# reads, but they converge to a user.
result.realize_hint()
# Realize if a Pointwise has too much stuff to be inlined.
# As this may cause RecursionError during Inductor's evaluation.
if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
curr = result.data.data
if isinstance(curr, Pointwise):
# Use inner fn as a rough proxy. Good enough.
if curr.has_large_inner_fn(threshold=100):
result.realize()
from .config import lowering_axis_count
if lowering_axis_count and len(curr.ranges) >= lowering_axis_count:
result.realize()
# This is not complete, but it doesn't have to be: origin_node
# tracking is best effort. The logic here critically relies on direct
# TensorBox -> StorageBox denoting a non-view; we don't bother trying
# to get views to work. Feel free to add any extra cases as needed.
#
# Note: we can't YOLO tree_map over this result, because if there are
# buffers or a view involved, we might not be able to validly assign
# the origin_node here.
if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
if isinstance(result.data.data, ir.Loops):
result.data.data._post_init_setattr("origin_node", n)
elif isinstance(result.data.data, ir.Buffer):
result.data.data._post_init_setattr("origin_node", n)
if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
result.data.data.data, ir.Loops
):
result.data.data.data._post_init_setattr("origin_node", n)
# Not really multi-output, can straightforwardly recurse in
elif (
isinstance(result.data.data, ir.MultiOutput)
and not result.data.data.indices
):
if isinstance(result.data.data.inputs[0], ir.Buffer):
result.data.data.inputs[0]._post_init_setattr("origin_node", n)
self.register_users_of(result)
new_unbacked_defs = OrderedSet[sympy.Symbol]()
for buf in self.buffers[buffer_watermark:]:
new_unbacked_defs |= buf.get_unbacked_symbol_defs()
for op in self.operations[operation_watermark:]:
new_unbacked_defs |= op.get_unbacked_symbol_defs()
def format_new_defs() -> str:
r = [
f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
for buf in self.buffers[buffer_watermark:]
]
r.extend(
f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
for op in self.operations[operation_watermark:]
)
return "***\n".join(r)
if n.op != "placeholder":
# Note [Backwards runtime asserts]
# Backwards poses an interesting problem for deferred runtime
# asserts. In the easy case, we may solely close over data
# dependent sized tensors, and there are no binding sites for
# unbacked SymInts. In this case, we can just drop all the
# runtime asserts on the floor: no non-placeholder bindings, no
# problem.
#
# However, it is *possible* for a fresh runtime assert to show up
# between forwards and backwards. Right now, the freezing process
# that happens when we lower forwards means that we will freeze
# runtime asserts, and then the moment the backwards lowering
# process attempts to add a new deferred runtime assert, we will
# fail. Let's say you remove that assert. Now when we get here,
# we need to make sure we actually emit these asserts (because we
# can't emit them in forwards, we already compiled it). So we
# have to do something here. But we don't want to reemit ALL
# deferred runtime asserts, we only want to emit the NEW ones.
# Therefore needing some sort of stratification in the ShapeEnv.
# This is all doable, it just hasn't been done yet.
shape_env = V.graph.sizevars.shape_env
def make_assert(expr, msg: str) -> None:
assert_op = ir.AssertScalar(expr, msg)
self.register_buffer(assert_op, set_name=True)
self.register_operation(assert_op)
for i0 in new_unbacked_defs:
ras = self.ras_by_symbol.pop(i0, [])
# NB: size-like not needed, we won't retrace
vr = shape_env.var_to_range[i0]
if not shape_env._default_unspecified_value_range().issubset(vr):
def is_convertible(s) -> bool:
if s in (int_oo, -int_oo):
return False
try:
int(s)
return True
except TypeError:
return False
if is_convertible(vr.lower):
make_assert(i0 >= vr.lower, f"{i0} >= {vr.lower}")
if is_convertible(vr.upper):
make_assert(i0 <= vr.upper, f"{i0} <= {vr.upper}")
for ra in ras:
fvs = free_unbacked_symbols(ra.expr)
missing = fvs - self.bound_unbacked_symbols
if missing:
i1 = min(missing, key=str)
self.ras_by_symbol.setdefault(i1, []).append(ra)
else:
make_assert(ra.expr, f"{ra.expr}")
self.bound_unbacked_symbols |= new_unbacked_defs
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
)
# When we do lowering, it is possible we reallocate unbacked SymInts.
# So we need to line up the unbacked SymInts when performing the test
# here
#
# In principle, we could permit lowering to introduce MORE unbacked
# SymInts: as long as all the old unbacked ones are accounted for,
# it's fine for inductor to introduce extra calls to item()/unbacked()
# whatever. This actually happens in practice when an unbacked SymInt
# gets memoized away; naively, when Inductor reprocesses a kernel, it
# doesn't know that the memo still applies, and ends up allocating a
# new symbol. However, this is generally a bad thing: we may still
# end up needing to test equalities on the symbols, and a fresh
# symbol is likely to hit lots of GuardOnDataDependent errors that
# we already know facts for.
renamed_unbacked_bindings = OrderedSet(
V.fake_mode.shape_env.unbacked_renamings.get(s, s)
for s in unbacked_bindings.keys()
)
return result
torch._inductor.graph.GraphLowering.run_node = run_node_npu