import os
import sympy
import torch._ops
from torch._inductor import ir
from torch._inductor import lowering
from torch._inductor.decomposition import decompositions, pw_cast_for_opmath
from torch._inductor.ir import ExpandView, TensorBox, ops_wrapper, StorageBox, View
from torch._inductor.ir import Reduction, Pointwise
from torch._inductor.lowering import sum_
from torch._inductor.utils import sympy_product
from torch._prims_common import (
is_boolean_dtype,
is_integer_dtype,
is_float_dtype,
get_computation_dtype,
ELEMENTWISE_TYPE_PROMOTION_KIND,
Number
)
from torch._inductor.lowering import (
lowerings,
make_fallback,
register_lowering,
to_dtype,
fallback_cumsum,
_validate_reduction_axis,
div as div_pt,
squeeze as squeeze_pt,
square as square_pt,
sub as sub_pt,
fallback_handler,
is_boolean_type,
make_pointwise,
_make_reduction_inner,
_validate_reduction_axis,
add_needs_realized_inputs,
add_layout_constraint,
require_channels_last,
_validate_dim as _validate_dim_pt,
get_promoted_dtype,
add as add_pt,
rsqrt as rsqrt_pt,
mul as mul_pt,
sqrt as sqrt_pt,
clone as clone_pt,
pow_recursive,
exp2 as exp2_pt
)
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
from torch._inductor.lowering import (unsqueeze as unsqueeze_pt, index_put_as_masked_fill, index_put_fallback, needs_fallback_due_to_atomic_add_limitations, view as view_pt, check_and_broadcast_indices, index_output_size_and_inner_fn, expand as expand_pt, clone, new_empty, scatter_fallback, full_like as full_like_pt)
from torch._inductor.virtualized import V, ops
from torch._inductor import scheduler
from torch._inductor.scheduler import Scheduler
from .. import npu_dtype_cast, _npu_dtype_cast
from . import ir as npu_ir
from .codegen.triton_utils import NPUKernelType
from .ir import IndexputTemplate, ScatterTemplate
from .lowering_override_list import LOWERING_OVERRIDE_OP
from .config import inductor_indirect_memory_mode, lowering_cat_with_concat_kernel, log, is_ascend950
from .lowering_fallback_list import FALLBACK_LIST, NPU_EXTRA_FALLBACK_LIST
from . import config as npu_config
from .lowering_fx import (
fetch_graphs,
merge_traced_graphs,
node_id,
create_fake_input,
subtract_graph,
create_fx_from_snodes_by_traced_graph,
create_compile_kwargs,
generate_fx_graph_code,
dump_fx_graph_code,
snodes_to_fx,
)
def npu_make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
if op in decompositions and not override_decomp:
raise RuntimeError(f"both a fallback and a decomp for same op: {op}")
def register_fallback(op_overload):
add_needs_realized_inputs(op_overload)
if layout_constraint is not None:
add_layout_constraint(op_overload, layout_constraint)
return register_lowering(op_overload, type_promotion_kind=None)(
fallback_handler(op_overload)
)
if isinstance(op, torch._ops.OpOverloadPacket):
for ol in op.overloads():
op_overload = getattr(op, ol)
register_fallback(op_overload)
elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
register_fallback(op)
else:
raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}")
make_fallback = npu_make_fallback
if npu_config.dump_fx_graph:
from .lowering_fx import (
_make_reduction_inner,
reduction_type_to_aten_fn,
npu_compute_ancestors,
_npu_prune_redundant_deps,
_npu_get_unmet_dep_nodes,
clone,
to_dtype,
DUMP_FX_GRAPH_LOWERING_OPS
)
LOWERING_OVERRIDE_OP = list(set(LOWERING_OVERRIDE_OP) | set(DUMP_FX_GRAPH_LOWERING_OPS))
Scheduler.compute_ancestors = npu_compute_ancestors
scheduler._prune_redundant_deps = _npu_prune_redundant_deps
Scheduler._get_unmet_dep_nodes = _npu_get_unmet_dep_nodes
def make_reduction(reduction_type: str, override_return_dtype=None):
def inner(x, axis=None, keepdims=False, *, dtype=None):
kwargs = _make_reduction_inner(
x,
axis=axis,
keepdims=keepdims,
dtype=dtype,
override_return_dtype=override_return_dtype,
)
if npu_config.dump_fx_graph:
node_name = f'reduction_{next(node_id)}'
input_graphs = fetch_graphs([x, axis if axis is not None else list(range(len(x.get_size())))])
new_graph = merge_traced_graphs(input_graphs, reduction_type_to_aten_fn[reduction_type],
node_name, keepdim=keepdims)
result = Reduction.create(reduction_type=reduction_type,
input_node=x,
node_name=node_name,
traced_graph=new_graph,
**kwargs)
else:
result = Reduction.create(reduction_type=reduction_type,
input_node=x,
**kwargs)
if isinstance(
result.data.data, Reduction
):
size = x.get_size()
axis = set(_validate_reduction_axis(x, axis))
kept_idx = []
reduced_idx = []
for i in range(len(size)):
if i in axis:
reduced_idx.append(i)
else:
kept_idx.append(i)
object.__setattr__(result.data.data, "kept_idx", kept_idx)
object.__setattr__(result.data.data, "reduced_idx", reduced_idx)
result.realize()
return result
return inner
lowering.make_reduction = make_reduction
aten = torch.ops.aten
tr_c10d = torch.ops.tr_c10d
prims = torch.ops.prims
npu = torch.ops.npu
def _add_overload(input_list, output_set):
for fn in input_list:
output_set.add(fn)
if isinstance(fn, torch._ops.OpOverloadPacket):
for overload in fn.overloads():
other_fn = getattr(fn, overload)
output_set.add(other_fn)
def _register_npu_inductor_fallbacks():
for op in lowering.lowerings:
if op in FALLBACK_LIST and op not in decompositions \
and isinstance(op, (torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
make_fallback(op)
log.info(f"[npu|inductor|lowering|fallback] with FALLBACK_LIST, len(lowerings): {len(lowerings)}, "
f"len(FALLBACK_LIST): {len(FALLBACK_LIST)}, make_fallback finished.")
log.info(f"[npu|inductor|lowering|fallback] len(NPU_EXTRA_FALLBACK_LIST): {len(NPU_EXTRA_FALLBACK_LIST)}")
overload_op_set = set()
_add_overload(LOWERING_OVERRIDE_OP, overload_op_set)
for op in overload_op_set:
if op in lowerings:
del lowerings[op]
if npu_config.dump_fx_graph:
from .lowering_fx import _register_npu_inductor_fallbacks_fx
(squeeze, expand, view, unsqueeze, _validate_dim, full_like, mul, div, rsqrt, add, square, sub) = _register_npu_inductor_fallbacks_fx(make_reduction)
else:
(squeeze, expand, view, unsqueeze, _validate_dim, full_like, mul, div, rsqrt, add, square, sub) = (squeeze_pt, expand_pt, view_pt, unsqueeze_pt, _validate_dim_pt, full_like_pt, mul_pt, div_pt, rsqrt_pt, add_pt, square_pt, sub_pt)
reduce_amax = register_lowering(aten.amax)(make_reduction("max"))
reduce_amin = register_lowering(aten.amin)(make_reduction("min"))
reduce_argmax = register_lowering(aten.argmax)(
make_reduction("argmax", override_return_dtype=torch.int64)
)
reduce_argmin = register_lowering(aten.argmin)(
make_reduction("argmin", override_return_dtype=torch.int64)
)
@register_lowering(aten.max, type_promotion_kind=None)
def reduce_max(x, dim=None, keepdim=False):
if dim is not None:
return (
reduce_amax(x, axis=dim, keepdims=keepdim),
reduce_argmax(x, axis=dim, keepdims=keepdim),
)
return reduce_amax(x, axis=None, keepdims=keepdim)
@register_lowering(aten.min, type_promotion_kind=None)
def reduce_min(x, dim=None, keepdim=False):
if dim is not None:
return (
reduce_amin(x, axis=dim, keepdims=keepdim),
reduce_argmin(x, axis=dim, keepdims=keepdim),
)
return reduce_amin(x, axis=None, keepdims=keepdim)
@register_lowering(aten.mean)
def mean(x, axis=None, keepdim=False, *, dtype=None):
if dtype is not None:
x = to_dtype(x, dtype)
size = x.get_size()
axis = _validate_reduction_axis(x, axis)
output_dtype = x.get_dtype()
if output_dtype in (torch.float16, torch.bfloat16):
x = to_dtype(x, torch.float)
sum_result = sum_(x, axis, keepdim)
denom = sympy_product(size[i] for i in axis)
denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
denom = ExpandView.create(denom, list(sum_result.get_size()))
return to_dtype(div(sum_result, denom), output_dtype)
@register_lowering(aten.cumsum)
def cumsum(x, axis=None, dtype=None):
if (is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())) and dtype is None:
dtype = torch.int64 if is_ascend950 else torch.int32
if len(x.get_size()) == 0:
if axis not in [0, -1]:
raise ValueError("axis must be 0 or -1")
dtype = dtype or x.get_dtype()
return to_dtype(x, dtype, copy=True)
return fallback_cumsum(x, dim=axis, dtype=dtype)
@register_lowering(npu.npu_dtype_cast, type_promotion_kind=None)
def _convert_npu_type(x: TensorBox, dtype: torch.dtype):
return to_dtype(x, dtype, copy=True)
@register_lowering(npu._npu_dtype_cast, type_promotion_kind=None)
def _convert__npu_type(x: TensorBox, dtype: torch.dtype):
return to_dtype(x, dtype, copy=True)
def lowering_index_select(x, select_dim, indices, index_select_type, traced_graph=None, node_name=None):
assert isinstance(x, TensorBox)
assert isinstance(indices, TensorBox)
assert "int" in str(indices.get_dtype())
weight_loader = x.make_loader()
indices_loader = indices.make_loader()
indices_ndim = len(indices.get_size())
x_size = x.get_size()
new_size = [*x_size[:select_dim], *indices.get_size(), *x_size[select_dim + 1:]]
def fn(idx):
assert len(idx) == len(new_size), f"{idx} != {new_size}"
is_indirect_idx = any(['tmp' in str(var) or 'indirect' in str(var) for var in idx])
var_index = indices_loader(idx[select_dim:select_dim + indices_ndim])
set_indirect = ops.indirect_indexing(var_index, x_size[select_dim])
x_idx = [*idx[:select_dim]] + [set_indirect] + [*idx[select_dim + indices_ndim:]]
if is_indirect_idx:
return weight_loader(x_idx)
try:
index_loader = x.data.make_indexer()
loader_name = x.data.get_name()
return ops.index_select(loader_name, index_loader(x_idx), var_index, set_indirect, int(x_size[select_dim]), index_select_type)
except Exception as e:
return weight_loader(x_idx)
if npu_config.dump_fx_graph:
return Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=fn,
ranges=new_size,
traced_graph=traced_graph,
node_name=node_name
)
else:
return Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=fn,
ranges=new_size,
)
@register_lowering(aten.embedding, type_promotion_kind=None)
def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
node = V.current_node
if node.meta.get("skip_lowering", False):
return fallback_handler(aten.embedding.default)(weight, indices, padding_idx=padding_idx, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse)
if npu_config.dump_fx_graph:
input_graphs = fetch_graphs([weight, indices])
node_name = f'embedding_{next(node_id)}'
new_graph = merge_traced_graphs(input_graphs, torch.ops.aten.embedding.default, node_name, padding_idx=padding_idx, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse)
else:
new_graph = None
node_name = None
if inductor_indirect_memory_mode != str(NPUKernelType.SIMT_TEMPLATE):
return lowering.embedding(weight, indices)
def should_use_template():
weight_size = weight.get_size()
if 1 in weight_size:
return False
if isinstance(weight, TensorBox) and isinstance(weight.data, ir.BaseView):
return False
return True
if should_use_template():
return lowering_index_select(weight, 0, indices, 'embedding', new_graph, node_name)
return lowering.embedding(weight, indices)
@make_pointwise
def pow_native(a,b):
return ops.pow(a,b)
fallback_pow_tensor_tensor = fallback_handler(
aten.pow.Tensor_Tensor, add_to_fallback_set=False
)
fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False)
fallback_pow_tensor_scalar = fallback_handler(
aten.pow.Tensor_Scalar, add_to_fallback_set=False
)
@register_lowering(aten.pow, broadcast=True)
def pow(a, b):
if isinstance(b, float) and b == int(b):
return pow(a, int(b))
elif isinstance(b, float) and b == 0.5:
return sqrt_pt(a)
elif isinstance(b, int) and b == 1:
return clone_pt(a)
dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox))
is_integer_pow = is_integer_dtype(dtype)
is_fp64_pow = (dtype == torch.float64)
embed_exponent = isinstance(b, int) and (
-32 < b < 32 or (is_integer_pow and b >= 0)
)
if embed_exponent:
loader = a.make_loader()
def fn(idx):
return pow_recursive(loader(idx), b, a.get_dtype())
return Pointwise.create(
device=a.get_device(),
dtype=a.get_dtype(),
inner_fn=fn,
ranges=a.get_size(),
)
if isinstance(a, Number):
if a == 1:
return full_like(b, 1)
if a == 2 and is_float_dtype(b.get_dtype()):
return exp2_pt(b)
if is_integer_pow or is_fp64_pow:
if isinstance(a, Number):
return fallback_pow_scalar(a, b)
elif isinstance(b, Number):
return fallback_pow_tensor_scalar(a, b)
else:
return fallback_pow_tensor_tensor(a, b)
return pow_native(a, b)
@register_lowering(aten.cat)
def cat(inputs, dim=0):
if len(inputs) == 1:
return clone(inputs[0])
def _is_dynamic(shape):
return any((isinstance(s, (sympy.Symbol, sympy.Expr)) and len(s.free_symbols) > 0) for s in shape)
is_dynamic = any(_is_dynamic(inp.get_size()) for inp in inputs)
if is_dynamic:
return fallback_handler(aten.cat.default)(inputs, dim)
if inputs[0].get_device().type == "npu" and lowering_cat_with_concat_kernel:
def is_reindex_view(x) -> bool:
if isinstance(x, (TensorBox, ir.StorageBox)):
return is_reindex_view(x.data)
if isinstance(x, ir.View) and "ModularIndexing" in x.reindex_str():
return True
return False
for inp in inputs:
if is_reindex_view(inp):
return fallback_handler(aten.cat.default)(inputs, dim)
input_dims = len(inputs[0].get_size())
if input_dims > 1 and (dim == -1 or dim == input_dims - 1):
return TensorBox(npu_ir.ConcatKernel.create(inputs, dim, False))
else:
return fallback_handler(aten.cat.default)(inputs, dim)
else:
dim = _validate_dim(inputs[0], dim, 0)
dtype = get_promoted_dtype(
*inputs,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
inputs = [to_dtype(inp, dtype) for inp in inputs]
return TensorBox(ir.ConcatKernel.create(inputs, dim))
@register_lowering(aten.gather, type_promotion_kind=None)
def gather(x, dim, index, sparse_grad=False):
assert isinstance(x, TensorBox)
if index.get_numel() == 0:
return new_empty(x, index.get_size())
assert index.get_dtype() == torch.int64
size = x.get_size()
offset = len(size) == 0
dim = _validate_dim(x, dim, offset)
if offset:
x = expand(x, [1])
size = [1]
def should_use_template():
template_x_dtypes = [torch.float32, torch.float16, torch.bfloat16]
if x.get_dtype() not in template_x_dtypes:
return False
if 1 in x.get_size() or 1 in index.get_size():
return False
if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
return False
if isinstance(index, TensorBox) and isinstance(index.data, ir.BaseView):
return False
return True
if not should_use_template():
return lowering.gather(x, dim, index, sparse_grad)
index_loader = index.make_loader()
loader_name = x.data.get_name()
x_loader = x.data.make_indexer()
index_boundary = size[dim]
def fn(idx):
idx = list(idx)
index_value = index_loader(idx)
gather_idx = ops.indirect_indexing(index_value, size[dim])
if len(idx) == 0:
idx = [gather_idx]
else:
idx[dim] = gather_idx
return ops.gather_template(loader_name, x_loader(idx), index_value, gather_idx, int(index_boundary))
return Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=fn,
ranges=index.get_size(),
)
def index_put_impl_(self, indices, values, accumulate, check, may_realize=False):
if may_realize:
def try_get_name(x):
if isinstance(x, ir.TensorBox):
x = x.data
if isinstance(x, ir.BaseView):
x = x.unwrap_view()
if isinstance(x, ir.StorageBox):
x = x.data
return x.get_name() if isinstance(x, ir.Buffer) else None
def indice_slice_from_randperm(indice):
if isinstance(indice, TensorBox) and isinstance(indice.data, ir.BaseView):
indice = indice.data.unwrap_view()
return (
isinstance(indice, ir.StorageBox)
and isinstance(indice.data, ir.ExternKernel)
and getattr(indice.data, "fx_node", None)
and indice.data.fx_node.target == torch.ops.aten.randperm.default
)
return False
if try_get_name(self) in values.get_read_names() and not all(
indice_slice_from_randperm(indice) for indice in indices
):
values.realize()
if (
values.get_numel() == 1
and len(indices) == 1
and indices[0].get_dtype() in (torch.bool, torch.uint8)
):
mask = indices[0]
for _ in range(len(mask.get_size()), len(self.get_size())):
mask = unsqueeze(mask, -1)
return index_put_as_masked_fill(self, [mask], values, accumulate)
if torch.are_deterministic_algorithms_enabled():
return index_put_fallback(self, indices, values, accumulate)
for index in indices:
if index is not None and index.get_dtype() in (torch.bool, torch.uint8):
return index_put_fallback(self, indices, values, accumulate)
x_size = self.get_size()
x_ndim = len(x_size)
if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()):
if x_ndim == 0:
self = view(self, [1])
self = index_put_fallback(self, indices, values, accumulate)
if x_ndim == 0:
self = view(self, [])
return self
values = to_dtype(values, self.get_dtype())
try:
indices, tensor_indices = check_and_broadcast_indices(
indices, self.get_device()
)
except NotImplementedError:
return index_put_fallback(self, indices, values, accumulate)
indices_loaders = [i.make_loader() if i is not None else None for i in indices]
assert isinstance(self, TensorBox)
self.realize()
if x_ndim == 0:
self = view(self, [1])
tensor_size = list(indices[tensor_indices[0]].get_size())
indexed_size = [x_size[i] for i in range(len(indices))]
expected_vals_size, inner_fn = index_output_size_and_inner_fn(
x_size,
indices,
tensor_indices,
tensor_size,
indices_loaders,
indexed_size,
None,
check=check,
)
values = expand(values, expected_vals_size)
def should_use_template():
if accumulate:
return False
if len(indices_loaders) == x_ndim and indices_loaders[-1] is not None:
return False
if x_ndim == 1 or 1 in x_size or tensor_size[0] == 1:
return False
valid_indices = [indice for indice in indices if indice]
if len(valid_indices) != 1:
return False
if isinstance(self, TensorBox) and isinstance(self.data, ir.BaseView):
return False
if isinstance(valid_indices[0], TensorBox) and isinstance(valid_indices[0].data, ir.BaseView):
return False
if isinstance(values, TensorBox) and isinstance(values.data, ir.BaseView):
return False
return True
if should_use_template():
valid_index = next(i for i, indice in enumerate(indices_loaders) if indice)
boundary = int(x_size[valid_index])
scatter = IndexputTemplate(
device=self.get_device(),
dtype=self.get_dtype(),
inner_fn=values.make_loader(),
ranges=expected_vals_size,
output_indexer=inner_fn,
scatter_mode=None,
boundary=boundary
)
else:
scatter = ir.Scatter(
device=self.get_device(),
dtype=self.get_dtype(),
inner_fn=values.make_loader(),
ranges=expected_vals_size,
output_indexer=inner_fn,
scatter_mode="atomic_add" if accumulate else None,
)
buffer = ir.ComputedBuffer(
name=None,
layout=ir.MutationLayoutSHOULDREMOVE(self),
data=scatter,
)
buffer.name = V.graph.register_buffer(buffer)
V.graph.register_operation(buffer)
if x_ndim == 0:
self = view(self, [])
return self
@register_lowering(aten.index_put)
def index_put(x, indices, values, accumulate=False):
return index_put_impl_(
clone(x), indices, values, accumulate, check=True, may_realize=False
)
@register_lowering(aten._unsafe_index_put)
def _unsafe_index_put(x, indices, values, accumulate=False):
return index_put_impl_(
clone(x), indices, values, accumulate, check=False, may_realize=False
)
@register_lowering(aten.index_put_, type_promotion_kind=None)
def index_put_(self, indices, values, accumulate=False):
return index_put_impl_(
self, indices, values, accumulate, check=True, may_realize=True
)
@register_lowering(aten.scatter_reduce_, type_promotion_kind=None)
def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
assert reduce in (None, "sum", "prod", "mean", "amax", "amin")
assert (
len(aten.scatter_reduce_.overloads()) == 1
and "two" in aten.scatter_reduce_.overloads()
), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_"
if isinstance(src, Number):
src = full_like(self, src)
fallback_result = scatter_fallback(
aten.scatter_reduce_.two,
self,
dim,
index,
src,
reduce=reduce,
include_self=include_self,
)
if fallback_result:
return fallback_result
assert isinstance(self, TensorBox)
assert "int" in str(index.get_dtype())
ndim = len(self.get_size())
if ndim == 0:
self = view(self, [1])
if isinstance(src, TensorBox) and len(src.get_size()) == 0:
src = view(src, [1])
if isinstance(index, TensorBox) and len(index.get_size()) == 0:
index = view(index, [1])
if index.get_numel() == 0:
return self
dim = _validate_dim(self, dim)
self.realize()
index_loader = index.make_loader()
src_loader = src.make_loader() if isinstance(src, TensorBox) else None
def output_indexer(idx):
shape = self.get_size()
ndim = len(shape)
indirect_idx = list(idx)
indirect_idx[dim] = ops.indirect_indexing(
index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False
)
return indirect_idx
def template_output_indexer(idx):
shape = self.get_size()
ndim = len(shape)
indirect_idx = list(idx)
indirect_idx[dim] = ops.indirect_indexing(
index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False
)
return indirect_idx, shape[dim]
def fn(idx):
if src_loader:
return src_loader(idx)
else:
return ops.constant(src, self.get_dtype())
def backend_reduce_str(reduce):
if reduce == "sum":
return "atomic_add"
else:
assert reduce is None
return None
if not include_self:
zero_out = ir.Scatter(
device=self.get_device(),
dtype=self.get_dtype(),
inner_fn=lambda index: ops.constant(0, self.get_dtype()),
ranges=index.get_size(),
output_indexer=output_indexer,
scatter_mode=None,
)
buffer = ir.ComputedBuffer(
name=None,
layout=ir.MutationLayoutSHOULDREMOVE(self),
data=zero_out,
)
buffer.name = V.graph.register_buffer(buffer)
V.graph.register_operation(buffer)
def should_use_template():
if reduce:
return False
if 1 in index.get_size() or 1 in self.get_size() or 1 in src.get_size():
return False
if isinstance(index, TensorBox) and isinstance(index.data, ir.BaseView):
return False
if isinstance(self, TensorBox) and isinstance(self.data, ir.BaseView):
return False
if isinstance(src, TensorBox) and isinstance(src.data, ir.BaseView):
return False
return True
if should_use_template():
scatter = ScatterTemplate(
device=self.get_device(),
dtype=self.get_dtype(),
inner_fn=fn,
ranges=index.get_size(),
output_indexer=template_output_indexer,
scatter_mode=backend_reduce_str(reduce),
)
else:
scatter = ir.Scatter(
device=self.get_device(),
dtype=self.get_dtype(),
inner_fn=fn,
ranges=index.get_size(),
output_indexer=output_indexer,
scatter_mode=backend_reduce_str(reduce),
)
buffer = ir.ComputedBuffer(
name=None,
layout=ir.MutationLayoutSHOULDREMOVE(self),
data=scatter,
)
buffer.name = V.graph.register_buffer(buffer)
V.graph.register_operation(buffer)
if ndim == 0:
self = view(self, [])
return self
@register_lowering(aten.scatter_reduce, type_promotion_kind=None)
def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs)
@register_lowering(aten.scatter_, type_promotion_kind=None)
def scatter_(self, dim: int, index, src, *, reduce=None):
assert reduce in (None, "add", "multiply")
if reduce is None:
op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname)
fallback_result = scatter_fallback(
op_overload, self, dim, index, src, reduce=reduce
)
if fallback_result is not None:
return fallback_result
if reduce == "add":
reduce = "sum"
elif reduce == "multiply":
reduce = "prod"
return scatter_reduce_(self, dim, index, src, reduce)
@register_lowering(aten.scatter, type_promotion_kind=None)
def scatter(x, dim: int, index, src, **kwargs):
return scatter_(clone(x), dim, index, src, **kwargs)
def var_mean_sum_(x, axis, correction, keepdim, return_mean):
if correction is None:
correction = 1
size = x.get_size()
axis = _validate_reduction_axis(x, axis)
x_mean = mean(x, axis, keepdim=True)
if return_mean:
x_mean.realize()
diffs = square(sub(x, x_mean))
sum_result = sum_(diffs, axis, keepdim)
denom = sympy_product(size[i] for i in axis)
if correction:
denom = sympy.Max(denom - correction, 0)
denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
denom = ExpandView.create(denom, list(sum_result.get_size()))
x_var = div(sum_result, denom)
if not return_mean:
return (x_var,)
x_mean = x_mean if keepdim else squeeze(x_mean, axis)
return x_var, x_mean
def var_mean_helper_(x, *, axis, correction, keepdim, return_mean):
out_dtype = x.get_dtype()
compute_dtype = get_computation_dtype(out_dtype)
x = to_dtype(x, compute_dtype, copy=False)
kwargs = dict(
x=x,
axis=axis,
correction=correction,
keepdim=keepdim,
return_mean=return_mean,
)
output = (
var_mean_sum_(**kwargs)
)
output = tuple(to_dtype(x, out_dtype, copy=False) for x in output)
return output[0] if not return_mean else output
@register_lowering(aten.var_mean)
def var_mean(x, axis=None, *, correction=None, keepdim=False):
return var_mean_helper_(
x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
)
@register_lowering([aten.var, prims.var])
def var_(x, axis=None, *, correction=None, keepdim=False):
return var_mean_helper_(
x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
)
@register_lowering(aten.index, type_promotion_kind=None)
def index(x, indices):
if npu_config.dump_fx_graph:
input_graphs = fetch_graphs([x, indices])
node_name = f'index_{next(node_id)}'
new_graph = merge_traced_graphs(input_graphs, aten.index, node_name)
else:
new_graph = None
node_name = None
def should_use_template():
x_size = x.get_size()
valid_indices = [indice for indice in indices if indice]
if len(x_size) == 1 or 1 in x_size:
return False
if len(valid_indices) != 1:
return False
select_dim = indices.index(valid_indices[0])
if select_dim == len(x_size) - 1:
return False
if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
return False
if isinstance(valid_indices[0], TensorBox) and isinstance(valid_indices[0].data, ir.BaseView):
return False
return True
if should_use_template():
valid_indices = [indice for indice in indices if indice]
select_dim = indices.index(valid_indices[0])
return lowering_index_select(x, select_dim, valid_indices[0], 'index_select', new_graph, node_name)
return lowering.index(x, indices)
@register_lowering(aten.cat)
def cat(inputs, dim=0):
if len(inputs) == 1:
return clone(inputs[0])
dim = _validate_dim(inputs[0], dim, 0)
dtype = get_promoted_dtype(
*inputs,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
inputs = [to_dtype(inp, dtype) for inp in inputs]
return TensorBox(ir.ConcatKernel.create(inputs, dim))
@register_lowering(aten.native_layer_norm)
def native_layer_norm(
x,
normalized_shape,
weight=None,
bias=None,
eps=1e-5
):
if is_ascend950 and \
(x.dtype == torch.bfloat16 or x.dtype == torch.float16):
return fallback_handler(aten.native_layer_norm.default)(x, normalized_shape, weight, bias, eps)
if not isinstance(normalized_shape, (list, tuple)):
normalized_shape = (normalized_shape,)
normalized_ndim = len(normalized_shape)
input_shape = x.get_size()
reduce_dims = list(range(len(input_shape) - normalized_ndim, len(input_shape)))
var, mean = var_mean_helper_(
x=x,
axis=reduce_dims,
correction=0,
keepdim=True,
return_mean=True
)
x_normalized = sub(x, mean)
eps_tensor = ir.IndexingConstant(index=eps, dtype=var.get_dtype(), device=var.get_device())
eps_tensor = ExpandView.create(eps_tensor, var.get_size())
var_eps = add(var, eps_tensor)
inv_std = rsqrt(var_eps)
normalized = mul(x_normalized, inv_std)
if weight is not None:
normalized = mul(normalized, weight)
if bias is not None:
normalized = add(normalized, bias)
return normalized, mean, inv_std
@register_lowering(triton_kernel_wrapper_mutation)
def triton_kernel_wrap_(
*,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata,
kwargs,
):
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
constant_args = kernel_side_table.get_constant_args(constant_args_idx)
ir.UserDefinedTritonKernel(
kernel_idx=kernel_idx,
grid=grid,
tma_descriptor_metadata=tma_descriptor_metadata,
kernel_args={**kwargs, **constant_args},
)
return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)}