import functools
import itertools
import operator
import os
import re
import textwrap
from enum import Enum
from torch._inductor.utils import DeferredLineBase, LineContext
from typing import List, Set, Iterable, Callable, Sequence
from typing import (
Optional,
Union,
Tuple,
Any,
cast,
Dict
)
import math
import sympy
import torch
from torch._inductor import config, ir
from torch.utils._ordered_set import OrderedSet
from torch._inductor.codegen.common import (
IndentedBuffer,
SizeArg,
DeferredLine,
ArgName
)
from torch._inductor.codegen.common import free_symbol_is_type
from torch._inductor.codegen.simd import CantSplit, DisableReduction, EnableReduction
from torch._inductor.codegen.triton import (
IndexingOptions,
triton_reshape,
TritonCSEVariable,
triton_compute_type
)
from torch._inductor.ops_handler import OpsHandler
from torch._inductor.codegen.triton import (
TritonKernel,
TritonKernelOverrides,
IterationRangesRoot,
IterationRangesEntry,
CSEVariable,
BlockPtrOptions,
triton_acc_type,
constant_repr,
is_welford_reduction, FixedTritonConfig,
prefix_is_reduction, upcast_acc_dtype,
get_kernel_category_by_source_code,
get_fused_kernel_name
)
from torch._inductor.codegen.triton_utils import config_of, signature_of, signature_to_meta, should_unwrap_unspec_arg
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
from torch._inductor.runtime.hints import DeviceProperties
from torch._inductor.runtime.runtime_utils import next_power_of_2
from torch._inductor.scheduler import SchedulerNode
from torch._inductor.utils import (
Placeholder,
get_bounds_index_expr,
upcast_compute_type,
sympy_product,
)
from torch._inductor.utils import sympy_index_symbol, generate_assert
from torch._inductor.utils import sympy_subs
from torch._inductor.virtualized import (
V,
StoreMode,
ReductionType,
_ops as ops,
)
from torch.utils import _pytree as pytree
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.symbol import SymT, symbol_is_type
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from torch._inductor.bounds import ValueRangeAnalysis
from torch._inductor.runtime import triton_heuristics
from .. import config as npu_config
from .kernel_analysis import (
IndexAnalysis,
ReductionAnalysis,
collect_stride_sorted_vars_from_indexings,
collect_stride_sorted_vars_from_nodes,
)
from .npu_kernel_features import NumelList
from .triton_utils import NPUKernelType
from .. import npu_triton_heuristics
def flatten(nums):
res = []
for i in nums:
if isinstance(i, list):
res.extend(flatten(i))
else:
res.append(i)
return res
origin_gen_common_triton_imports = torch._inductor.codegen.triton.gen_common_triton_imports
@functools.lru_cache(None)
def gen_common_triton_imports():
imports = IndentedBuffer()
imports.splice(origin_gen_common_triton_imports())
imports.splice(
"""
import torch
import torch_npu
from torch_npu._inductor import npu_triton_heuristics as triton_heuristics
from torch_npu._inductor.npu_triton_helpers import libdevice, extension, math as tl_math
"""
)
return imports.getvalue()
def patch_gen_common_triton_ext_imports():
from torch._inductor import select_algorithm
from torch._inductor.codegen import triton, triton_combo_kernel
triton.gen_common_triton_imports = gen_common_triton_imports
select_algorithm.gen_common_triton_imports = gen_common_triton_imports
triton_combo_kernel.gen_common_triton_imports = gen_common_triton_imports
class NPUTritonKernelOverrides(TritonKernelOverrides):
@staticmethod
def exp(x):
return f"tl_math.exp({x})"
@staticmethod
def sqrt(x):
return f"tl_math.sqrt({x})"
@staticmethod
def tanh(x):
return f"tl_math.tanh({x})"
@staticmethod
def rsqrt(x):
return f"tl.rsqrt({x})"
@staticmethod
def floor(x):
return f"tl_math.floor({x})"
@staticmethod
def erf(x):
return f"tl_math.erf({x})"
@staticmethod
def ceil(x):
return f"tl_math.ceil({x})"
@staticmethod
def masked(mask, body, other):
if mask is not None and torch.version.hip is not None:
mask = V.kernel.cse.generate(
V.kernel.compute,
f"{mask}.to(tl.int1)",
dtype=torch.bool,
)
nodes = body.graph.find_nodes(op="output")
need_where = False
for node in nodes:
for arg in node.args:
if arg.target != "load" or should_unwrap_unspec_arg(arg.args[1]):
need_where = True
break
value = None if need_where else other
current_subblock = None
for block_name, body_var in body.body.subblocks.items():
if body_var == body:
current_subblock = block_name
break
before_subblock_axis = V.kernel.current_subblock_axis
current_subblock_axis = body.body.masked_indexing.get(current_subblock, {})
V.kernel.current_subblock_axis = before_subblock_axis | current_subblock_axis
with V.kernel.mask_loads(mask, value=value) as new_mask:
result = body()
V.kernel.current_subblock_axis = before_subblock_axis
if need_where:
if result.bounds.is_bool:
other = bool(other)
other = V.kernel.cse.generate(
V.kernel.compute,
f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)",
bounds=ValueRanges.wrap(other),
dtype=result.dtype,
)
ret = ops.where(new_mask, result, other)
else:
ret = result
ret.mask_vars.discard(new_mask)
return ret
@classmethod
def index_expr(cls, expr, dtype):
indexing = V.kernel.indexing(expr, block_ptr=False, is_index_expr=True)
if not isinstance(indexing, IndexingOptions):
raise TypeError(f"not a IndexingOptions : {indexing}")
index_dtype = torch.int32 if V.kernel.index_dtype == "tl.int32" else torch.int64
dtype = dtype if dtype not in (torch.int32, torch.int64) else index_dtype
var = V.kernel.cse.generate(
V.kernel.compute,
indexing.index_str,
bounds=get_bounds_index_expr(expr),
dtype=dtype,
)
if dtype not in (torch.int32, torch.int64):
var = V.kernel.cse.generate(
V.kernel.compute,
cls.to_dtype(var, dtype),
dtype=upcast_compute_type(dtype),
)
else:
dtype = index_dtype
for index_var in expr.free_symbols:
if symbol_is_type(index_var, SymT.TMP):
dtype = torch.promote_types(
dtype, V.kernel.cse.varname_map[index_var.name].dtype
)
if dtype != index_dtype:
var = V.kernel.cse.generate(
V.kernel.compute,
cls.to_dtype(var, index_dtype),
dtype=index_dtype,
)
var.mask_vars = indexing.mask_vars
return var
@staticmethod
def frexp(x):
cache_key = f"frexp({x})"
cse_val = V.kernel.cse.try_get(cache_key)
if cse_val:
return cse_val
mantissa = V.kernel.cse.newvar(dtype=x.dtype)
exponent = V.kernel.cse.newvar(dtype=torch.int32)
V.kernel.compute.writeline(
f"{mantissa}, {exponent} = npu_triton_helpers.frexp({x})"
)
V.kernel.cse.put(cache_key, (mantissa, exponent))
return (mantissa, exponent)
def group_fn(self, sizes):
groups = list()
for s in sizes:
if not s:
groups.append(1)
elif isinstance(s, list):
group = flatten(s)
groups.append(NumelList(tuple(group)) if isinstance(group, list) else group)
else:
groups.append(s)
return tuple(groups)
def get_allow_dynamic():
from torch._inductor.dependencies import V as V_DYNAMIC
graph = getattr(V_DYNAMIC, "graph", None)
if graph is None:
return False
sizevars = getattr(graph, "sizevars", None)
shape_env = getattr(graph, "shape_env", None) or getattr(sizevars, "shape_env", None)
if shape_env and hasattr(shape_env, "var_to_range"):
return len(shape_env.var_to_range) > 0
return False
@staticmethod
def select_index_dtype(node_schedule, numel, reduction_numel):
return "tl.int32"
class IterationRangesEntryNPUIndex(IterationRangesEntry):
def __init__(
self,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.is_tiling_axis = False
self.is_split_axis = False
self.indexing_code = IndentedBuffer()
self.sorted_order = None
self.tiling_order = None
self.split_order = None
self.var_directions = {}
self.directions = []
self.codegen = self._codegen
self.is_no_loop_axis = False
def _codegen_mask(self):
allow_dynamic = get_allow_dynamic()
codegen_mask = self.is_tiling_axis and (not self.is_no_loop_axis or allow_dynamic)
if V.kernel.is_unified_simt_kernel():
codegen_mask = self.is_tiling_axis
if codegen_mask:
BLOCK_NAME = f"{self.name.upper()}BLOCK"
upper = f"min({BLOCK_NAME}+{self.symbol()}_offset, {self.name}_numel)" if self.is_split_axis else f"{self.name}_numel"
line = f"{self.name}_mask = {self.name} < {upper}"
self.writeline(line)
for var in self.var_directions.keys():
line = f"{var.name}_mask = {var.name} < {upper}"
self.writeline(line)
else:
pass
def get_axis_direction(self):
if self.directions:
return f"[{','.join(self.directions)}]"
tiling_axis = [x.symbol() for x in self.kernel.tiling_axis]
rev_orders = [x for x in self.kernel.golden_var_list if x in tiling_axis]
self.directions = ["None"] * len(tiling_axis)
if len(tiling_axis) != len(rev_orders):
raise RuntimeError(f"assert tiling len={len(tiling_axis)}, not equal to golden varlist len ={len(rev_orders)}")
var_orders = list(reversed(rev_orders))
index = var_orders.index(self.symbol())
self.directions[index] = ":"
return f"[{','.join(self.directions)}]"
def _codegen(self):
self.indexing_code.clear()
index = None
if not self.is_tiling_axis:
return self.name
direction = self.get_axis_direction()
index = f"{self.name} = {self.codegen_index(direction)}"
for var, dir_index in self.var_directions.items():
line = f"{var.name} = {self.codegen_index(dir_index)}"
self.writeline(line)
if self.prefix == 'r':
if V.kernel.inside_reduction and V.kernel.current_node \
and isinstance(V.kernel.current_node, SchedulerNode) \
and V.kernel.current_node.node \
and V.kernel.current_node.node.data \
and isinstance(V.kernel.current_node.node.data, ir.Reduction):
reduction_type = V.kernel.current_node.node.data.reduction_type
if reduction_type in {"argmax", "argmin"}:
self.writeline(f"{self.parent.prefix}index = "
f"{self.codegen_index(None)}")
if index:
self.writeline(index)
self._codegen_mask()
return self.name
def writeline(self, line):
self.indexing_code.writeline(line)
def codegen_index(self, direction):
BLOCK_NAME = f"{self.name.upper()}BLOCK"
BLOCK_NAME_SUB = f"{BLOCK_NAME}_SUB"
index = None
if self.prefix == 'r':
if V.kernel.persistent_reduction:
index = f"base_{self.name}"
else:
index = f"(loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}"
else:
if self.is_no_loop_axis:
index = f"base_{self.name}"
elif self.is_split_axis:
offset = f"{self.symbol()}_offset"
index = f"{offset} + (loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}"
else:
index = f"(loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}"
if len(V.kernel.tiling_axis) > 1 and direction is not None:
index += direction
return index
def codegen_header(self, code):
lines = []
BLOCK_NAME = f"{self.name.upper()}BLOCK"
BLOCK_NAME_SUB = f"{BLOCK_NAME}_SUB"
dtype_cast_str = ""
if V.kernel.index_dtype == "tl.int64" and npu_config.is_ascend950:
dtype_cast_str = ".to(tl.int64)"
if self.is_split_axis:
lines.append(f"{self.symbol()}_offset = tl.program_id({self.split_order}){dtype_cast_str} * {BLOCK_NAME}")
if self.is_no_loop_axis:
lines.append(f"base_{self.name}= tl.arange(0, {BLOCK_NAME_SUB}){dtype_cast_str}")
elif self.is_tiling_axis:
lines.append(f"base_{self.name}= tl.arange(0, {BLOCK_NAME_SUB}){dtype_cast_str}")
block = f"{BLOCK_NAME}" if self.is_split_axis else f"{self.symbol()}_numel"
lines.append(f"loops_{self.name} = ({block} + {BLOCK_NAME_SUB} - 1) // {BLOCK_NAME_SUB}")
else:
pass
code.writelines(lines)
def precomputed_args(self):
precomputed_args: List[sympy.Expr] = []
if isinstance(self.expr, (sympy.Symbol, sympy.Integer)):
return precomputed_args
if not isinstance(self.expr, (FloorDiv, ModularIndexing)):
raise RuntimeError("assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr)")
for arg in self.expr.args[1:]:
if not isinstance(arg, (sympy.Integer, sympy.Symbol)):
symbols = arg.free_symbols
if len(symbols) > 0 and all(
symbol_is_type(s, SymT.SIZE) for s in symbols
):
precomputed_args.append(arg)
return precomputed_args
def __eq__(self, other):
return self.name == other.name
class IterationRangesRootNPUIndex(IterationRangesRoot):
def __init__(
self,
name: str,
numel: sympy.Expr,
prefix: str,
index: int,
kernel: TritonKernel,
pid_cache=None,
*,
is_loop: bool,
tensor_dim: Optional[int],
grid_dim: Optional[int],
):
super().__init__(name, numel, prefix, index, kernel, pid_cache, is_loop=is_loop, tensor_dim=tensor_dim,
grid_dim=grid_dim, has_zdim=False)
def __repr__(self):
return f"IterationRangesRootNPUIndex({self.name!r}, {self.numel}, ...)"
def remove_entry(self, name):
"""
Remove an entry from the range tree.
Args:
name: sympy.Symbol representing the dimension to remove
"""
if name in self.var_ranges:
del self.var_ranges[name]
if name in self.var_list:
del self.var_list[self.var_list.index(name)]
if name in V.kernel.range_tree_nodes:
node = V.kernel.range_tree_nodes[name]
V.kernel.range_tree_nodes_removed[name] = node
del V.kernel.range_tree_nodes[name]
if node.expr in self.nodes:
del self.nodes[node.expr]
def duplicated_check(self, divisor, length):
"""
Lookup a given RangeTreeEntry, creating it if needed
"""
if V.graph.sizevars.statically_known_equals(divisor * length, self.numel):
expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor)
else:
expr = ModularIndexing(
sympy_index_symbol(f"{self.prefix}index"), divisor, length
)
return expr not in self.nodes
def lookup(self, divisor, length):
"""
Lookup a given RangeTreeEntry, creating it if needed
"""
if V.graph.sizevars.statically_known_equals(divisor * length, self.numel):
expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor)
else:
expr = ModularIndexing(
sympy_index_symbol(f"{self.prefix}index"), divisor, length
)
if expr not in self.nodes:
removed_node = self._find_removed_node(divisor, length)
if removed_node is not None:
self.nodes[expr] = removed_node
npu_config.log.debug(
"lookup: reusing removed node %s for divisor=%s, length=%s",
removed_node.symbol(), divisor, length
)
else:
node = IterationRangesEntryNPUIndex(
f"{self.prefix}{next(V.kernel.iter_vars_count)}",
divisor,
length,
expr,
self,
)
V.kernel.range_tree_nodes[node.symbol()] = node
self.var_list.append(node.symbol())
self.var_ranges[node.symbol()] = length
self.nodes[expr] = node
npu_config.log.debug(
"lookup: created new node %s for divisor=%s, length=%s",
node.symbol(), divisor, length
)
return self.nodes[expr]
def _find_removed_node(self, divisor, length):
"""Find a removed range tree node matching the given divisor and length."""
for name, node in V.kernel.range_tree_nodes_removed.items():
if (node.parent is self
and node.divisor == divisor
and node.length == length):
return node
return None
@classmethod
def is_compatible(
cls,
groups: Iterable[sympy.Expr],
lengths: Sequence[Sequence[sympy.Expr]],
reduction_numel: sympy.Expr = sympy.S.One
):
sizevars = V.graph.sizevars
if len(lengths[1]) == 0 and (
sizevars.statically_known_equals(
sympy_product(groups),
sympy_product(lengths[0]) * reduction_numel,
)
):
lengths = (lengths[0], [reduction_numel])
try:
groups = flatten(groups)
NPUIndexTritonKernel._split_iteration_ranges(groups, lengths)
return True
except CantSplit:
return False
class NPUIndexTritonKernel(TritonKernel):
overrides = NPUTritonKernelOverrides
def __init__(
self,
tiling: Dict[str, sympy.Expr],
min_elem_per_thread=0,
optimize_mask=True,
fixed_config: Optional[FixedTritonConfig] = None,
**kwargs, ):
super().__init__(tiling=tiling,
min_elem_per_thread=min_elem_per_thread,
optimize_mask=optimize_mask,
fixed_config=fixed_config,
**kwargs)
self.first_node = True
self.inside_high_order_reduction = False
self.low_dims = set()
self.split_axis = []
self.tiling_axis = []
self.range_tree_nodes_removed: Dict[sympy.Symbol, IterationRangesEntry] = {}
self.range_tree_nodes_substituted = {}
self.expr_substituted = {}
self.sorted_axis = []
self.prefix: IndentedBuffer = IndentedBuffer()
self.index_analysis = {}
self.golden_var_list = None
self.reduce_analysis = None
self.load_store_indexing = None
self.npu_kernel_type = NPUKernelType.SIMD
self.current_subblock_axis = set()
def _get_grid_type(self) -> type[triton_heuristics.GridExpr]:
return npu_triton_heuristics.GridNpu
def scan(
self,
dtypes: tuple[torch.dtype, ...],
combine_fn: Callable[
[tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...]
],
values: tuple[CSEVariable, ...],
) -> tuple[CSEVariable, ...]:
"""NPU override for ops.scan codegen.
Upstream TritonKernel.scan assumes the reduction/scan dimension is the
last dimension in the broadcasted tensor (layout [X, R]) and uses
`dim = triton_tensor_ndim() - num_reduction_dims`.
NPU index codegen may produce either an R-first ([R, X...]) or
R-last ([X..., R]) dense layout depending on how `golden_var_list` is
derived for the current kernel. We must scan along the actual reduction
("r") axis position in the broadcasted dense tensor, instead of
hard-coding an axis.
"""
assert self.inside_reduction
assert not self.cooperative_reduction, "TODO"
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
self.filter_masks(masks)
masks = sorted(masks)
assert not self._load_mask, "ops.scan not supported inside ops.masked"
broadcasted_values: list[CSEVariable] = []
accumulators: list[CSEVariable] = []
dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes)
cse_compute = functools.partial(self.cse.generate, self.compute)
combine_helper_fn = self._lift_helper(combine_fn, len(values), dtypes)
_ = self.dense_size_list()
golden = list(self.golden_var_list or [])
dense_vars = list(reversed(golden))
reduction_dims = [
i
for i, v in enumerate(dense_vars)
if str(v).startswith("r")
]
if reduction_dims:
dim = reduction_dims[-1]
else:
dim = self.triton_tensor_ndim() - self.num_reduction_dims
dense_ndim = len(self.dense_size_list())
scan_axis_sym = dense_vars[dim] if 0 <= dim < len(dense_vars) else None
scan_axis = (
self.range_tree_nodes.get(scan_axis_sym) if scan_axis_sym in self.range_tree_nodes else None
)
scan_axis_name = (
getattr(scan_axis, "name", None) or (str(scan_axis_sym) if scan_axis_sym is not None else "r")
)
rbase_sym = f"base_{scan_axis_name}"
rblock_sym = f"{scan_axis_name.upper()}BLOCK_SUB"
if getattr(scan_axis, "is_split_axis", False):
roffset_expr = f"{scan_axis_name}_offset + (loop_{scan_axis_name} * {rblock_sym})"
else:
roffset_expr = f"(loop_{scan_axis_name} * {rblock_sym})"
reshape_sizes = ["1"] * dense_ndim
if 0 <= dim < dense_ndim:
reshape_sizes[dim] = rblock_sym
rbase_broadcast = (
f"tl.broadcast_to({rbase_sym}.reshape({', '.join(reshape_sizes)}), {self.dense_size_str()})"
if dense_ndim > 1
else rbase_sym
)
for value, dtype in zip(values, dtypes):
value_dtype = self.cse.generate(
self.compute,
f"{value}.to({triton_compute_type(dtype)})",
dtype=dtype,
)
value = self.cse.generate(
self.compute,
f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})",
dtype=dtype,
)
broadcasted_values.append(value)
acc_type = triton_acc_type(dtype)
if not self.persistent_reduction:
accumulator = self.cse.newvar(dtype=dtype)
reduced_size = self.dense_size_list()
reduced_size[dim] = "1"
reduced_size = f"[{', '.join(reduced_size)}]"
default = "float('nan')" if dtype.is_floating_point else "-1"
self.body.writeline(
f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})"
)
accumulators.append(accumulator)
def csv(vs):
return " ".join(f"{v}," for v in vs)
def cse_multiple(line, in_values, in_masks, in_dtypes):
n = len(in_values)
cache_keys = [f"{line}, {i}, {in_masks}" for i in range(n)]
if all(self.cse.contains(cache_key) for cache_key in cache_keys):
return [self.cse.get(cache_key) for cache_key in cache_keys]
result_vars = [self.cse.newvar(dtype=_dtype) for _dtype in in_dtypes]
self.compute.writeline(f"{csv(result_vars)} = {line}")
for result_var, cache_key in zip(result_vars, cache_keys):
if in_masks:
result_var.mask_vars = in_masks
self.cse.put(cache_key, result_var)
return tuple(result_vars)
partial_scan_vars = cse_multiple(
f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})",
values,
masks,
dtypes,
)
if not self.persistent_reduction:
partial_reduce_vars = [
cse_compute(
f"triton_helpers.select_one(({partial_scan_var}), ({rbase_broadcast}) == ({rblock_sym} - 1), dim={dim}, keep_dims=True)",
dtype=upcast_compute_type(partial_scan_var.dtype),
)
for partial_scan_var in partial_scan_vars
]
accs_next = combine_fn(tuple(accumulators), tuple(partial_reduce_vars))
full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars)
result_vars = [
cse_compute(
f"tl.where(({roffset_expr}) > 0, {full_scan}, {partial_scan})",
dtype=partial_scan.dtype,
)
for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars)
]
for acc_next, accumulator, partial_reduce in zip(
accs_next, accumulators, partial_reduce_vars
):
self.compute.writeline(
f"{accumulator} = tl.where(({roffset_expr}) > 0, {acc_next}, {partial_reduce})"
)
else:
result_vars = partial_scan_vars
for result_var in result_vars:
assert isinstance(result_var, TritonCSEVariable)
result_var.mask_vars = OrderedSet(masks)
return tuple(result_vars)
def patch_triton_hash(self):
import hashlib
from triton.compiler.compiler import triton_key, make_backend
from triton.runtime.driver import driver
backend = make_backend(driver.active.get_current_target())
key = f"{triton_key()}-{backend.hash()}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()
def numof_tiling_axis(self):
return len(self.tiling_axis)
def codegen_range_tree(self):
pass
def initialize_range_tree(self, pid_cache):
no_r_dim = not self.inside_reduction or self.numels["r"] == 1
prefixes = "wvtzyxr"
active_prefixes = prefixes[-len(self.numels):]
grid_dims = "xyztvw"
if self.no_x_dim:
tensor_dims = "r"
elif no_r_dim:
tensor_dims = "xyztvw"
else:
tensor_dims = "xyztvwr"
tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes)
for i, prefix in enumerate(active_prefixes):
is_reduction = prefix_is_reduction(prefix)
tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None
grid_dim = None if is_reduction else grid_dims.find(prefix)
index = i if grid_dim is None else grid_dim
self.range_trees.append(
IterationRangesRootNPUIndex(
f"{prefix}index",
self.numels[prefix],
prefix,
index,
self,
pid_cache=pid_cache,
is_loop=is_reduction and not self.persistent_reduction,
tensor_dim=tensor_dim,
grid_dim=grid_dim
)
)
def codegen_reduction_numels(self, buffer) -> None:
reduction_trees = [tree for tree in self.range_trees if tree.is_reduction]
if len(reduction_trees) > 1:
raise AssertionError("Currently npu don't support multi-reduction ranges trees, e.g, r0, r1.")
def get_axis_dtype(self, axis):
dtype = None
if axis is None:
return None
def _lookup_dim(sym) -> Optional["IterationRangesEntryNPUIndex"]:
dim = self.range_tree_nodes.get(sym)
if dim is not None:
return dim
return self.range_tree_nodes_removed.get(sym)
def _iter_candidate_syms(key):
yield key
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
for s in key.free_symbols:
yield s
for node in self.node_schedule:
if node in (EnableReduction, DisableReduction):
continue
if axis.symbol() in node._body.indexing_map:
dtype = V.graph.get_dtype(node.node.name)
break
if dtype is None:
should_break_all = False
for node in self.node_schedule:
if should_break_all:
break
if node in (EnableReduction, DisableReduction):
continue
for key, _ in node._body.indexing_map.items():
dim = None
for cand in _iter_candidate_syms(key):
dim = _lookup_dim(cand)
if dim is not None:
break
if dim is None:
continue
if dim.parent == axis.parent:
dtype = V.graph.get_dtype(node.node.name)
should_break_all = True
break
return dtype
def create_inductor_meta(self):
mutated_args = set()
for mutation in self.mutations:
if mutation in self.args.input_buffers:
mutated_args.add(self.args.input_buffers[mutation])
if (
mutation in self.args.inplace_buffers
and mutation not in V.graph.removed_buffers
and mutation not in self.removed_buffers
):
mutated_args.add(self.args.inplace_buffers[mutation].inner_name)
if mutation in self.args.output_buffers:
mutated_args.add(self.args.output_buffers[mutation])
mutated_args = sorted(mutated_args)
tiling_axis = [x.sorted_order for x in self.tiling_axis]
no_loop_axis = [x.sorted_order for x in self.tiling_axis if x.is_no_loop_axis]
split_axis = [x.sorted_order for x in self.split_axis]
axis_names = [x.name for x in self.sorted_axis]
split_axis_dtype = self.get_axis_dtype(self.split_axis[0]) if self.split_axis else None
inductor_meta = {
"grid_type": self._get_grid_type().__name__,
"autotune_hints": set(self.autotune_hints),
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
"mutated_arg_names": mutated_args,
"backend_hash": self.patch_triton_hash(),
"split_axis": split_axis,
"tiling_axis": tiling_axis,
"no_loop_axis": no_loop_axis,
"axis_names": axis_names,
"low_dims": self.low_dims,
"numof_reduction_axis": self.numof_reduction_axis(),
"split_axis_dtype": split_axis_dtype,
"dual_reduction": self.numof_reduction_axis() > 1,
"npu_kernel_type": str(self.npu_kernel_type),
"traced_graph_hash": "TRACED_GRAPH_HASH",
"traced_graph_dir": "TRACED_GRAPH_DIR",
**TritonKernel.inductor_meta_common()
}
return inductor_meta
def get_size_hints(self):
size_hints = []
if (len(self.range_tree_nodes.values()) == 0):
return [v for _, v in self.numels.items()]
for _, node in enumerate(self.sorted_axis):
if isinstance(node.expr, ModularIndexing):
numel_expr = node.length
else:
numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees})
try:
raw_hint = V.graph.sizevars.shape_env.size_hint(numel_expr)
hint_val = int(raw_hint)
except Exception:
hint_val = 32
size_hints.append(hint_val)
return size_hints
def add_numel_to_call_args(self, name, call_args, arg_types):
for node in self.sorted_axis:
if isinstance(node.expr, ModularIndexing):
numel_expr = node.length
else:
numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees})
if isinstance(numel_expr, (sympy.Integer, sympy.Symbol)):
expr = numel_expr
else:
expr = V.graph.wrapper_code.generate_node_numel_expr(name, node, numel_expr)
call_args.append(expr)
arg_types.append(type(expr))
def gen_numel_args(self, signature, triton_meta, triton_meta_signature, argdefs):
for node in self.sorted_axis:
arg_name = f"{node.name}_numel"
if not npu_config.inductor_static_mode:
sizearg = SizeArg(arg_name, node.length)
signature.append(sizearg)
triton_meta_signature[arg_name] = signature_of(
sizearg, size_dtype=self.index_dtype
)
argdefs.append(ArgName(arg_name))
else:
argdefs.append(ArgName(arg_name, is_constexpr=True))
triton_meta["constants"][arg_name] = node.length
def add_autotune_args(self, argdefs):
for axis in self.split_axis:
argdefs.append(ArgName(f"{axis.name.upper()}BLOCK", is_constexpr=True))
for axis in self.tiling_axis:
if axis.name[0] == 'r' and self.persistent_reduction:
simplified = V.graph.sizevars.simplify(axis.length)
if isinstance(simplified, (sympy.Integer, int)):
continue
elif axis.is_no_loop_axis:
simplified = V.graph.sizevars.simplify(axis.length)
if isinstance(simplified, (sympy.Integer, int)):
continue
argdefs.append(ArgName(f"{axis.name.upper()}BLOCK_SUB", is_constexpr=True))
def _get_heuristic(self):
if self.persistent_reduction:
if not self.inside_reduction:
raise RuntimeError("assert self.inside_reduction to be true")
return "persistent_reduction"
elif self.inside_reduction:
return "reduction"
return "pointwise"
def get_kernel_name(self, src_code, node_schedule, kernel):
wrapper = V.graph.wrapper_code
if src_code in wrapper.src_to_kernel:
kernel_name = wrapper.src_to_kernel[src_code]
else:
fused_name = (
get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
if config.triton.descriptive_names
else ""
)
kernel_category = get_kernel_category_by_source_code(src_code)[:3]
kernel_name = "_".join(
["triton", kernel_category, fused_name, wrapper.get_next_kernel_suffix()]
)
return kernel_name
def codegen_kernel(self, name=None):
code = IndentedBuffer()
size_hints = self.get_size_hints()
heuristics = self._get_heuristic()
if name is None:
code.splice(gen_common_triton_imports())
if config.benchmark_kernel:
code.splice(self.imports_for_benchmark_kernel())
argdefs, _, signature, _ = self.args.python_argdefs()
for i, arg in enumerate(signature):
if isinstance(arg, SizeArg):
symbol = cast(sympy.Symbol, arg.expr)
if symbol in V.graph.sizevars.inv_precomputed_replacements:
signature[i] = SizeArg(
arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
)
triton_meta_signature = signature_to_meta(signature, size_dtype=self.index_dtype, argdefs=argdefs)
triton_meta = {
"signature": triton_meta_signature,
"device":
DeviceProperties.create(
V.graph.get_current_device_or_throw()
),
"constants": {},
"mix_mode": "aiv"
}
inductor_meta = self.create_inductor_meta()
num_gb = None
if config.benchmark_kernel or config.profile_bandwidth:
num_gb = self.estimate_kernel_num_bytes() / 1e9
inductor_meta["kernel_num_gb"] = num_gb
self.gen_numel_args(signature, triton_meta, triton_meta_signature, argdefs)
if self.is_unified_simt_kernel():
triton_meta["configs"] = [config_of(signature)]
self.add_autotune_args(argdefs)
if len(self.range_tree_nodes) == 0:
self.write_scalar()
else:
self.codegen_body()
for helper in self.helper_functions:
code.writeline("")
code.splice(helper)
if self.inside_reduction:
reduction_hint = self.features.get_reduction_hint()
heuristics_line = f"""
@triton_heuristics.{heuristics}(
size_hints={size_hints},
reduction_hint={reduction_hint},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r}
)
@triton.jit
"""
else:
tile_hint = ""
if len(size_hints) == 2:
if len(signature) == 4:
tile_hint = "tile_hint=TileHint.SQUARE,"
else:
tile_hint = "tile_hint=TileHint.DEFAULT,"
heuristics_line = f"""
@triton_heuristics.{heuristics}(
size_hints={size_hints!r}, {tile_hint}
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
min_elem_per_thread={self.min_elem_per_thread}
)
@triton.jit
"""
code.splice(heuristics_line)
code.writeline(
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):"
)
with code.indent():
self.codegen_static_numels(code)
for old, new in self.args.aliases():
code.writeline(f"{old} = {new}")
code.splice(self.body)
if config.benchmark_kernel:
code.splice(self.codegen_kernel_benchmark(num_gb))
return code.getvalue()
def codegen_static_numels(self, code):
for symbol in self.reduction_axis_list():
if symbol.name[0] != "r" or not self.persistent_reduction:
continue
node = self.range_tree_nodes[symbol]
simplified_tree_numel = V.graph.sizevars.simplify(node.length)
if isinstance(simplified_tree_numel, (sympy.Integer, int)):
val = int(simplified_tree_numel)
else:
continue
if self.is_unified_simt_kernel():
val = next_power_of_2(val)
code.writeline(f"{node.name.upper()}BLOCK_SUB: tl.constexpr = {val}")
for axis in self.sorted_axis:
if axis.is_no_loop_axis:
simplified_tree_numel = V.graph.sizevars.simplify(axis.length)
if isinstance(simplified_tree_numel, (sympy.Integer, int)):
val = int(simplified_tree_numel)
else:
continue
code.writeline(f"{axis.name}_numel = {val}")
if self.is_unified_simt_kernel():
code.writeline(f"{axis.name.upper()}BLOCK_SUB: tl.constexpr = {next_power_of_2(val)}")
else:
code.writeline(f"{axis.name.upper()}BLOCK_SUB: tl.constexpr = {val}")
def lowest_axis_variable(self):
if len(self.tiling_axis) == 0:
return None
return self.tiling_axis[-1]
def is_isolated_symbol(self, input_str, range_val):
patterns = [r'\b' + re.escape(range_val.name) + r'\b']
for var in range_val.var_directions.keys():
pattern = r'\b' + re.escape(var.name) + r'\b'
patterns.append(pattern)
for pattern in patterns:
if re.search(pattern, input_str):
return True
return False
def _axis_used_in_coordinate_transforms(self, axis_name):
if not hasattr(self, 'coordinate_transforms') or not self.coordinate_transforms:
return False
return any(
transform_info.get('unified_var') == axis_name
for transform_info in self.coordinate_transforms.values()
)
def _iter_codegen_lines(self):
for line in self.loads._lines:
yield line
for line in self.compute._lines:
yield line
for line in self.post_loop_store._lines:
yield line.line if isinstance(line, DeferredLine) else line
for line in self.stores._lines:
yield line.line if isinstance(line, DeferredLine) else line
def _emit_coordinate_transforms(self):
"""
Emit coordinate transformation code into the kernel body.
"""
if not hasattr(self, 'coordinate_transforms') or not self.coordinate_transforms:
return
self.body.writeline("# Coordinate transformation for different expansion patterns")
for name, transform_info in self.coordinate_transforms.items():
self.body.writeline(f"# Transform for {name}")
transform_lines = self._generate_coordinate_transform_lines(
transform_info['unified_var'],
transform_info['dims'],
transform_info['numels'],
)
for line in transform_lines:
self.body.writeline(line)
def find_axis_in_load_store(self, range_val):
if not range_val:
return False
if self._axis_used_in_coordinate_transforms(range_val.name):
return True
return any(
self.is_isolated_symbol(line, range_val)
for line in self._iter_codegen_lines()
)
def write_scalar(self):
self.body.splice(self.indexing_code)
self.body.splice(self.loads)
self.body.splice(self.compute)
self.body.splice(self.stores)
self.loads.clear()
self.compute.clear()
self.stores.clear()
self.post_loop_store.clear()
self.prefix.clear()
def codegen_body(self):
if not (
self.loads
or self.stores
or self.compute
or self.post_loop_store
):
return
def write_pointwise():
self._emit_coordinate_transforms()
self.body.splice(self.indexing_code)
self.body.splice(self.loads)
self.body.splice(self.compute)
self.body.splice(self.stores)
def collect_store_unified_vars():
"""
Collect unified axis names that are directly referenced by Store indices.
Returns:
A set of unified axis names used by Store indices.
"""
if not hasattr(self, 'store_unified_indexing') or not self.store_unified_indexing:
return set()
store_unified_vars = set()
for store_index in self.store_unified_indexing:
store_unified_vars.update(str(sym) for sym in store_index.free_symbols)
return store_unified_vars
def build_sub_axis_to_unified_var(store_unified_vars):
"""
Build a mapping from sub-axis names to their unified axis names.
Args:
store_unified_vars: Unified axis names that are directly used by Store indices.
Returns:
A dictionary mapping each sub-axis name to its unified axis name.
"""
if not hasattr(self, 'different_expansions') or not self.different_expansions:
return {}
if not hasattr(self, 'coordinate_transforms') or not self.coordinate_transforms:
return {}
sub_axis_to_unified_var = {}
for expansion_list in self.different_expansions.values():
for exp in expansion_list:
transform_info = self.coordinate_transforms.get(exp['name'])
if not transform_info:
continue
unified_var = transform_info.get('unified_var')
if unified_var not in store_unified_vars:
continue
for dim in exp['dims']:
sub_axis_to_unified_var[dim] = unified_var
return sub_axis_to_unified_var
store_unified_vars = collect_store_unified_vars()
sub_axis_to_unified_var = build_sub_axis_to_unified_var(store_unified_vars)
def codegen_range(index):
def is_1d_reduction():
return V.graph.sizevars.statically_known_gt(self.numels["r"], 1) and len(self.numels) == 1
def loop_body(index, indexing_code, is_last_axis, do_indent=True):
if do_indent:
self.body.do_indent()
if indexing_code:
self.body.splice(indexing_code)
if is_last_axis:
write_pointwise()
else:
codegen_range(index + 1)
if do_indent:
self.body.do_unindent()
def should_skip_loop_for_sub_axis(range_val):
"""
Check if this axis is a sub-axis that should skip loop generation.
Returns:
True if should skip loop generation, False otherwise
"""
axis_name = range_val.name
unified_var = sub_axis_to_unified_var.get(axis_name)
if unified_var is None:
return False
return True
if index < 0 or index >= len(self.range_tree_nodes):
return
range_val = self.sorted_axis[index]
if should_skip_loop_for_sub_axis(range_val):
loop_body(index, None, is_last_axis=False, do_indent=False)
return
numof_tilings = len(self.tiling_axis)
last_tiling = range_val.is_tiling_axis and numof_tilings >= 1 and range_val.tiling_order == len(
self.tiling_axis) - 1
is_last_axis = index == len(self.sorted_axis) - 1
indexing_code = getattr(range_val, "indexing_code")
reduction_1d = is_1d_reduction()
do_indent = False
if range_val.is_tiling_axis and last_tiling:
do_indent = False
have_load_store = self.find_axis_in_load_store(range_val)
if not have_load_store:
indexing_code = None
need_axis_loop = have_load_store and (not range_val.is_no_loop_axis)
if (range_val.prefix != 'r' or not self.persistent_reduction) and need_axis_loop:
self.body.splice(self.prefix)
self.body.writeline(f"for loop_{range_val.name} in range(loops_{range_val.name}):")
do_indent = True
loop_body(index, indexing_code, is_last_axis, do_indent)
self.body.splice(self.post_loop_store)
self.post_loop_store.clear()
elif range_val.is_tiling_axis:
do_indent = False
have_load_store = self.find_axis_in_load_store(range_val)
if not have_load_store:
do_indent = False
indexing_code = None
if not range_val.is_no_loop_axis:
do_indent = True
self.body.writeline(f"for loop_{range_val.name} in range(loops_{range_val.name}):")
loop_body(index, indexing_code, is_last_axis, do_indent=do_indent)
elif not is_last_axis:
do_indent = True
if range_val.is_split_axis:
offset = f"{range_val.name}_offset"
self.body.writeline(f"for {range_val.name} in range({offset}, "
f"min({offset} + {range_val.name.upper()}BLOCK, {range_val.name}_numel)):")
else:
self.body.writeline(f"for {range_val.name} in range({range_val.name}_numel):")
if not reduction_1d and self.persistent_reduction:
self.body.do_indent()
self.body.splice(self.prefix)
self.prefix.clear()
self.body.do_unindent()
loop_body(index, indexing_code, is_last_axis, do_indent=do_indent)
else:
write_pointwise()
if self.first_node:
for node in self.sorted_axis:
node.codegen_header(self.body)
while True:
if not self.sorted_axis[-1].is_tiling_axis:
x = self.sorted_axis[-1]
self.sorted_axis.pop(-1)
self.sorted_axis.insert(0, x)
else:
break
if self.first_node:
codegen_range(0)
else:
last_axis_order = self.tiling_axis[-1].sorted_order
if self.persistent_reduction and self.numof_reduction_axis() > 1:
last_axis_order = last_axis_order - self.numof_reduction_axis() + 1
for _ in range(last_axis_order):
self.body.do_indent()
codegen_range(last_axis_order)
for _ in range(last_axis_order):
self.body.do_unindent()
self.cse.invalidate(self.outside_loop_vars)
self.loads.clear()
self.compute.clear()
self.stores.clear()
self.post_loop_store.clear()
self.prefix.clear()
self.first_node = False
def triton_tensor_ndim(self):
if self.numof_reduction_axis() > 1:
if self.is_contiguous_reduction():
return len(self.tiling_axis) - self.numof_reduction_axis() + 1
return 1
return len(self.tiling_axis)
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
if not self.inside_reduction:
raise RuntimeError("assert self.inside_reduction")
self.inside_reduction = False
indexing = self.indexing(index, block_ptr=True)
self.inside_reduction = True
var = self.args.output(name)
if isinstance(indexing, BlockPtrOptions):
self.post_loop_store.writeline(
DeferredLine(
name,
self.codegen_block_ptr_store_line(
name,
indexing,
indexing.format(var),
value,
f", boundary_check={indexing.boundary_check()!r}",
),
)
)
else:
if not isinstance(indexing, IndexingOptions):
raise RuntimeError("assert isinstance(indexing, IndexingOptions)")
line = f"tl.store({var} + ({indexing.index_str} ), {value}, {indexing.mask_str})"
if self.numof_reduction_axis() > 1:
line = f"tl.store({var} + ({indexing.index_str} + tl.arange(0,1) ), {value}, {indexing.mask_str})"
self.post_loop_store.writeline(
DeferredLine(name, line)
)
def store(
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> None:
var = self.args.output(name)
original_index = index
index_analyze = IndexAnalysis(self, index, is_store_index=True)
index_analyze.analyze_index()
indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None, index_analyze=index_analyze)
index_str = indexing.index_str
value_str = f"{value}"
mask_str = indexing.mask_str
if index_analyze.need_permute:
value_str = value_str.replace(f"{value}", f"{value}{index_analyze.generate_statement()}")
advance_block_ptr = None
if isinstance(indexing, BlockPtrOptions):
block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
name, var, indexing
)
line = self.codegen_block_ptr_store_line(
name, indexing, block_ptr, value, other
)
elif mode is None:
line = f"tl.store({var} + ({index_str}), {value_str}, {mask_str})"
if self.numof_reduction_axis() > 1:
line = f"tl.store({var} + ({index_str} + tl.arange(0,1) ), {value_str}, {indexing.mask_str})"
elif mode == "atomic_add":
line = f"tl.atomic_add({var} + ({index_str}), {value_str}, {indexing.mask_str})"
else:
raise NotImplementedError(f"store mode={mode}")
self.stores.writeline(DeferredLine(name, line))
if advance_block_ptr:
self.stores.writeline(advance_block_ptr)
if not self.inside_reduction:
self.outside_loop_vars.add(value)
def contains_cat_node(self):
node = self.current_node
if node is not None and isinstance(node, SchedulerNode):
cat_node = node.node.data
if cat_node is not None and hasattr(cat_node, "inner_fn_str") and "ops.cat" in cat_node.inner_fn_str():
return True
for node in self.node_schedule:
if node in (EnableReduction, DisableReduction):
continue
cat_node = node.node.data
if cat_node is not None and hasattr(cat_node, "inner_fn_str") and "ops.cat" in cat_node.inner_fn_str():
return True
return False
def find_reduction_node(self):
node = self.current_node
if node is not None and isinstance(node, SchedulerNode):
reduction = node.node.data
if reduction is not None and isinstance(reduction, ir.Reduction):
return reduction
for node in self.node_schedule:
if node in (EnableReduction, DisableReduction):
continue
reduction = node.node.data
if reduction is not None and isinstance(reduction, ir.Reduction):
return reduction
return None
def find_indirect_axis(self, indirect_var):
indirect_key = sympy_index_symbol(indirect_var)
for node in self.node_schedule:
if node in (EnableReduction, DisableReduction):
continue
if indirect_key in node._body.indirect_replacements:
return node._body.indirect_replacements[indirect_key]
return None
def get_template_shape_offset(self, analyzer):
tiling_offset = []
axis_shape = []
reshape_list = []
reshape_type = None
total_var_list = list(reversed(analyzer.all_var_list))
prev_var = None
for var in reversed(self.golden_var_list):
if var not in total_var_list:
insert_pos = 0
if prev_var:
insert_pos = total_var_list.index(prev_var) + 1
total_var_list.insert(insert_pos, var)
reshape_type = 'broadcast_to'
prev_var = var
for axis_key in total_var_list:
axis = self.range_tree_nodes[axis_key]
BLOCK_NAME = f"{axis.name.upper()}BLOCK"
BLOCK_NAME_SUB = f"{BLOCK_NAME}_SUB"
axis_offset = "0"
axis_numel = str(axis_key) + "_numel"
if axis.is_tiling_axis:
if axis.is_split_axis:
offset = f"{axis.name}_offset"
axis_offset = f"{offset} + (loop_{axis.name} * {BLOCK_NAME_SUB})"
axis_numel = f"min({BLOCK_NAME} + {offset}, {axis_numel})"
else:
axis_persistent_reduction = axis.name[0] == 'r' and self.persistent_reduction
if (not axis.is_no_loop_axis) and (not axis_persistent_reduction):
axis_offset = f"(loop_{axis.name} * {BLOCK_NAME_SUB})"
reshape_list.append(BLOCK_NAME_SUB)
else:
axis_offset = f"{axis.name}"
if not reshape_type:
reshape_type = 'reshape'
reshape_list.append("1")
axis_shape.append(axis_numel)
tiling_offset.append(axis_offset)
return axis_shape, tiling_offset, reshape_type, reshape_list
def get_template_offset(self, analyzer):
axis_start_offset = []
axis_end_offset = []
reshape_list = []
reshape_type = ""
for axis_key in reversed(analyzer.all_var_list):
if symbol_is_type(axis_key, SymT.TMP):
axis_start_offset.append('0')
continue
axis = self.range_tree_nodes[axis_key]
BLOCK_NAME = f"{axis.name.upper()}BLOCK"
BLOCK_NAME_SUB = f"{BLOCK_NAME}_SUB"
start_offset = "0"
end_offset = f"{axis.name}_numel"
if axis.is_tiling_axis:
if axis.is_split_axis:
offset = f"{axis.name}_offset"
start_offset = f"{offset} + (loop_{axis.name} * {BLOCK_NAME_SUB})"
end_offset = f"min({BLOCK_NAME} + {offset}, {axis.name}_numel)"
else:
axis_persistent_reduction = axis.name[0] == 'r' and self.persistent_reduction
if (not axis.is_no_loop_axis) and (not axis_persistent_reduction):
start_offset = f"(loop_{axis.name} * {BLOCK_NAME_SUB})"
reshape_list.append(BLOCK_NAME_SUB)
else:
start_offset = f"{axis.name}"
end_offset = f"{axis.name}_numel"
reshape_list.append("1")
reshape_type = "reshape"
axis_start_offset.append(start_offset)
axis_end_offset.append(end_offset)
return axis_start_offset, axis_end_offset, reshape_type, reshape_list
def parse_golden_from_load_store_index(self):
load_store_indexing = getattr(self, "load_store_indexing", None)
if load_store_indexing:
return collect_stride_sorted_vars_from_indexings(load_store_indexing)
return collect_stride_sorted_vars_from_nodes(
self.node_schedule,
(EnableReduction, DisableReduction),
)
def get_reduction_layout_var_list(self):
if self.numof_reduction_axis() > 1:
stride_sorted_var_list = self.parse_golden_from_load_store_index()
if stride_sorted_var_list:
return tuple(stride_sorted_var_list)
if not self.golden_var_list:
self.select_golden_varlist()
if self.golden_var_list:
return tuple(self.golden_var_list)
return tuple([x.symbol() for x in self.tiling_axis]) if self.tiling_axis else ()
def load_store_index_in_all_tiling_list(self):
res = False
for index in self.load_store_indexing:
index = index.subs(V.graph.sizevars.var_to_val)
analyze = IndexAnalysis(self, index)
res = res or self.all_tiling_in_var_list(analyze.var_list)
return res
def all_tiling_in_var_list(self, var_list):
return all([x in var_list for x in self.tiling_axis])
def _parse_expansion_from_index(self, index):
"""
Parse index Expression to Extract Expansion Patterns.
CRITICAL FIX (v1.4):
Previous versions tried to derive dimension sizes from strides or
factorization, which is fragile and error-prone.
The correct approach: directly query V.kernel.range_tree_nodes,
which already stores each axis's size (length) as an authoritative
data source. No guessing needed.
Data source:
V.kernel.range_tree_nodes is a Dict[sympy.Symbol, IterationRangesEntryNPUIndex]
Each IterationRangesEntryNPUIndex has a .length attribute = axis size.
These nodes are created during scan() -> index_preparation() -> lookup(),
so they are already available when this function is called.
Args:
index: SymPy expression representing the index
Returns:
Dict mapping prefix to expansion info:
{
'x': {'dims': ['x3', 'x4'], 'numels': [32, 20]},
'y': {'dims': ['y1'], 'numels': [80]}
}
Example:
index = x4 + 20*y1 + 1600*x3 + 51200*z0
Direct query:
V.kernel.range_tree_nodes[x3].length = 32
V.kernel.range_tree_nodes[x4].length = 20
Returns:
{'x': {'dims': ['x3', 'x4'], 'numels': [32, 20]}}
"""
result = {}
symbols = index.free_symbols
symbol_groups = {}
for sym in symbols:
sym_str = str(sym)
if len(sym_str) > 0:
prefix = sym_str[0]
if prefix not in symbol_groups:
symbol_groups[prefix] = []
symbol_groups[prefix].append(sym)
for prefix, syms in symbol_groups.items():
if len(syms) >= 2:
strides = []
for sym in syms:
coeff = index.coeff(sym)
if coeff is not None and coeff != 0:
strides.append((sym, coeff))
if len(strides) >= 2:
def get_sort_key(item):
sym, stride = item
if hasattr(stride, 'free_symbols') and stride.free_symbols:
if sym in V.kernel.range_tree_nodes:
return V.kernel.range_tree_nodes[sym].sorted_order or 0
elif sym in V.kernel.range_tree_nodes_removed:
return V.kernel.range_tree_nodes_removed[sym].sorted_order or 0
else:
return str(sym)
else:
return stride if stride else 0
sorted_strides = sorted(strides, key=get_sort_key)
dims = []
numels = []
for sym, _stride in sorted_strides:
if sym in V.kernel.range_tree_nodes:
size = V.kernel.range_tree_nodes[sym].length
elif sym in V.kernel.range_tree_nodes_removed:
size = V.kernel.range_tree_nodes_removed[sym].length
else:
npu_config.log.warning(
"Missing range tree node for symbol %s while parsing expansion",
sym,
)
continue
dims.append(str(sym))
numels.append(int(size))
if len(dims) == len(sorted_strides):
result[prefix] = {'dims': dims, 'numels': numels}
return result
def _detect_different_expansions(self):
"""
Detect if Same Output Dimension has Different Expansion Patterns.
Returns:
Dict mapping prefix to list of Expansion info if Different Patterns Exist, or None
"""
expansions = {}
for idx, index in enumerate(self.load_store_indexing):
parsed = self._parse_expansion_from_index(index)
for prefix, info in parsed.items():
if prefix not in expansions:
expansions[prefix] = []
expansions[prefix].append({
'name': f'index_{idx}',
'dims': info['dims'],
'numels': info['numels']
})
different = {}
for prefix, expansion_list in expansions.items():
numels_set = set(tuple(e['numels']) for e in expansion_list)
if len(numels_set) > 1:
different[prefix] = expansion_list
npu_config.log.info(
"Detected different expansions for prefix %s: %s",
prefix,
[
{
"name": e["name"],
"dims": e["dims"],
"numels": e["numels"],
}
for e in expansion_list
],
)
return different if different else None
def _find_store_unified_axis(self, prefix, dims_to_remove, total_size):
store_indexing = getattr(self, "store_unified_indexing", None)
if not store_indexing:
npu_config.log.info(
"Skip special path for prefix %s because Store unified indexing is unavailable",
prefix,
)
return None
matched_axes = {}
for idx, store_index in enumerate(store_indexing):
prefix_symbols = []
for sym in sorted(store_index.free_symbols, key=str):
if not sym.name.startswith(prefix) or sym.name in dims_to_remove:
continue
node = self.range_tree_nodes.get(sym) or self.range_tree_nodes_removed.get(sym)
if node is not None and node.length == total_size:
prefix_symbols.append(node)
if len(prefix_symbols) == 1:
matched_axes[prefix_symbols[0].name] = prefix_symbols[0]
if len(matched_axes) == 1:
unified_axis = next(iter(matched_axes.values()))
npu_config.log.info(
"Selected Store unified axis for prefix %s: %s (length=%s)",
prefix,
unified_axis.name,
unified_axis.length,
)
return unified_axis
if len(matched_axes) > 1:
npu_config.log.warning(
"Ambiguous Store unified axis for prefix %s: %s",
prefix,
sorted(matched_axes.keys()),
)
else:
npu_config.log.info(
"Skip special path for prefix %s because no valid Store unified axis exists",
prefix,
)
return None
def _get_range_tree_for_prefix(self, prefix):
"""
Get the IterationRangesRoot for the Given Prefix.
The range_trees are created during initialize_range_tree() and contain one
IterationRangesRootNPUIndex per prefix (x, y, z, r).
Args:
prefix: Dimension prefix ('x', 'y', 'z', 'r')
Returns:
IterationRangesRootNPUIndex or None if not found
"""
for tree in self.range_trees:
if hasattr(tree, 'prefix') and tree.prefix == prefix:
return tree
return None
def _create_unified_dimension(self, prefix, expansion_list):
"""
Create Unified Dimension Using Lookup.
Args:
prefix: Dimension prefix (e.g., 'x')
expansion_list: List of expansion info dicts
Returns:
Unified IterationRangesEntry or None if failed
"""
total_size = 1
for numel in expansion_list[0]['numels']:
total_size *= numel
expected_size = self.numels.get(prefix)
if expected_size is not None and total_size != expected_size:
npu_config.log.warning(
"Adjust unified size for prefix %s from %s to %s",
prefix,
total_size,
expected_size,
)
total_size = expected_size
root = self._get_range_tree_for_prefix(prefix)
if root is None:
npu_config.log.warning(
"No range tree found for prefix %s while creating unified dimension",
prefix,
)
return None
unified_axis = root.lookup(1, total_size)
return unified_axis
def _generate_coordinate_transform_code(self, unified_var, dims, numels):
"""
Generate Coordinate Transformation Code.
Args:
unified_var: Unified variable name (e.g., 'x')
dims: List of dimension names (e.g., ['x3', 'x4'])
numels: List of dimension sizes (e.g., [32, 20])
Returns:
String containing transformation code
"""
code_lines = []
remaining = unified_var
for i in range(len(dims) - 1, 0, -1):
factor = math.prod(numels[:i])
code_lines.append(f"{dims[i]} = {remaining} // {factor}")
remaining = f"({remaining} % {factor})"
code_lines.append(f"{dims[0]} = {remaining}")
return '\n'.join(code_lines)
def _generate_coordinate_transform_lines(self, unified_var, dims, numels):
"""
Generate coordinate transformation code lines.
Args:
unified_var: Unified variable name (e.g., 'x2')
dims: List of dimension names participating in the transform
numels: List of dimension sizes
Returns:
A list of code lines implementing the coordinate transform.
"""
return self._generate_coordinate_transform_code(
unified_var, dims, numels
).split('\n')
def _log_select_golden_varlist_inputs(self):
"""
Log the current load/store indexing state before selecting golden vars.
"""
npu_config.log.debug(
"select_golden_varlist load_store_indexing=%s",
len(self.load_store_indexing),
)
def _build_guarded_expansions(self, different_expansions):
"""
Filter different expansions to cases with a valid Store-side unified axis.
Args:
different_expansions: Expansion info detected from load/store indexing.
Returns:
A dictionary containing only guarded expansion groups.
"""
guarded_expansions = {}
if not different_expansions:
return guarded_expansions
for prefix, expansion_list in different_expansions.items():
dims_to_remove = set()
for exp in expansion_list:
dims_to_remove.update(exp['dims'])
total_size = 1
for size in expansion_list[0]['numels']:
total_size *= size
unified_axis = self._find_store_unified_axis(prefix, dims_to_remove, total_size)
if unified_axis is None:
continue
guarded_expansions[prefix] = {
"expansion_list": expansion_list,
"dims_to_remove": dims_to_remove,
"unified_axis": unified_axis,
}
return guarded_expansions
def _remove_expansion_dims_from_axes(self, dims_to_remove):
"""
Remove sub-axis dimensions from active tiling/sorted axes.
Args:
dims_to_remove: Dimension names to remove from active structure axes.
"""
old_len = len(self.tiling_axis)
self.tiling_axis = [
axis for axis in self.tiling_axis
if axis.name not in dims_to_remove
]
if hasattr(self, 'sorted_axis') and self.sorted_axis:
old_len = len(self.sorted_axis)
self.sorted_axis = [
axis for axis in self.sorted_axis
if axis.name not in dims_to_remove
]
for i, axis in enumerate(self.sorted_axis):
axis.sorted_order = i
self.low_dims = set()
def _append_unified_axis_to_active_axes(self, unified_axis):
"""
Add the unified axis back into active tiling and sorted axes.
Args:
unified_axis: The Store-side unified axis.
"""
if not unified_axis:
return
unified_axis.is_tiling_axis = True
unified_axis.is_no_loop_axis = False
if unified_axis not in self.tiling_axis:
unified_axis.tiling_order = len(self.tiling_axis)
self.tiling_axis.append(unified_axis)
else:
unified_axis.tiling_order = self.tiling_axis.index(unified_axis)
if hasattr(self, 'sorted_axis') and self.sorted_axis is not None and unified_axis not in self.sorted_axis:
unified_axis.sorted_order = len(self.sorted_axis)
self.sorted_axis.append(unified_axis)
elif hasattr(self, 'sorted_axis') and self.sorted_axis is not None:
unified_axis.sorted_order = self.sorted_axis.index(unified_axis)
def _clear_sub_axis_states(self, dims_to_remove):
"""
Clear structure state for sub-axes that are kept only for indexing.
Args:
dims_to_remove: Dimension names whose structure state should be cleared.
"""
for dim_name in dims_to_remove:
matched = False
for container_name, container in (
("range_tree_nodes", self.range_tree_nodes),
("range_tree_nodes_removed", self.range_tree_nodes_removed),
):
for key, node in container.items():
if str(key) != dim_name:
continue
node.is_tiling_axis = False
node.is_no_loop_axis = False
node.tiling_order = None
node.directions = []
node.var_directions = {}
matched = True
if not matched:
npu_config.log.warning(
"Failed to clear sub-axis state for %s",
dim_name,
)
def _update_coordinate_transforms(self, prefix, expansion_list, unified_axis):
"""
Build coordinate transform metadata for a guarded expansion group.
Args:
prefix: Dimension prefix of the guarded expansion group.
expansion_list: Expansion metadata list.
unified_axis: The Store-side unified axis.
"""
unified_var_name = unified_axis.name if unified_axis else prefix
for exp in expansion_list:
self.coordinate_transforms[exp['name']] = {
'unified_var': unified_var_name,
'dims': exp['dims'],
'numels': exp['numels'],
}
def _apply_guarded_expansions(self, guarded_expansions):
"""
Apply unified-axis special path using guarded expansion metadata.
Args:
guarded_expansions: Guarded expansion groups keyed by prefix.
"""
self.coordinate_transforms = {}
for prefix, guarded_info in guarded_expansions.items():
expansion_list = guarded_info["expansion_list"]
dims_to_remove = guarded_info["dims_to_remove"]
unified_axis = guarded_info["unified_axis"]
self._remove_expansion_dims_from_axes(dims_to_remove)
self._append_unified_axis_to_active_axes(unified_axis)
self._clear_sub_axis_states(dims_to_remove)
self._update_coordinate_transforms(prefix, expansion_list, unified_axis)
npu_config.log.info(
"Apply unified-axis special path for prefix %s: unified_axis=%s, removed_dims=%s",
prefix,
unified_axis.name if unified_axis else None,
sorted(dims_to_remove),
)
self.golden_var_list = tuple([x.symbol() for x in self.tiling_axis])
npu_config.log.info(
"Unified-axis final structure: golden=%s, tiling_axis=%s, sorted_axis=%s",
self.golden_var_list,
[x.name for x in self.tiling_axis],
[x.name for x in self.sorted_axis] if self.sorted_axis else [],
)
def _select_golden_varlist_normal_case(self):
"""
Run the original golden var selection logic for non-special cases.
"""
longest = None
maximum_length = 0
for index in self.load_store_indexing:
index = index.subs(V.graph.sizevars.var_to_val)
analyze = IndexAnalysis(self, index)
if len(analyze.var_list) > maximum_length and self.all_tiling_in_var_list(analyze.var_list):
longest = analyze.var_list
maximum_length = len(longest)
if not longest:
if self.find_reduction_node is not None and self.load_store_index_in_all_tiling_list() is False:
self.golden_var_list = self.parse_golden_from_load_store_index()
else:
self.golden_var_list = tuple([x.symbol() for x in self.tiling_axis]) if self.tiling_axis else []
else:
self.golden_var_list = tuple([x for x in longest if x in self.tiling_axis]) if self.tiling_axis else []
if self.golden_var_list is not None and self.tiling_axis:
golden_list = list(self.golden_var_list)
for x in self.tiling_axis:
sym = x.symbol() if hasattr(x, 'symbol') else x
if sym not in golden_list:
golden_list.append(sym)
self.golden_var_list = tuple(golden_list)
def select_golden_varlist(self):
self.golden_var_list = None
self._log_select_golden_varlist_inputs()
different_expansions = self._detect_different_expansions()
guarded_expansions = self._build_guarded_expansions(different_expansions)
different_expansions = {
prefix: info["expansion_list"]
for prefix, info in guarded_expansions.items()
} or None
self.different_expansions = different_expansions
if different_expansions:
self._apply_guarded_expansions(guarded_expansions)
else:
self._select_golden_varlist_normal_case()
if self.golden_var_list is None:
raise RuntimeError("assert self.golden_var_list is None")
npu_config.log.info(
"Selected golden vars: golden=%s, tiling_axis=%s, different_expansions=%s",
self.golden_var_list,
[x.name for x in self.tiling_axis] if self.tiling_axis else [],
different_expansions,
)
def dense_size_list(self) -> List[str]:
if self.find_reduction_node() is not None:
if not self.reduce_analysis:
self.reduce_analysis = ReductionAnalysis(self)
if self.is_contiguous_reduction():
return self.reduce_analysis.dense_post_reduction_list()
return self.reduce_analysis.dense_size_list()
if not self.golden_var_list:
self.select_golden_varlist()
golden_var_list = self.golden_var_list if self.golden_var_list else [x.symbol() for x in self.tiling_axis]
if golden_var_list is None:
raise RuntimeError("assert golden_var_list is None")
sizes = [None for _ in golden_var_list]
for i, var in enumerate(reversed(golden_var_list)):
axis = self.range_tree_nodes[var]
sizes[i] = f"{axis.name.upper()}BLOCK_SUB"
return sizes
def is_contiguous_reduction(self):
def is_contiguous_axis(axis_list):
axis_set = set(axis_list)
return len(axis_set) == (max(axis_set) - min(axis_set) + 1)
if self.numof_reduction_axis() > 1:
stride_sorted_var_list = self.parse_golden_from_load_store_index()
reduction_dim_list = []
if not stride_sorted_var_list:
if not self.golden_var_list:
self.select_golden_varlist()
stride_sorted_var_list = list(self.golden_var_list) if self.golden_var_list else []
for i, x in enumerate(reversed(stride_sorted_var_list)):
if x.name[0] == 'r':
reduction_dim_list.append(i)
return is_contiguous_axis(reduction_dim_list)
return False
def dense_size_str(self):
if self.find_reduction_node() is not None:
if not self.reduce_analysis:
self.reduce_analysis = ReductionAnalysis(self)
return self.reduce_analysis.dense_size_str()
sizes = self.dense_size_list()
return f"[{', '.join(sizes)}]"
def reduction_resize(self, value, dim):
ndims = self.triton_tensor_ndim()
if ndims == 1:
return f"triton_helpers.promote_to_tensor({value})"
dense_list = self.dense_size_list()
dense_list[dim] = "1"
contiguous_reduction = self.is_contiguous_reduction()
if contiguous_reduction:
residual_shape_length = len(self.golden_var_list) - len(dense_list)
for i in range(residual_shape_length):
dense_list.insert(dim + i + 1, "1")
expand_str = ", ".join(dense_list)
return f"{value}.reshape({expand_str})"
def reduction_dim(self):
if not self.reduce_analysis:
self.reduce_analysis = ReductionAnalysis(self)
return self.reduce_analysis.reduced_dim
def is_unified_simt_kernel(self):
return self.npu_kernel_type == NPUKernelType.SIMT_ONLY or self.npu_kernel_type == NPUKernelType.SIMD_SIMT_MIX
def filter_masks(self, mask_vars, index_vars=None):
variable_mask_vars = set(mask_var for mask_var in mask_vars if isinstance(mask_var, TritonCSEVariable))
normal_mask_vars = mask_vars.copy() - variable_mask_vars
masked_axis_name = []
subblock_axis = V.kernel.current_subblock_axis
save_variable_mask = True
if index_vars and subblock_axis:
save_variable_mask = subblock_axis.issubset({str(var) for var in index_vars})
for node in self.sorted_axis:
allow_dynamic = get_allow_dynamic()
is_persistent_reduction_axis = self.persistent_reduction and node.is_reduction
if self.is_unified_simt_kernel():
if not node.is_tiling_axis:
continue
else:
if ((not node.is_tiling_axis) or
is_persistent_reduction_axis or
(node.is_no_loop_axis and not allow_dynamic)):
continue
if save_variable_mask and node.name in subblock_axis:
continue
masked_axis_name.append(node.name)
for mask_var in variable_mask_vars:
if save_variable_mask:
continue
mask_vars.discard(mask_var)
for mask_var in normal_mask_vars:
valid_mask_var = False
mask_var_str = str(mask_var)
for axis_name in masked_axis_name:
if mask_var_str == f"{axis_name}mask" or mask_var_str.startswith(f"{axis_name}_"):
valid_mask_var = True
if not valid_mask_var:
mask_vars.discard(mask_var)
def numof_reduction_axis(self):
root = self.range_trees[-1]
if root is None:
return 0
result = len(root.var_list)
return result
def reduction_axis_list(self):
root = self.range_trees[-1]
if root is None:
return []
return root.var_list
def reduction(
self,
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
if not self.inside_reduction:
raise RuntimeError("assert self.inside_reduction")
if self.persistent_reduction and self.numof_reduction_axis() == 1:
masks = {f"{node.symbol()}_mask" for node in self.sorted_axis if node.name[0] != "r"}
else:
masks = {f"{node.symbol()}_mask" for node in self.sorted_axis}
self.filter_masks(masks)
masks = sorted(masks)
if self._load_mask:
masks.append(self._load_mask)
reduction_range_prefix = self.range_trees[-1].prefix
if not self.reduce_analysis:
self.reduce_analysis = ReductionAnalysis(self)
dense_size_str = self.dense_size_str()
axis_list = []
for index in self.load_store_indexing:
for axis in V.kernel.range_tree_nodes.keys():
if str(axis) in str(index) and (axis not in axis_list):
axis_list.append(axis)
if len(axis_list) < len(self.dense_size_list()):
value = self._map_tuple_or_scalar(
lambda v: self.cse.generate(
self.compute, f"tl.broadcast_to({v}, {dense_size_str})", dtype=v.dtype,
),
value,
)
if len(dense_size_str) > 2 and (not self.persistent_reduction or self.numof_reduction_axis() != 1):
value = self._map_tuple_or_scalar(
lambda v: self.cse.generate(
self.compute, f"tl.reshape({v}, {dense_size_str})", dtype=v.dtype,
),
value,
)
dim: int
root_op: str
def final_reduction(value):
use_helper = reduction_type in {"any", "prod"}
module = "triton_helpers" if use_helper else "tl"
if reduction_type in {"max"}:
return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim}, propagate_nan=True)", dim)
return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})", dim)
def final_argreduce(buffer, result_var, value, index):
buffer.splice(
f"""\
_, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
{result_var} = {self.reduction_resize(f'{result_var}_tmp', dim)}
"""
)
def get_reduction_axis():
reduce_dim = self.reduction_dim()
axis_key = self.golden_var_list[::-1][reduce_dim]
reduced_axis = self.range_tree_nodes[axis_key]
return reduced_axis
cache_key = (src_dtype, reduction_type, value)
if cache_key in self.cse.reduction_cache:
return self.cse.reduction_cache[cache_key]
dim = self.reduction_dim()
acc_type = triton_acc_type(src_dtype)
torch_acc_type = upcast_acc_dtype(src_dtype)
result_var: Any = self.cse.newvar(dtype=torch_acc_type)
result_var.mask_vars = {var for var in masks if var[0] != "r"}
cond = f"({' & '.join(masks)}).reshape({dense_size_str})"
def where_cond(tval, fval):
if not cond:
return tval
return TritonKernelOverrides.where(cond, tval, fval)
if self.persistent_reduction:
masked_value = value
if reduction_type in {"argmax", "argmin", "max", "min"}:
if reduction_type == "argmax" or reduction_type == "argmin":
reduce_axis = get_reduction_axis()
broadcast_string: str
reshape_str = self.reduce_analysis.get_reduce_dim_reshape(reduce_axis)
broadcast_string = f"tl.broadcast_to({reduce_axis.symbol()}.reshape({reshape_str}), {masked_value}.shape)"
accumulator_index = str(
self.cse.generate(
self.compute,
broadcast_string,
dtype=torch.int64
)
)
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
final_argreduce(
self.compute, result_var, masked_value, accumulator_index
)
elif reduction_type == "max" or reduction_type == "min":
result_var = self.cse.generate(
self.compute, final_reduction(masked_value), dtype=masked_value.dtype,
)
elif reduction_type == "welford_reduce":
raise RuntimeError("assert False, welford_reduction and is not supported now..")
elif reduction_type == "welford_combine":
raise RuntimeError("assert False, welford_combine and is not supported now..")
else:
result_var = self.cse.generate(
self.compute, final_reduction(masked_value), dtype=masked_value.dtype,
)
else:
accumulator = self.cse.namedvar(f"_{result_var}", dtype=torch_acc_type)
default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
default = self._map_tuple_or_scalar(constant_repr, default)
if not isinstance(default, tuple):
self.prefix.writeline(
f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})"
)
if reduction_type in {"argmax", "argmin"}:
accumulator_index = f"_{result_var}_index"
long_max = torch.iinfo(torch.int64).max
self.prefix.writeline(
f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)"
)
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
self.compute.splice(
f"""\
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
{accumulator}, {accumulator_index}, {value}, {get_reduction_axis().name}
)
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
{accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)}
"""
)
final_argreduce(self.post_loop_store, result_var, accumulator, accumulator_index)
elif is_welford_reduction(reduction_type):
raise RuntimeError("assert False, welford_reduction and is not supported now..")
else:
combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype)
updated = combine_fn(accumulator, value)
self.compute.writeline(
f"{accumulator} = {where_cond(updated, accumulator)}"
)
if src_dtype == torch.bool:
accumulator = f"{accumulator}.to(tl.int8)"
result_type = triton_compute_type(dtype)
self.post_loop_store.writeline(
f"{result_var} = {final_reduction(accumulator)}.to({result_type})"
)
else:
self.post_loop_store.writeline(
f"{result_var} = {final_reduction(accumulator)}"
)
self.cse.reduction_cache[cache_key] = result_var
if isinstance(result_var, tuple):
self.outside_loop_vars |= set(result_var)
else:
self.outside_loop_vars.add(result_var)
return result_var
def load(self, name: str, index: sympy.Expr):
var = self.args.input(name)
original_index = index
store_cache = self.cse.store_cache
if name in store_cache:
result_var = store_cache[name]
return result_var
index_analyze = IndexAnalysis(self, index)
nddma_switch = npu_config.nddma_switch
index_analyze.analyze_index(nddma=nddma_switch)
indirect_indexing = self.is_indirect_indexing(index)
indexing = self.indexing(index, nddma=nddma_switch, block_ptr=True)
has_rindex = indexing.has_rindex()
has_tmpmask = indexing.has_tmpmask()
is_coalesced = any(
i == 1 for i in self.get_strides_of_load(original_index).values()
)
ep = ""
if (
(has_tmpmask or has_rindex)
and V.graph.get_dtype(name) != torch.bool
and indexing.has_mask()
):
other = ", other=0.0"
else:
other = ""
advance_block_ptr = None
append_broadcast = None
dtype = V.graph.get_dtype(name)
if V.graph.is_unspec_arg(name):
line = var
else:
if isinstance(indexing, BlockPtrOptions):
block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
name, var, indexing, other
)
line = f"tl.load({block_ptr}{other}{ep})"
line = triton_reshape(
line, indexing.block_shape, indexing.reshape_suffix
)
elif isinstance(original_index, sympy.Integer):
line = f"tl.load({var} + tl.arange(0,1) + ({original_index}))"
full_list = ["1"] * (len(self.tiling_axis) if self.tiling_axis else 1)
append_broadcast = f"[{', '.join(full_list)} ]"
else:
index_str = indexing.index_str
mask_str = indexing.mask_str
line = f"tl.load({var} + ({index_str}), {mask_str}{ep}{other})"
if has_tmpmask:
load_buffer = self.compute
elif (
self.inside_reduction
and self.range_trees[-1].is_loop
and not indirect_indexing
and not has_rindex
):
load_buffer = self.prefix
else:
load_buffer = self.loads
result_var = self.cse.generate(load_buffer, line, dtype=dtype)
if not (isinstance(result_var, TritonCSEVariable)):
raise RuntimeError("assert isinstance(result_var, TritonCSEVariable)")
result_var.mask_vars = indexing.mask_vars
if append_broadcast and append_broadcast != '[]':
line = f"tl.reshape({result_var}, {append_broadcast})"
result_var = self.cse.generate(load_buffer, line, dtype=dtype)
elif index_analyze.need_permute:
line = f"{result_var}{index_analyze.generate_statement()}"
result_var = self.cse.generate(self.loads, line, dtype=dtype)
if advance_block_ptr:
load_buffer.writeline(advance_block_ptr)
if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex):
self.outside_loop_vars.add(result_var)
return result_var
def prepare_indexing(
self,
index: sympy.Expr,
index_analyze,
is_index_expr=False,
nddma=False
):
index = sympy_subs(index, V.graph.sizevars.precomputed_replacements)
if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)):
index = index.subs(V.graph.sizevars.precomputed_replacements)
if len(index.atoms(sympy.ceiling)):
for a in index.atoms(sympy.ceiling):
symbols = a.free_symbols
if len(symbols) > 0 and all(
symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE))
for s in symbols
):
replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)}
index = sympy_subs(index, replacements)
simp_index = index
simp_index = (
simp_index if not isinstance(simp_index, Identity) else simp_index.args[0]
)
index_analyze.analyze_index(nddma)
return self.codegen_indexing(simp_index)
def replace_index_vars(self, index, index_analyze):
new_index = index
if index_analyze.nddma_var_replacements:
new_index = sympy_subs(index, index_analyze.nddma_var_replacements)
elif not index_analyze.processed_nddma and index_analyze.var_replacements:
new_index = sympy_subs(index, index_analyze.var_replacements)
return new_index
def index_to_str(self, index: sympy.Expr) -> str:
if isinstance(index, list):
return f"[{', '.join(map(self.index_to_str, index))}]"
index = self.rename_indexing(index)
return self.kexpr(index)
def indexing(
self,
index: sympy.Expr,
*,
copy_shape=None,
dense_indexing=False,
override_mask=None,
nddma=False,
block_ptr=False,
index_analyze=None,
is_index_expr=False
) -> Union[IndexingOptions, BlockPtrOptions]:
"""
Compute the index and mask to pass to tl.load() or tl.store()
"""
if not index_analyze:
index_analyze = IndexAnalysis(self, index, is_index_expr=is_index_expr)
index_analyze.analyze_index(nddma)
index = self.prepare_indexing(index, index_analyze, is_index_expr, nddma=nddma)
index_vars = index.free_symbols
has_rindex = False
index = sympy_subs(index, V.graph.sizevars.precomputed_replacements)
if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)):
index = index.subs(V.graph.sizevars.precomputed_replacements)
if len(index.atoms(sympy.ceiling)):
for a in index.atoms(sympy.ceiling):
symbols = a.free_symbols
if len(symbols) > 0 and all(
s.name.startswith("s") or s.name.startswith("ps") for s in symbols
):
replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)}
index = sympy_subs(index, replacements)
index = self.replace_index_vars(index, index_analyze)
index_vars = index.free_symbols
has_rindex = False
mask_vars: Set[str] = set()
for var in index_vars:
if not (isinstance(var, sympy.Symbol)):
raise RuntimeError("assert isinstance(var, sympy.Symbol)")
has_rindex = has_rindex or var.name.startswith("r")
if override_mask:
pass
elif var.name.startswith("tmp"):
cse_var = self.cse.varname_map[var.name]
mask_vars.update(cse_var.mask_vars)
elif var.name.startswith(("s", "ps", "i")):
pass
else:
mask_vars.add(f"{var.name}_mask")
expand_str = None
index_str = self.index_to_str(index)
if isinstance(index, sympy.Integer):
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
if (index != 0):
if npu_config.is_ascend950:
index_str = f"tl.full({expand_str}, {index_str}, {V.kernel.index_dtype})"
else:
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
else:
index_str = f"tl.arange(0,1)"
return IndexingOptions(index_str, OrderedSet(), expand_str, has_rindex, index)
if override_mask:
mask_vars = {override_mask}
if self._load_mask:
mask_vars.add(self._load_mask)
self.filter_masks(mask_vars, index_vars)
return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index)
def codegen_indexing(self, expr: sympy.Expr):
expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
for sym in sorted(expr.free_symbols, key=str):
if sym in self.range_tree_nodes:
replacements = {}
for ps in self.range_tree_nodes[sym].precomputed_args():
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
if len(replacements) > 0:
self.range_tree_nodes[sym].expr = sympy_subs(
self.range_tree_nodes[sym].expr, replacements
)
self.range_tree_nodes[sym].codegen()
return expr
def split_and_set_ranges(self, lengths: Sequence[Sequence[sympy.Expr]]):
groups = [rt.numel for rt in self.range_trees]
if not self.inside_reduction:
groups[-1] = sympy.S.One
return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges)
@staticmethod
def _split_iteration_ranges(
groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]
):
sv = V.graph.sizevars
new_ranges: List[List[sympy.Expr]] = [[] for _ in groups]
remaining = [sv.simplify(g) for g in groups]
for i, group in enumerate(remaining):
if isinstance(group, (list, tuple)):
remaining[i] = NumelList(group).numels()
var_count = itertools.count()
def add_range(i, expr):
expr = sv.simplify(expr)
if sv.statically_known_equals(remaining[i], expr):
pass
elif not sv.statically_known_multiple_of(remaining[i], expr):
raise CantSplit
remaining[i] = FloorDiv(remaining[i], expr)
new_ranges[i].append(expr)
return next(var_count)
def make_combined(strides, index_list):
def getter(flat_vars):
expr = sympy.Integer(0)
for stride, index in zip(strides, index_list):
expr = stride * flat_vars[index] + expr
return expr
return getter
def size_hints(group):
if isinstance(group, (list, tuple)):
return sv.size_hint(NumelList(group).numels())
return sv.size_hint(group)
def add_multiple_range(size, return_getters):
index_list = []
stride_list = []
group = current_group
remained_size = size
while (
group < len(remaining)
and sv.statically_known_gt(remaining[group], 1)
and sv.statically_known_gt(remained_size, 1)
):
group_size = remaining[group]
if not sv.statically_known_multiple_of(remained_size, group_size):
raise CantSplit
index_list.append(add_range(group, group_size))
remained_size = FloorDiv(remained_size, group_size)
stride_list.append(remained_size)
group = group + 1
if remained_size != 1:
raise CantSplit
return_getters.append(make_combined(stride_list, index_list))
def safe_size_hint_is_one(expr):
try:
return int(sv.size_hint(expr)) == 1
except (TypeError, ValueError):
return False
return_getters_groups = []
current_group = 0
for length_group in lengths:
return_getters = []
for size in length_group:
if sv.statically_known_equals(size, 1):
return_getters.append(lambda _: sympy.S.Zero)
continue
while (
current_group < len(remaining)
and (
sv.statically_known_equals(remaining[current_group], 1)
or safe_size_hint_is_one(remaining[current_group])
)
):
current_group += 1
if sv.statically_known_gt(size, remaining[current_group]):
add_multiple_range(size, return_getters)
else:
if current_group >= len(remaining):
raise CantSplit
return_getters.append(
operator.itemgetter(add_range(current_group, size))
)
return_getters_groups.append(return_getters)
if not (all(V.graph.sizevars.size_hint(s) == 1 for s in remaining)):
raise RuntimeError("assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining)")
return new_ranges, return_getters_groups
def __enter__(self):
class CSEProxy:
self.name = "CSEProxy"
vr_analysis = ValueRangeAnalysis()
@staticmethod
def __getattr__(name: str) -> Callable[..., CSEVariable]:
def inner(*args, **kwargs):
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
value = getattr(parent_handler, name)(*args, **kwargs)
dtype_handler = DtypePropagationOpsHandler()
output_idx = 0
def do_cse(v):
if V.graph.current_device is not None:
device_str = V.graph.get_current_device_or_throw().type
triton_backend = (
config.cpu_backend == "triton"
if device_str == "cpu"
else config.cuda_backend == "triton"
)
else:
triton_backend = False
if triton_backend:
if name == "masked":
output_dtype = value.dtype
else:
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
else:
output_dtype = None
csevar = V.kernel.cse.generate(
V.kernel.compute,
v,
bounds=bounds,
dtype=output_dtype,
)
nonlocal output_idx
if (
config.test_configs.runtime_triton_dtype_assert
and triton_backend
):
from torch._inductor.codegen.triton import triton_type
if isinstance(output_dtype, (list, tuple)):
output_dtype = output_dtype[output_idx]
V.kernel.compute.writeline(
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
)
output_idx += 1
csevar.update_on_args(name, args, kwargs)
return csevar
return pytree.tree_map(do_cse, value)
return inner
@staticmethod
def _bound_variable(name, *args, **kwargs):
"""
If the variable comes from an FX node, we forward the bound we have already computed
Else, if the variable when codegen'ing another op, we try to compute its bounds
"""
from torch._inductor.select_algorithm import TritonTemplateKernel
if isinstance(V.kernel, TritonTemplateKernel):
return ValueRanges.unknown()
fx_node = V.interpreter.current_node
if fx_node.target == name and self.node_to_bounds is not None:
if not (isinstance(self.node_to_bounds, dict)):
raise RuntimeError("assert isinstance(self.node_to_bounds, dict)")
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
if any(
s in fx_node.target
for s in ("set_indirect", "reduction", "scan")
):
return ValueRanges.unknown()
if (kwargs):
raise RuntimeError("assert not kwargs")
def arg_to_bound(x):
if isinstance(x, CSEVariable):
return x.bounds
elif isinstance(x, sympy.Expr):
return bound_sympy(x)
else:
return x
arg_bounds = list(map(arg_to_bound, args))
return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
return ValueRanges.unknown()
@staticmethod
def indirect_indexing(
var: CSEVariable,
size: Union[sympy.Expr, int],
check: bool = True,
wrap_neg=True,
):
if isinstance(size, int):
size = sympy.Integer(size)
if not (isinstance(size, sympy.Expr)):
raise RuntimeError("assert isinstance(size, sympy.Expr), size")
current_node = V.interpreter.current_node
if current_node and current_node.meta.get("indirect_template", False):
sympy_var = sympy_index_symbol(str(var))
return sympy_var
if var.bounds.lower < 0:
if wrap_neg:
stm = ops.add(var, ops.index_expr(size, torch.long))
if var.bounds.upper >= 0:
lt = ops.lt(var, 0)
stm = ops.where(lt, stm, var)
else:
stm = var
new_bounds = ValueRanges.unknown()
if var.bounds != ValueRanges.unknown() and isinstance(
size, sympy.Number
):
neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
new_bounds = ValueRanges(
neg_bounds.lower + size, neg_bounds.upper + size
)
if var.bounds.upper >= 0:
pos = var.bounds & ValueRanges(0, int_oo)
new_bounds = new_bounds | pos
var = self.cse.generate(self.compute, stm, bounds=new_bounds)
sympy_var = parent_handler.indirect_indexing(var, size, check)
if generate_assert(check):
assert_lower = not (var.bounds.lower >= 0)
assert_upper = not isinstance(size, sympy.Number) or not (
var.bounds.upper < size
)
self.check_bounds(sympy_var, size, assert_lower, assert_upper)
return sympy_var
@staticmethod
def check_bounds(
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
):
return self.check_bounds(expr, size, lower, upper)
@staticmethod
def load(name: str, index: sympy.Expr) -> CSEVariable:
if name in self.cse.invalidated_stores:
V.kernel.must_keep_buffers.add(name)
if free_symbol_is_type(index, SymT.TMP):
return self.indirect_load(name, index)
store_cache = self.cse.store_cache
if name in store_cache:
return self.load(name, index)
out = self.load(name, index)
if out.use_count == 1:
self.num_load += 1
return out
@staticmethod
def _update_store_cache(name: str, value: CSEVariable):
self.cse.store_cache[name] = value
if self.current_node and name in V.graph.name_to_buffer:
buf = self.current_node.get_output(name)
for other_name in buf.get_mutations():
self.cse.store_cache[other_name] = value
@staticmethod
def store(
name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> None:
self.store_buffer_names.add(name)
if mode is None:
CSEProxy._update_store_cache(name, value)
if name not in V.graph.removed_buffers:
return self.store(name, index, value, mode=mode)
return None
@staticmethod
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
self.store_buffer_names.add(name)
CSEProxy._update_store_cache(name, value)
if name not in V.graph.removed_buffers:
return self.store_reduction(name, index, value)
raise RuntimeError("store_reduction")
@staticmethod
def reduction(
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
self.num_reduction += 1
return self.reduction(dtype, src_dtype, reduction_type, value)
@staticmethod
def scan(
dtypes: Tuple[torch.dtype, ...],
combine_fn: Callable[
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
Tuple[CSEVariable, ...],
],
values: Tuple[CSEVariable, ...],
) -> Tuple[CSEVariable, ...]:
return self.scan(dtypes, combine_fn, values)
@staticmethod
def sort(
dtypes: Tuple[torch.dtype, ...],
values: Tuple[CSEVariable, ...],
stable: bool,
descending: bool,
) -> Tuple[CSEVariable, ...]:
return self.sort(dtypes, values, stable, descending)
@staticmethod
def bucketize(
values: CSEVariable,
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: CSEVariable,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[Tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[CSEVariable] = None,
) -> CSEVariable:
return self.bucketize(
values,
boundaries,
boundary_indices,
indexing_dtype,
right,
sorter,
sorter_indices,
)
@staticmethod
def cat_insert_slice(dst: CSEVariable, src: CSEVariable, offset, size, output_size) -> CSEVariable:
insert_offsets = ["0"] * len(self.tiling_axis)
insert_offsets[-1] = str(offset)
insert_sizes = [f"{axis.name.upper()}BLOCK_SUB" for axis in self.golden_var_list[::-1]]
insert_sizes[-1] = str(size)
output_sizes = [f"{axis.name.upper()}BLOCK_SUB" for axis in self.golden_var_list[::-1]]
strides = ["1"] * len(self.tiling_axis)
from torch._inductor.utils import triton_type
for index, line in enumerate(self.compute._lines):
if line == f"{dst} = {output_size}.0":
if src.dtype == torch.bfloat16:
dtype = triton_type(torch.float32)
else:
dtype = triton_type(src.dtype)
self.compute._lines[index] = f"{dst} = tl.zeros(({', '.join(output_sizes)}, ), dtype={dtype})"
break
axis_name = self.golden_var_list[::-1][-1].name
for index, line in enumerate(self.loads._lines):
if "_index_" not in line:
new_line = line.replace(axis_name, f"{axis_name}_index_{size}", 1)
new_line = new_line.replace(f"{axis_name}_mask", f"{axis_name}_index_mask_{size}", 1)
self.loads._lines[index] = new_line
for index, line in enumerate(self.compute._lines):
if "_index_" not in line and 'tl.zeros' not in line:
new_line = line.replace(axis_name, f"{axis_name}_index_{size}", 1)
new_line = new_line.replace(f"{axis_name}_mask", f"{axis_name}_index_mask_{size}", 1)
new_line = new_line.replace(f"{axis_name.upper()}BLOCK_SUB", f"{size}", 1)
self.compute._lines[index] = new_line
suffix = ["None"] * len(self.tiling_axis)
suffix[-1] = ":"
line = f"{axis_name}_index_{size}= tl.arange(0, {size})[{','.join(suffix)}]"
if line not in self.body._lines:
self.body.writeline(line)
mask_line = f"{axis_name}_index_mask_{size}= {axis_name}_index_{size} < {size}"
self.body.writeline(mask_line)
numel_line = f"{axis_name}_index_{size}_numel = {size}"
self.body.writeline(numel_line)
code = f"tl.broadcast_to({src}, [{', '.join(insert_sizes)}])"
src = self.cse.generate(self.compute, code, dtype=src.dtype)
insert_slice = f"extension.insert_slice({dst}, {src}, [{', '.join(insert_offsets)}], " \
f"[{', '.join(insert_sizes)}], [{', '.join(strides)}])"
return self.cse.generate(self.compute, insert_slice, dtype=src.dtype)
@staticmethod
def cat_store(dst: CSEVariable, src: CSEVariable, size, store_offset_index, output_buffer_index) -> CSEVariable:
if self.golden_var_list is None:
indexing = self.indexing(store_offset_index, block_ptr=True)
code = f"tl.store({self.args.output(dst)} + {indexing.index_str}, {src}, None)"
self.stores.writeline(code)
return src
axis = self.range_tree_nodes[self.golden_var_list[::-1][-1]]
axis_name = axis.name
line = f"{axis_name}_index_mask_{size} = {axis_name} < {size}"
if line not in self.indexing_code._lines:
self.indexing_code.writeline(line)
def contains_axis_name(line, axis_name):
words = re.findall(r'\b\w+\b', line)
return any(var for var in words if var == axis_name)
for index, line in enumerate(self.loads._lines):
if "_index_mask_" not in line and contains_axis_name(line, axis_name):
new_line = f"{line[:-1]} & {axis_name}_index_mask_{size})"
new_line = new_line.replace("None & ", "")
self.loads._lines[index] = new_line
for index, line in enumerate(self.compute._lines):
if "load" in line and "_index_mask_" not in line and contains_axis_name(line, axis_name):
if line.endswith(", other=0.0)"):
tail_len = len(", other=0.0)")
new_line = f"{line[:-tail_len]} & {axis_name}_index_mask_{size}, other=0.0)"
else:
new_line = f"{line[:-1]} & {axis_name}_index_mask_{size})"
new_line = new_line.replace("None & ", "")
self.compute._lines[index] = new_line
indexing = self.indexing(store_offset_index, block_ptr=True)
if "None" in indexing.mask_str:
mask = f"{axis_name}_index_mask_{size}"
else:
mask = f"{axis_name}_index_mask_{size} & {indexing.mask_str}"
code = f"tl.store({self.args.output(dst)} + {indexing.index_str}, {src}, {mask})"
code = code.replace(f"{axis_name}_mask &", "", 1)
self.stores.writeline(code)
return src
@staticmethod
def index_select(src_name: str, weight_index: CSEVariable, indirect_var, set_indirect, bound, index_select_type) -> CSEVariable:
npu_config.log.debug(f"index_select: {src_name}, {weight_index}, {indirect_var}, {set_indirect}, {bound}, {index_select_type}")
from torch._inductor.utils import triton_type
def fallback_index_select_load(reason):
npu_config.log.info(f"fallback index_select to tl.load reason: {reason}, bound: {bound}")
new_indirect_var = V.ops.indirect_indexing(indirect_var, bound)
new_index = sympy_subs(weight_index, {sympy_index_symbol(indirect_var.name): sympy_index_symbol(new_indirect_var.name)})
return V.ops.load(src_name, new_index)
def is_correct_weight_index():
if not self.is_indirect_indexing(weight_index):
return False
if index_select_type != 'embedding':
return True
embedding_axis = None
for key, coeff in weight_index.as_coefficients_dict().items():
if isinstance(key, sympy.Integer):
continue
if isinstance(coeff, sympy.Integer) and int(coeff) == 1:
embedding_axis = key
break
if not embedding_axis or embedding_axis not in self.golden_var_list:
return False
if self.golden_var_list.index(embedding_axis) != 0:
return False
return True
def check_output_index(output_index_analyzer):
for var in output_index_analyzer.all_var_list:
if not var.is_Atom:
return False
return True
current_node = V.interpreter.current_node
if current_node and current_node.meta.get("multi_indirect_index", False):
log_info = "ir is multi_indirect_index"
return fallback_index_select_load(log_info)
if not is_correct_weight_index():
log_info = f"{str(weight_index)} not invalid"
return fallback_index_select_load(log_info)
var = self.args.input(src_name)
dtype = V.graph.get_dtype(src_name)
indice_index = self.find_indirect_axis(set_indirect)
if indice_index is None:
log_info = f"{str(set_indirect)} can't find indexing"
return fallback_index_select_load(log_info)
indirect_output_analyzer = IndexAnalysis(self, weight_index)
indirect_output_vars = tuple(reversed(indirect_output_analyzer.all_var_list))
gather_dim = next(
(dim for dim, var in enumerate(indirect_output_vars) if symbol_is_type(var, SymT.TMP)), None)
if gather_dim is None:
log_info = f"gather_dim is None, weight_index: {weight_index}"
return fallback_index_select_load(log_info)
src_stride = tuple(reversed(indirect_output_analyzer.all_stride_list))
if src_stride[-1] != 1:
log_info = f"src_stride: {src_stride} not fit"
return fallback_index_select_load(log_info)
output_index = sympy_subs(weight_index, {sympy_index_symbol(indirect_var.name): indice_index})
output_index_analyzer = IndexAnalysis(self, output_index)
indice_analyzer = IndexAnalysis(self, indice_index)
output_index_analyzer.analyze_index()
indice_analyzer.analyze_index()
if not check_output_index(output_index_analyzer):
log_info = f"check_output_index: {output_index} not fit"
return fallback_index_select_load(log_info)
start_offsets, _, _, _ = self.get_template_offset(indirect_output_analyzer)
_, end_offsets, _, value_shapes = self.get_template_offset(output_index_analyzer)
_, _, _, indice_shape = self.get_template_offset(indice_analyzer)
if isinstance(indice_index, (sympy.Integer, int)):
indice_shape = ["1"]
end_offsets.insert(gather_dim, "1")
value_shapes.insert(gather_dim, "1")
if len(start_offsets) != len(src_stride):
log_info = f"src_stride: {src_stride}, starts_offset: {start_offsets} don't fit"
return fallback_index_select_load(log_info)
if len(end_offsets) != len(indice_shape) + len(start_offsets) - 1:
log_info = f"end_offsets: {end_offsets}, starts_offset: {start_offsets}, value_shapes: {indice_shape} don't fit"
return fallback_index_select_load(log_info)
start_offset_val = ", ".join(start_offsets)
end_offset_val = ", ".join(end_offsets)
shape_val = ",".join(value_shapes)
indice_shape = ", ".join(indice_shape)
indirect_var = self.cse.generate(self.compute,
f"tl.reshape({indirect_var}, ({indice_shape}, ))",
dtype=dtype)
out_triton_type = triton_type(dtype)
if self.contains_cat_node() and index_select_type == "embedding":
shape_val = ",".join(value_shapes[:-1] + [str(src_stride[0])])
out_var = self.cse.generate(self.compute,
f"tl.full(({shape_val}, ), 0, dtype={out_triton_type})",
dtype=dtype)
line = f"extension.custom(\"__builtin_index_select\", {var}, {indirect_var}, dim={gather_dim}, bound={bound}, end_offset=({end_offset_val}, ), start_offset=({start_offset_val}, ), src_stride={src_stride}, out={out_var})"
index_select_var = self.cse.generate(self.compute,
line,
dtype=dtype)
output_shapes = []
for axis_key in reversed(self.golden_var_list):
axis = self.range_tree_nodes[axis_key]
BLOCK_NAME_SUB = f"{axis.name.upper()}BLOCK_SUB"
output_shapes.append(BLOCK_NAME_SUB)
output_shapes_vals = ", ".join(output_shapes)
if self.contains_cat_node() and index_select_type == "embedding":
output_shapes_vals = ",".join(output_shapes[:-1] + [str(src_stride[0])])
line = f"tl.reshape({index_select_var}, ({output_shapes_vals}, ))"
index_select_var = self.cse.generate(self.compute,
line,
dtype=dtype)
return index_select_var
@staticmethod
def gather_template(src_name: str, gather_index: CSEVariable, indirect_var: CSEVariable, set_indirect: str, index_boundary: int):
var = self.args.input(src_name)
dtype = V.graph.get_dtype(src_name)
indice_index = self.find_indirect_axis(set_indirect)
if not (indice_index):
raise RuntimeError(f"assert indirect indice_index is not None")
indice_index_analyzer = IndexAnalysis(self, indice_index)
indirect_output_analyzer = IndexAnalysis(self, gather_index)
indirect_output_vars = tuple(reversed(indirect_output_analyzer.all_var_list))
indirect_dim = next(
(dim for dim, var in enumerate(indirect_output_vars) if symbol_is_type(var, SymT.TMP)), None)
if indirect_dim is None:
raise RuntimeError(f"indirect_dim is None in {gather_index}")
src_stride = tuple(reversed(indirect_output_analyzer.all_stride_list))
axis_shape, tiling_offset, reshape_type, reshape_list = self.get_template_shape_offset(indice_index_analyzer)
axis_shape_val = ", ".join(axis_shape)
tiling_offset_val = ", ".join(tiling_offset)
if reshape_type:
before_gather_shape = ",".join(reshape_list)
line = f"tl.{reshape_type}({indirect_var}, ({before_gather_shape}))"
reshape_var = self.cse.generate(self.compute,
line,
dtype=dtype)
line = f"extension.gather_out_to_ub({var}, {reshape_var}, {index_boundary}, {indirect_dim}, {src_stride}, ({axis_shape_val},), ({tiling_offset_val},))"
gather_var = self.cse.generate(self.compute,
line,
dtype=dtype)
after_gather_shape = ",".join([var for var in reshape_list if var != "1"])
line = f"tl.reshape({gather_var}, ({after_gather_shape}))"
else:
line = f"extension.gather_out_to_ub({var}, {indirect_var}, {index_boundary}, {indirect_dim}, {src_stride}, ({axis_shape_val},), ({tiling_offset_val},))"
return self.cse.generate(self.compute,
line,
dtype=dtype)
@staticmethod
def indexput_template(name: str, output_index: CSEVariable, store_var: CSEVariable, set_indirect: str, boundary: int):
indirect_indexing = self.is_indirect_indexing(output_index)
if not indirect_indexing:
return self.store(name, output_index, store_var, None)
var_ptr = self.args.output(name)
dtype = V.graph.get_dtype(name)
indirect_axis = self.find_indirect_axis(set_indirect)
if not (indirect_axis):
raise RuntimeError(f"assert indirect indirect_axis is not None")
indirect_output_analyzer = IndexAnalysis(self, output_index)
indirect_output_vars = tuple(reversed(indirect_output_analyzer.all_var_list))
indirect_dim, indirect_var = next(((dim, var) for dim, var in enumerate(indirect_output_vars) if symbol_is_type(var, SymT.TMP)), None)
if indirect_dim is None:
raise RuntimeError(f"indirect_dim is None in {output_index}")
output_codegen_index = sympy_subs(output_index, {sympy_index_symbol(indirect_var.name): indirect_axis})
output_index_analyzer = IndexAnalysis(self, output_codegen_index)
dst_stride = tuple(reversed(indirect_output_analyzer.all_stride_list))
start_offset, end_offset, reshape_type, reshape_list = self.get_template_offset(output_index_analyzer)
start_offset_val = ", ".join(start_offset)
end_offset_val = ", ".join(end_offset)
if reshape_type:
before_indexput_shape = ",".join(reshape_list)
line = f"tl.reshape({store_var}, ({before_indexput_shape}))"
store_var = self.cse.generate(self.compute,
line,
dtype=dtype)
line = f"extension.index_put({var_ptr}, {indirect_var}, {store_var}, {indirect_dim}, {boundary}, ({end_offset_val},), ({start_offset_val},), {dst_stride})"
return self.cse.generate(self.compute,
line,
dtype=dtype)
@staticmethod
def scatter_template(name: str, output_index: CSEVariable, store_var: CSEVariable, set_indirect: str, boundary: int):
var = self.args.output(name)
dtype = V.graph.get_dtype(name)
indice_index = self.find_indirect_axis(set_indirect)
if not (indice_index):
raise RuntimeError(f"assert indirect indice_index is not None")
indice_index_analyzer = IndexAnalysis(self, indice_index)
indirect_output_analyzer = IndexAnalysis(self, output_index)
indirect_output_vars = tuple(reversed(indirect_output_analyzer.all_var_list))
indirect_dim, indirect_var = next(((dim, var) for dim, var in enumerate(indirect_output_vars) if symbol_is_type(var, SymT.TMP)), None)
if indirect_dim is None:
raise RuntimeError(f"indirect_dim is None in {output_index}")
dst_stride = tuple(reversed(indirect_output_analyzer.all_stride_list))
axis_shape, tiling_offset, reshape_type, reshape_list = self.get_template_shape_offset(indice_index_analyzer)
tiling_offset_val = ", ".join(tiling_offset)
axis_shape_val = ", ".join(axis_shape)
if reshape_type:
before_scatter_shape = ",".join(reshape_list)
line = f"tl.{reshape_type}({indirect_var}, ({before_scatter_shape}))"
indirect_var = self.cse.generate(self.compute,
line,
dtype=dtype)
line = f"tl.reshape({store_var}, ({before_scatter_shape}))"
store_var = self.cse.generate(self.compute,
line,
dtype=dtype)
line = f"extension.scatter_ub_to_out({var}, {store_var}, {indirect_var}, {boundary}, {indirect_dim}, {dst_stride}, ({axis_shape_val}, ), ({tiling_offset_val}, ))"
return self.cse.generate(self.compute,
line,
dtype=dtype)
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
return h
super().__enter__()
if not (self.overrides):
raise RuntimeError("assert self.overrides")
parent_handler = self.overrides()
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
self.exit_stack.enter_context(V.set_kernel_handler(self))
return self