import copy
import functools
from functools import lru_cache
import importlib
import logging
import dataclasses
import math
import os
import re
import sys
import time
import shutil
import hashlib
import csv
import uuid
import threading
from itertools import count
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union, List
from contextlib import contextmanager
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
from torch._logging import warning_once
import triton
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import dynamo_timed
from torch._inductor import config
from torch._inductor.compile_fx import clone_preserve_strides
from torch._inductor.runtime.autotune_cache import AutotuneCache
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.runtime.runtime_utils import (
create_bandwidth_info_str,
get_num_bytes,
next_power_of_2,
)
from torch._inductor.utils import triton_version_uses_attrs_dict
from torch.utils._ordered_set import OrderedSet
from torch._inductor.runtime.triton_heuristics import (
CachingAutotuner,
HeuristicType,
unique_configs,
hash_configs,
Config,
ASTSource,
_find_names,
get_first_attr,
collected_calls,
_dump_launch_params,
builtins,
NoTritonConfigsError,
TritonCompileResult,
GridExpr,
config_to_dict,
config_from_dict,
FixedGrid,
PrecomputedGrid,
SequentialComboKernelGrid,
PrecomputedGrid,
Grid1D,
Grid2D,
Grid2DWithYZOverflow,
Grid3D
)
from .symbolic_grouping import (
UnsupportedGroupedPlan,
bucketize,
build_group_representatives,
find_primary_feature_index,
is_open_bucket_group,
)
from torch._inductor.runtime.triton_heuristics import (
fixed_config,
user_autotune,
template
)
from torch._inductor.runtime.runtime_utils import triton_hash_to_path_key
from triton.compiler import CompiledKernel
from torch._inductor.triton_bundler import TritonBundler
try:
from triton.backends.compiler import GPUTarget
from triton.runtime.autotuner import OutOfResources
import torch.autograd.profiler as autograd_profiler
except ImportError:
GPUTarget = None
OutOfResources = None
autograd_profiler = None
import torch_npu
from torch_npu.utils._error_code import ErrCode, pta_error
from torch_npu._inductor.npu_compare import check_accuracy_triton
from ..codegen.tile_generator import TileGenerator
from ..codegen.triton_utils import NPUKernelType
from ..config import log
from .. import config as npu_config
from ..profiler import simple_trace_handler, mspti_batch_benchmark
kernel_idx = count()
class CompileThreadPool:
def __init__(self):
self.pool = ThreadPoolExecutor(max_workers=npu_config.max_precompiled_thread_num)
self.warmup()
def warmup(self):
event = threading.Event()
def worker():
event.wait()
tasks = []
for _ in range(npu_config.max_precompiled_thread_num):
tasks.append(self.submit(worker))
event.set()
for future in tasks:
future.result()
def submit(self, fn, *args, **kwargs):
return self.pool.submit(fn, *args, **kwargs)
compile_thread_pool = CompileThreadPool()
@contextmanager
def create_profiler(torch_path, wait=0, warmup=1, active=1, repeat=1, skip_first=1):
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
profiler_level=torch_npu.profiler.ProfilerLevel.Level0, )
profile_path = torch_path
with torch_npu.profiler.profile(
activities=[torch_npu.profiler.ProfilerActivity.NPU],
record_shapes=False,
profile_memory=False,
with_stack=False,
schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first),
on_trace_ready=simple_trace_handler(profile_path),
experimental_config=experimental_config) as prof:
yield prof
def delete_file_base(base_path):
if os.path.exists(base_path):
shutil.rmtree(base_path)
def read_device_time(torch_path, triton_only=True, return_list=True):
for root, _, files in os.walk(torch_path):
for file in files:
if file != 'kernel_details.csv':
continue
target_file = os.path.join(root, file)
with open(target_file, newline='') as csvfile:
durations = []
reader = csv.DictReader(csvfile)
for row_read in reader:
durations.append(float(row_read['Duration(us)']))
if return_list:
return durations
ret = sum(durations) / len(durations)
return ret
delete_file_base(torch_path)
raise RuntimeError(f"Could not find kernel_details.csv from dir {torch_path}")
def _summarize_statistics(times, quantiles, return_mode):
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
if return_mode == "all":
return times.tolist()
return getattr(torch, return_mode)(times).item()
def do_bench_using_profiling_npu(fn, warmup=2, rep=10, grad_to_none=None, quantiles=None, return_mode="mean"):
if return_mode not in ["min", "max", "mean", "median", "all"]:
raise RuntimeError("return_mode must be one of 'min', 'max', 'mean', 'median', 'all'")
stream = torch.npu.current_stream()
stream.synchronize()
for _ in range(warmup):
fn()
stream.synchronize()
random_uuid = uuid.uuid4().hex
md5_hash = hashlib.md5(random_uuid.encode()).hexdigest()
torch_path = os.path.join(os.getcwd(), "profile_result", f"triton_{md5_hash}")
with create_profiler(torch_path, active=8) as prof:
stream.synchronize()
for _ in range(rep + 10):
fn()
prof.step()
stream.synchronize()
times = read_device_time(torch_path, triton_only=False, return_list=True)
delete_file_base(torch_path)
return _summarize_statistics(torch.tensor(times), quantiles, return_mode)
@dataclasses.dataclass
class GridNpu(GridExpr):
numels: List[str] = None
mode: Literal["python", "cpp"] = "python"
def generate(self, meta: dict[str, int]) -> None:
numel_args = []
split_axis = meta.get("split_axis", None)
split_blocks = meta.get("split_blocks", None)
if split_axis is None or split_blocks is None:
raise RuntimeError(
f"Could not get split_axis or split_blocks from meta {meta}."
)
def grid_fn(i):
if i >= len(split_axis):
return "1"
axis = split_axis[i]
block = split_blocks[i]
if block is None or block == 1:
return self.numels[axis]
if self.mode == "python":
return f"({self.numels[axis]} + {block} - 1) // {block}"
else:
return f"(({self.numels[axis]} + ({block} - 1)) / ({block}))"
self.x_grid = grid_fn(0)
self.y_grid = grid_fn(1)
self.z_grid = grid_fn(2)
def is_namedtuple_isinstance(obj):
return (
isinstance(obj, tuple) and
hasattr(obj, '_fields') and
hasattr(obj, '_asdict') and
callable(getattr(obj, '_asdict'))
)
class GridExprNpu(GridExpr):
@staticmethod
def from_meta_and_set_numel(
inductor_meta: dict[str, Any],
cfg: Union[Config, dict[str, int]],
numels: List[str],
mode: Literal["python", "cpp"] = "python",
) -> GridExpr:
grid_type = inductor_meta["grid_type"]
grid_cls = globals().get(grid_type)
if issubclass(grid_cls, GridNpu):
grid = grid_cls(inductor_meta=inductor_meta, mode=mode, numels=numels)
else:
grid = grid_cls(inductor_meta=inductor_meta, mode=mode)
if isinstance(cfg, Config):
cfg = config_to_dict(cfg)
grid.generate(cfg)
return grid
@staticmethod
def from_grouped_meta_and_numel(
inductor_meta: dict[str, Any],
cfg: Union[Config, dict[str, int]],
numels: List[str],
runtime_block_names: tuple[str, ...],
mode: Literal["python", "cpp"] = "python",
) -> GridExpr:
grid_type = inductor_meta["grid_type"]
grid_cls = globals().get(grid_type)
if not issubclass(grid_cls, GridNpu):
raise RuntimeError(
f"grouped grid path expects GridNpu-compatible grid, got {grid_type}"
)
grid = grid_cls(inductor_meta=inductor_meta, mode=mode, numels=numels)
axis_names = tuple(inductor_meta.get("axis_names", ()))
runtime_axes = []
for block_name in runtime_block_names:
axis_name = block_name.removesuffix("BLOCK").lower()
if axis_name not in axis_names:
raise RuntimeError(
f"grouped grid path could not resolve runtime block axis {axis_name}"
)
runtime_axes.append(axis_names.index(axis_name))
def grouped_grid_fn(i):
if i >= len(runtime_axes):
return "1"
axis = runtime_axes[i]
block = runtime_block_names[i]
if block is None or block == 1:
return numels[axis]
if mode == "python":
return f"({numels[axis]} + {block} - 1) // {block}"
return f"(({numels[axis]} + ({block} - 1)) / ({block}))"
grid.x_grid = grouped_grid_fn(0)
grid.y_grid = grouped_grid_fn(1)
grid.z_grid = grouped_grid_fn(2)
return grid
class TritonCompileResultNpu(TritonCompileResult):
def make_launcher(self):
cfg = self.config
compile_meta = self.compile_meta
binary = self.kernel
fn = binary.src.fn
binary._init_handles()
known_constants = OrderedSet(
arg for i, arg in enumerate(fn.arg_names) if i in fn.constexprs
)
none_args = OrderedSet(
k
for k, v in compile_meta["constants"].items()
if v is None and k not in known_constants
)
none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys()))
if triton_version_uses_attrs_dict():
call_args = fn.arg_names
def_args = fn.arg_names
if (
"num_warps" in compile_meta["constants"]
or "num_stages" in compile_meta["constants"]
):
def_args = [
arg
for arg in def_args
if arg not in ("num_warps", "num_stages")
]
repl = {
k: str(compile_meta["constants"].get(k))
for k in ("num_warps", "num_stages")
}
call_args = [repl.get(arg, arg) for arg in call_args]
else:
call_args = [
arg
for i, arg in enumerate(fn.arg_names)
if i not in fn.constexprs and arg not in none_args
]
cfg_dict = config_to_dict(cfg)
def_args = [
name
for name in fn.arg_names
if name not in cfg_dict and name not in none_args
]
if self.inductor_meta.get("group_enabled", False):
runtime_block_names = tuple(
self.inductor_meta.get("runtime_block_append_order", ())
)
missing_runtime_block_names = [
name
for name in runtime_block_names
if name not in def_args or name not in call_args
]
if missing_runtime_block_names:
raise RuntimeError(
"Grouped launcher is missing runtime block args: "
f"{missing_runtime_block_names}"
)
binary_shared = (
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
)
if is_namedtuple_isinstance(binary.packed_metadata):
binary.packed_metadata = binary.packed_metadata._asdict()
scope = {
"grid_meta": cfg.kwargs,
"bin": binary,
"launch_enter_hook": binary.__class__.launch_enter_hook,
"launch_exit_hook": binary.__class__.launch_exit_hook,
"metadata": (
binary.packed_metadata
if hasattr(binary, "packed_metadata")
else binary.metadata
),
"shared": binary_shared,
"num_warps": (
binary.num_warps
if hasattr(binary, "num_warps")
else binary.metadata.num_warps
),
"cta_args": (
(
binary.num_ctas,
*get_first_attr(binary, "cluster_dims", "clusterDims"),
)
if hasattr(binary, "num_ctas")
else (
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
if hasattr(binary, "metadata")
else ()
)
),
"function": get_first_attr(binary, "function", "cu_function"),
"runner": get_first_attr(binary, "run", "c_wrapper"),
}
if not hasattr(binary, "launch_metadata"):
runner_args = [
"grid_0",
"grid_1",
"grid_2",
"num_warps",
"*cta_args",
"shared",
"stream",
"function",
"launch_enter_hook",
"launch_exit_hook",
"metadata",
*call_args,
]
else:
if binary.__class__.launch_enter_hook:
launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})"
else:
launch_metadata = "None"
runner_args = [
"grid_0",
"grid_1",
"grid_2",
"stream",
"function",
"metadata",
launch_metadata,
"launch_enter_hook",
"launch_exit_hook",
*call_args,
]
if "extra_launcher_args" in self.inductor_meta:
def_args = [*def_args, *self.inductor_meta["extra_launcher_args"]]
numels = [
arg
for arg in fn.arg_names
if "_numel" in arg
]
linear_mode = self.inductor_meta.get('inductor_ascend_linear_mode', 'no_linear')
grid = None
if linear_mode == 'no_linear':
grid = GridExpr.from_meta(self.inductor_meta, cfg)
elif self.inductor_meta.get("group_enabled", False):
runtime_block_names = tuple(
self.inductor_meta.get("runtime_block_arg_names", ())
)
grid = GridExprNpu.from_grouped_meta_and_numel(
self.inductor_meta,
cfg,
numels,
runtime_block_names=runtime_block_names,
)
else:
grid = GridExprNpu.from_meta_and_set_numel(self.inductor_meta, cfg, numels)
lines = [
f"def launcher({', '.join(def_args)}, stream):",
*[f" {line}" for line in grid.prefix],
f" grid_0 = {grid.x_grid}",
f" grid_1 = {grid.y_grid}",
f" grid_2 = {grid.z_grid}",
f" runner({', '.join(runner_args)})",
]
exec("\n".join(lines), scope)
launcher = scope["launcher"]
launcher.config = cfg
launcher.runnable = True
launcher.n_regs = getattr(binary, "n_regs", None)
launcher.n_spills = getattr(binary, "n_spills", None)
launcher.shared = binary_shared
launcher.store_cubin = self.inductor_meta.get("store_cubin", False)
if launcher.store_cubin:
launcher.fn = fn
launcher.bin = binary
if triton_version_uses_attrs_dict():
cfg_dict = config_to_dict(cfg)
def_args = [x for x in def_args if x not in cfg_dict]
call_args = [
x
for x in call_args
if compile_meta["signature"].get(x, "constexpr") != "constexpr"
and x not in none_args
]
launcher.def_args = def_args
launcher.call_args = call_args
return launcher
class NPUCachingAutotuner(CachingAutotuner):
def __init__(
self,
fn,
triton_meta,
configs,
save_cache_hook,
mutated_arg_names: List[str],
optimize_mem,
heuristic_type,
size_hints=None,
inductor_meta=None,
custom_kernel=False,
filename: Optional[str] = None,
reset_to_zero_arg_names: Optional[List[str]] = None,
):
super().__init__(fn, triton_meta, configs, save_cache_hook, mutated_arg_names, optimize_mem, heuristic_type,
size_hints, inductor_meta, custom_kernel, filename, reset_to_zero_arg_names)
self.exceptions = []
self.fn_name = None
self._costmodel_runtime_args = ()
self._costmodel_runtime_kwargs = {}
self._costmodel_fallback_configs = None
def precompile(
self,
warm_cache_only=False,
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
):
runtime_args, runtime_kwargs = self._resolve_costmodel_runtime_inputs()
self._apply_costmodel_to_configs(*runtime_args, **runtime_kwargs)
if warm_cache_only:
self.kernel_name = self.get_fn_name()
self._precompile_worker()
return
with self.lock:
if reload_kernel is not None:
self._reload_kernel = reload_kernel
self._precompile_worker()
self._make_launchers()
def _make_ttir_module_from_cfg(self, cfg):
"""Compile one config to TTIR module only (no backend lowering/launch)."""
compile_meta = copy.deepcopy(self.triton_meta)
cfg_kwargs = cfg.kwargs
constexpr_arg_names = {self.fn.arg_names[i] for i in self.fn.constexprs}
for k, v in cfg_kwargs.items():
if k not in constexpr_arg_names:
continue
compile_meta["constants"][k] = v
for i in self.fn.constexprs:
arg_name = self.fn.arg_names[i]
if arg_name not in compile_meta["constants"] and (
arg_name == "num_warps" or arg_name == "num_stages"
):
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
compile_meta["debug"] = (
os.getenv("INDUCTOR_ASCEND_DEBUG", 'false').lower() in ('true', '1')
and self.inductor_meta.get("assert_indirect_indexing", True)
and not self.inductor_meta.get("is_hip", False)
)
compile_meta['compile_mode'] = cfg_kwargs.get('compile_mode')
compile_meta["device_type"] = self.device_props.type
compile_meta["cc"] = self.device_props.cc
if not ASTSource:
raise RuntimeError(
"Installed triton version too old, please upgrade")
src_attrs = compile_meta["configs"][0] if compile_meta.get(
"configs", None) else None
src = ASTSource(self.fn, compile_meta["signature"],
compile_meta["constants"], src_attrs)
cc_warp_size = 32
target = GPUTarget(
compile_meta["device_type"],
compile_meta["cc"],
cc_warp_size,
)
options = {
"num_warps": compile_meta["num_warps"],
"num_stages": compile_meta["num_stages"],
"debug": compile_meta["debug"],
"multibuffer": cfg_kwargs.get('multibuffer', False),
"compile_mode": compile_meta['compile_mode'],
"enable_vf_fusion": cfg_kwargs.get('enable_vf_fusion', False),
}
if compile_meta['compile_mode'] == NPUKernelType.SIMT_ONLY.compile_mode():
options['simt_stack_limit'] = npu_config.simt_default_warp_stacksize
from triton.compiler.code_generator import ast_to_ttir
from triton.compiler.compiler import make_backend
from triton._C.libtriton import ir
from triton._C.libtriton.ascend import ir as ascend_ir
backend = make_backend(target)
extra_options = src.parse_options()
options = backend.parse_options(dict(options or dict(), **extra_options))
context = ir.context()
ir.load_dialects(context)
ascend_ir.load_dialects(context)
return ast_to_ttir(self.fn, src, context, options, {}, {})
def _build_costmodel_runtime_args_from_size_hints(self):
"""Best-effort fallback runtime args when precompile is called without run()."""
size_hints = getattr(self, "size_hints", None)
if size_hints is None:
return ()
hint_map = {}
if isinstance(size_hints, dict):
hint_map.update(size_hints)
elif isinstance(size_hints, (list, tuple)):
axis_names = []
if isinstance(getattr(self, "inductor_meta", None), dict):
axis_names = list(self.inductor_meta.get("axis_names", []) or [])
if axis_names and len(axis_names) == len(size_hints):
for axis, val in zip(axis_names, size_hints):
hint_map[axis] = val
hint_map[f"{axis}_numel"] = val
signature_names = list(self.triton_meta.get("signature", {}).keys())
if not signature_names:
return ()
runtime_args = []
for name in signature_names:
key = str(name)
if key.endswith("_numel"):
val = hint_map.get(key)
if val is None:
val = hint_map.get(key[:-6])
int_val = self._try_parse_int_like(val)
runtime_args.append(int_val if int_val is not None else 1)
else:
runtime_args.append(object())
return tuple(runtime_args)
def _resolve_costmodel_runtime_inputs(self):
runtime_args = getattr(self, "_costmodel_runtime_args", ())
runtime_kwargs = getattr(self, "_costmodel_runtime_kwargs", {})
if runtime_args or runtime_kwargs:
return runtime_args, runtime_kwargs
fallback_args = self._build_costmodel_runtime_args_from_size_hints()
if fallback_args:
return fallback_args, {}
return (), {}
def _try_parse_int_like(self, value):
if isinstance(value, bool):
return int(value)
if isinstance(value, int):
return value
try:
return int(value)
except Exception:
return None
def _build_ttir_arg_value_map(self, ttir_text, runtime_args, runtime_kwargs=None):
"""Map TTIR arg id -> runtime value.
Primary mapping uses frontend signature names (positional + kwargs) so callers
can pass named args without relying on positional order only.
"""
if runtime_kwargs is None:
runtime_kwargs = {}
m = re.search(r"tt\.func\s+public\s+@\w+\((.*?)\)\s+attributes", ttir_text, re.S)
if m is None:
return {}
ttir_arg_ids = [int(x) for x in re.findall(r"%arg(\d+)\s*:", m.group(1))]
if not ttir_arg_ids:
return {}
signature_names = list(self.triton_meta.get("signature", {}).keys())
name_to_value = {}
for idx, name in enumerate(signature_names):
if idx < len(runtime_args):
name_to_value[name] = runtime_args[idx]
for name, value in runtime_kwargs.items():
if name in name_to_value or name in signature_names:
name_to_value[name] = value
arg_value_map = {}
for pos, arg_id in enumerate(ttir_arg_ids):
if pos < len(signature_names):
name = signature_names[pos]
if name in name_to_value:
arg_value_map[arg_id] = name_to_value[name]
continue
if pos < len(runtime_args):
arg_value_map[arg_id] = runtime_args[pos]
return arg_value_map
def _build_costmodel_arg_bindings(self, ttir_text, runtime_args, runtime_kwargs=None):
"""Build costmodel arg-bindings string from TTIR arg ids and runtime values."""
try:
arg_value_map = self._build_ttir_arg_value_map(ttir_text, runtime_args, runtime_kwargs)
if not arg_value_map:
return ""
bindings = []
for arg_id in sorted(arg_value_map.keys()):
int_value = self._try_parse_int_like(arg_value_map[arg_id])
if int_value is None:
continue
bindings.append(f"arg{arg_id}={int_value}")
bindings.append("pid_x=0")
if "tt.get_num_programs x" in ttir_text:
num_programs_x = None
if runtime_kwargs:
if "num_programs_x" in runtime_kwargs:
num_programs_x = self._try_parse_int_like(runtime_kwargs.get("num_programs_x"))
elif "grid" in runtime_kwargs:
grid = runtime_kwargs.get("grid")
if isinstance(grid, (tuple, list)) and len(grid) > 0:
num_programs_x = self._try_parse_int_like(grid[0])
if num_programs_x is None:
num_programs_x = 1
bindings.append(f"num_programs_x={num_programs_x}")
return ",".join(bindings)
except Exception:
return ""
def _build_costmodel_items(self, runtime_args, runtime_kwargs=None):
"""Build config->TTIR payloads for costmodel evaluation."""
items = []
for cfg in self.configs:
try:
ttir_module = self._make_ttir_module_from_cfg(cfg)
ttir_text = str(ttir_module)
arg_bindings = self._build_costmodel_arg_bindings(ttir_text, runtime_args, runtime_kwargs)
items.append({
"config": cfg,
"ttir": ttir_text,
"arg_bindings": arg_bindings,
})
except Exception:
items.append({"config": cfg, "ttir": "", "arg_bindings": ""})
return items
def _triton_make_ttir(self):
if not self.configs:
raise NoTritonConfigsError("No triton configs are available")
def make_ttir_from_cfg(cfg):
"""Ahead of time compile a given autotuner config."""
try:
ttir_module = self._make_ttir_module_from_cfg(cfg)
log.debug(
"Triton precompile success: %s\n%s\nmodule: %s",
self.inductor_meta.get("kernel_name", "triton_"),
self.fn.src,
str(ttir_module),
)
except Exception as e:
log.debug(
"Triton precompile failed: %s\n%s",
self.inductor_meta.get("kernel_name", "triton_"),
self.fn.src,
)
raise e
return ttir_module
compile_results = []
exc = None
exc_stack = ""
test_config = self.configs[0]
try:
compile_results.append(make_ttir_from_cfg(test_config))
except Exception as e:
import traceback
exc_stack = traceback.format_exc()
exc = e
if len(compile_results) == 0:
raise NoTritonConfigsError(
f"Failed to compile config {test_config} for make ttir {self.fn.__name__}: {exc} \nStack trace:{exc_stack}"
)
self.compile_results = compile_results
def _apply_costmodel_to_configs(self, *args, **kwargs):
"""Use triton-ascend costmodel path to prefilter configs before full compile."""
self._costmodel_fallback_configs = None
if not self.configs or len(self.configs) <= 1:
return
if not bool(npu_config.enable_costmodel_backend):
return
costmodel_ratio = float(getattr(npu_config, "costmodel_ratio", 0.25))
if costmodel_ratio <= 0.0 or costmodel_ratio >= 1.0:
return
try:
from triton.backends.ascend.runtime.costmodel_runtime import costmodel_bench
costmodel_items = self._build_costmodel_items(args, kwargs)
costmodel_start_time = time.perf_counter()
costmodel_map = costmodel_bench(costmodel_items)
costmodel_elapsed = time.perf_counter() - costmodel_start_time
log.debug("costmodel_bench elapsed time: %.6fs", costmodel_elapsed)
log.debug("costmodel_bench result: %s", costmodel_map)
except Exception as exc:
log.warning("Skip costmodel prefilter because costmodel path is unavailable: %s", exc)
return
if not isinstance(costmodel_map, dict):
return
ranked_cfgs = [
cfg for cfg, t in sorted(costmodel_map.items(), key=lambda kv: kv[1])
if t != float("inf")
]
target_count = max(1, math.ceil(len(self.configs) * costmodel_ratio))
if len(ranked_cfgs) < target_count:
log.warning("Config count filtered by costmodel is less than target count")
if ranked_cfgs:
selected_count = min(target_count, len(ranked_cfgs))
self.configs = ranked_cfgs[:selected_count]
self._costmodel_fallback_configs = ranked_cfgs[selected_count:] or None
def _precompile_configs(self, configs):
if not configs:
raise NoTritonConfigsError("No triton configs are available")
compile_results = []
exc = None
exc_stack = ""
compile_start_time = time.perf_counter()
for c in configs:
try:
compile_results.append(self._precompile_config(c))
except Exception as e:
import traceback
exc_stack = traceback.format_exc()
exc = e
if len(compile_results) == 0:
raise NoTritonConfigsError(
f"No valid triton configs. {type(exc).__name__}: {exc} \nStack trace:{exc_stack}"
)
log.info(f"kernel: {self.get_fn_name()} compile cost time: {time.perf_counter() - compile_start_time}s")
return compile_results
def _precompile_with_costmodel_fallback(self, compile_fn):
primary_configs = self.configs
try:
compile_results = compile_fn(primary_configs)
except NoTritonConfigsError as primary_exc:
fallback_configs = getattr(self, "_costmodel_fallback_configs", None)
if not fallback_configs:
raise
log.warning(
"No valid triton configs from costmodel-selected configs for kernel %s; "
"retrying %d costmodel-filtered configs.",
self.get_fn_name(),
len(fallback_configs),
)
self.configs = fallback_configs
try:
compile_results = compile_fn(fallback_configs)
except NoTritonConfigsError as fallback_exc:
raise NoTritonConfigsError(
f"No valid triton configs from costmodel-selected or fallback configs. "
f"Primary error: {primary_exc}. Fallback error: {fallback_exc}"
) from fallback_exc
self.compile_results = compile_results
self.configs = None
self._costmodel_fallback_configs = None
def _precompile_worker(self):
if self.compile_results:
for result in self.compile_results:
TritonBundler.put(
triton_hash_to_path_key(result.kernel.hash),
self.triton_meta.get("device", 0),
)
return
if self.launchers:
raise AssertionError("Before _precompile_worker, launchers must bt empty")
self._precompile_with_costmodel_fallback(self._precompile_configs)
def parse_triton_ascend_options(self, tiling_kwargs, options):
from triton.backends.ascend.compiler import NPUOptions
for k in NPUOptions.__dataclass_fields__.keys():
if k not in tiling_kwargs:
continue
options[k] = tiling_kwargs[k]
return options
def _precompile_config(self, cfg: Config) -> TritonCompileResultNpu:
"""Ahead of time compile a given autotuner config."""
compile_meta = copy.deepcopy(self.triton_meta)
cfg_kwargs = cfg.kwargs
constexpr_arg_names = {self.fn.arg_names[i] for i in self.fn.constexprs}
for k, v in cfg_kwargs.items():
if k not in constexpr_arg_names:
continue
compile_meta["constants"][k] = v
for i in self.fn.constexprs:
arg_name = self.fn.arg_names[i]
if arg_name not in compile_meta["constants"] and (
arg_name == "num_warps" or arg_name == "num_stages"
):
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
compile_meta["debug"] = (
os.getenv("INDUCTOR_ASCEND_DEBUG", 'false').lower() in ('true', '1')
and self.inductor_meta.get("assert_indirect_indexing", True)
and not self.inductor_meta.get("is_hip", False)
)
compile_meta['compile_mode'] = cfg_kwargs.get('compile_mode')
compile_meta["device_type"] = self.device_props.type
compile_meta["cc"] = self.device_props.cc
if not ASTSource:
raise RuntimeError("Installed triton version too old, please upgrade")
if compile_meta.get("configs", None):
compile_args = (
ASTSource(
self.fn,
compile_meta["signature"],
compile_meta["constants"],
compile_meta["configs"][0],
),
)
else:
compile_args = (
ASTSource(
self.fn,
compile_meta["signature"],
compile_meta["constants"],
),
)
cc_warp_size = 32
target = GPUTarget(
compile_meta["device_type"],
compile_meta["cc"],
cc_warp_size,
)
options = {
"num_warps": compile_meta["num_warps"],
"num_stages": compile_meta["num_stages"],
"debug": compile_meta["debug"],
"compile_mode": compile_meta['compile_mode'],
}
options = self.parse_triton_ascend_options(cfg_kwargs, options)
if compile_meta['compile_mode'] == NPUKernelType.SIMT_ONLY.compile_mode():
options['simt_stack_limit'] = npu_config.simt_default_warp_stacksize
compile_kwargs = {
"target": target,
"options": options,
}
start_time = 0
if log.isEnabledFor(logging.DEBUG):
start_time = time.perf_counter()
try:
binary = None
binary = triton.compile(*compile_args, **compile_kwargs)
required_ub_bits = binary.metadata.required_ub_bits
cfg.real_ub_size = required_ub_bits
except AttributeError as e:
log.debug(f"config: {cfg_kwargs} compile failed, cost time: {time.perf_counter() - start_time}s")
if binary is None:
raise Exception("Triton compilation failed") from e
return TritonCompileResultNpu(binary, cfg, compile_meta, self.inductor_meta)
except Exception:
log.debug(
"Triton compilation failed: %s\n%s\nmetadata: %s",
self.inductor_meta.get("kernel_name", "triton_"),
self.fn.src,
compile_meta,
)
log.debug(f"config: {cfg_kwargs} compile failed, cost time: {time.perf_counter() - start_time}s")
import traceback
ts = traceback.format_exc()
match = re.search(r"ub overflow.*?requires (\d+) bits", ts)
if match:
required_ub_bits = int(match.group(1))
cfg.real_ub_size = required_ub_bits
raise
log.debug(f"config: {cfg_kwargs} compile success, cost time: {time.perf_counter() - start_time}s")
return TritonCompileResultNpu(binary, cfg, compile_meta, self.inductor_meta)
def _make_launchers(self):
if len(self.launchers) == len(self.compile_results):
return
from torch._dynamo.device_interface import DeviceGuard
device_interface = self.get_device_interface()
with DeviceGuard(device_interface, self.triton_meta["device"]):
device_interface.synchronize(device_interface.current_device())
launchers = []
exc = None
exc_stack = ""
for result in self.compile_results:
try:
launchers.append(result.make_launcher())
except Exception as e:
import traceback
exc_stack = traceback.format_exc()
exc = e
if len(launchers) == 0:
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}\n"
f"Stack trace: {exc_stack}")
self.launchers = launchers
def save_gpu_kernel(self, stream, launcher):
self.save_npu_kernel(stream, launcher)
def save_npu_kernel(self, input_stream, input_launcher):
key = self.inductor_meta.get("kernel_name", None)
if key is None:
raise RuntimeError("assert key is not None, kernel_name can not be None")
params = {
"mangled_name": (
input_launcher.bin.metadata.name
if hasattr(input_launcher.bin.metadata, "name")
else input_launcher.bin.metadata["name"]
),
"num_warps": (
input_launcher.bin.num_warps
if hasattr(input_launcher.bin, "num_warps")
else input_launcher.bin.metadata.num_warps
),
"shared_mem": (
input_launcher.bin.shared
if hasattr(input_launcher.bin, "shared")
else input_launcher.bin.metadata.shared
),
"stream": input_stream,
"meta": input_launcher.config.kwargs,
"config": config_to_dict(input_launcher.config),
"inductor_meta": self.inductor_meta,
"triton_meta": self.triton_meta,
"def_args": input_launcher.def_args,
"call_args": input_launcher.call_args,
"mix_mode": input_launcher.bin.metadata.mix_mode,
"parallel_mode": input_launcher.bin.metadata.parallel_mode,
"force_simt_only": input_launcher.bin.metadata.force_simt_only
}
enable_simt = ("simt" in params["parallel_mode"]) or params["force_simt_only"]
if npu_config.is_ascend950 and enable_simt:
params["shared_mem_dynamic_size"] = input_launcher.bin.metadata.shared_mem_dynamic_size
from torch._inductor.codecache import CudaKernelParamCache
bin_type = "npubin"
binary = input_launcher.bin.asm[bin_type]
CudaKernelParamCache.set(key, params, binary, bin_type='cubin')
self.cuda_kernel_saved = True
def _precompile_configs_parallel(self, configs):
if not configs:
raise NoTritonConfigsError("No triton configs are available")
config_len = len(configs)
compile_exc_results = [None for _ in range(config_len)]
compile_exc_stack_results = ["" for _ in range(config_len)]
def worker(i, kernel_config):
try:
return self._precompile_config(kernel_config)
except Exception as e:
import traceback
compile_exc_stack_results[i] = traceback.format_exc()
compile_exc_results[i] = e
return None
tasks = []
for i, c in enumerate(configs):
task_handler = compile_thread_pool.submit(worker, i, c)
tasks.append(task_handler)
from torch._dynamo.device_interface import DeviceGuard
device_interface = self.get_device_interface()
compile_results = []
with DeviceGuard(device_interface, self.triton_meta["device"]):
device_interface.synchronize(device_interface.current_device())
for future in as_completed(tasks):
compiled_kernel = future.result()
if compiled_kernel is None:
continue
compile_results.append(compiled_kernel)
if len(compile_results) == 0:
for i in range(len(configs)):
configs[i].kwargs["enable_vf_fusion"] = True
tasks = []
for i, c in enumerate(configs):
task_handler = compile_thread_pool.submit(worker, i, c)
tasks.append(task_handler)
with DeviceGuard(device_interface, self.triton_meta["device"]):
device_interface.synchronize(device_interface.current_device())
for future in as_completed(tasks):
compiled_kernel = future.result()
if compiled_kernel is None:
continue
compile_results.append(compiled_kernel)
if len(compile_results) == 0:
raise NoTritonConfigsError(
f"No valid triton configs for kernel {self.get_fn_name()}. {type(compile_exc_results[0]).__name__}: {compile_exc_results[0]} \nStack trace:{compile_exc_stack_results[0]}"
)
return compile_results
def _precompile_worker_parallel(self):
if self.compile_results:
for result in self.compile_results:
TritonBundler.put(
triton_hash_to_path_key(result.kernel.hash),
self.triton_meta.get("device", 0),
)
return
if self.launchers:
raise AssertionError("Before _precompile_worker, launchers must bt empty")
self._precompile_with_costmodel_fallback(self._precompile_configs_parallel)
def bench(self, launcher, *args, with_profiler=False, **kwargs):
"""Measure the performance of a given launcher"""
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
"spill_threshold", 16
):
return float("inf")
device_interface = self.get_device_interface()
stream = device_interface.get_raw_stream(device_interface.current_device())
def kernel_call():
cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
self.reset_to_zero_args(*args, **kwargs)
launcher(
*cloned_args,
**cloned_kwargs,
stream=stream,
)
if self.inductor_meta.get("profile_bandwidth_with_do_bench_using_profiling", False):
return do_bench_using_profiling_npu(kernel_call, rep=1)
return benchmarker.benchmark_gpu(kernel_call, rep=1)
def _should_skip_autotune_for_determinism(self):
"""
When deterministic algorithms are enabled, skip benchmarking for
reduction kernels and use the first successfully compiled config.
All configs are still compiled so that compilation failures on NPU
do not crash the program.
Pointwise kernels are unaffected because their tiling does not change
floating-point numerics.
"""
return (
self.inductor_meta.get("are_deterministic_algorithms_enabled")
and self.heuristic_type in (
HeuristicType.REDUCTION,
HeuristicType.PERSISTENT_REDUCTION,
)
)
def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
start_time = time.time_ns()
timings = self.benchmark_all_configs(*args, **kwargs)
benchmark_time_taken_ns = time.time_ns() - start_time
self.launchers = [builtins.min(timings, key=timings.get)]
self.autotune_time_taken_ns = (
self.precompile_time_taken_ns + benchmark_time_taken_ns
)
if self.save_cache_hook:
self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns)
@lru_cache(None)
def get_fx_graph_dump_path(self):
traced_graph_hash = self.inductor_meta.get("traced_graph_hash")
if not traced_graph_hash:
return None
dump_dir = self.inductor_meta.get("traced_graph_dir", "")
dump_path = os.path.join(dump_dir, traced_graph_hash)
if dump_dir == "" or not os.path.exists(dump_path):
return None
return dump_path
def data_dump(self, *args, dump_path=None):
dump_path = self.get_fx_graph_dump_path() if dump_path is None else dump_path
if dump_path is None:
log.warning(f"data dump for kernel {self.get_fn_name()} failed, no valid dump_path is supplied.")
return False
data_dump_path = os.path.join(dump_path, 'data.pth')
torch.save(args, data_dump_path)
return True
def get_fn_name(self):
if self.fn_name is not None:
return self.fn_name
try:
self.fn_name = self.fn.fn.__name__
except AttributeError:
self.fn_name = "unknown"
if hasattr(self, 'kernel_name'):
self.fn_name = self.kernel_name
return self.fn_name
@functools.lru_cache(None)
def is_run_debug(self):
return npu_config.dump_fx_graph or npu_config.check_accuracy
def maybe_run_debug(self, *args, grid_, stream, launcher, **kwargs):
kernel_name = self.get_fn_name()
log.info(f"Try to run debug mode for kernel {kernel_name}.")
if npu_config.dump_fx_graph:
if torch_npu.npu.is_current_stream_capturing():
raise RuntimeError(
"INDUCTOR_ASCEND_CHECK_ACCURACY / INDUCTOR_ASCEND_DUMP_FX_GRAPH "
"is not compatible with aclgraph.\n"
"Please disable aclgraph before enabling INDUCTOR_ASCEND_CHECK_ACCURACY "
"/ INDUCTOR_ASCEND_DUMP_FX_GRAPH, or unset these environment variables."
)
_ = self.data_dump(*args)
if npu_config.check_accuracy:
if check_accuracy_triton(
*args,
launcher=launcher,
grid=grid_,
stream=stream,
inductor_meta=self.inductor_meta,
**kwargs
):
return "check_accuracy"
log.info(f"No debug mode is activated for kernel {kernel_name}.")
return None
def run(
self, *args, stream, benchmark_run=False, **kwargs
):
if self.triton_interpret:
args, grid = self._interpret_args_grid(args, self.configs[0])
copied_kwargs = copy.copy(self.configs[0].kwargs)
copied_kwargs.pop('split_axis', None)
copied_kwargs.pop('split_blocks', None)
return self.fn[grid](
*args,
**kwargs,
**copied_kwargs,
)
if len(self.launchers) == 1 and hasattr(self.launchers[0], "fallback"):
return self.launchers[0](
*args,
**kwargs,
)
self.autotuner(*args, stream=stream, benchmark_run=benchmark_run, **kwargs)
if not getattr(
self.launchers[0].config, "found_by_coordesc", False
) and self.inductor_meta.get("coordinate_descent_tuning", False):
self.launchers = [
self.coordinate_descent_tuning(
self.launchers[0], *args, **kwargs
)
]
(launcher,) = self.launchers
if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved):
self.save_gpu_kernel(stream, launcher)
if self.dump_launch_params:
_dump_launch_params(args, kwargs, launcher, self.fn.__name__)
if self.is_run_debug() and not self.heuristic_type == HeuristicType.USER_AUTOTUNE:
_, grid = self._interpret_args_grid(args, launcher.config)
debug_mode = self.maybe_run_debug(*args, grid_=grid, stream=stream, launcher=launcher, **kwargs)
if debug_mode:
log.info(f"Kernel {self.get_fn_name()} goes into {debug_mode} and return.")
return
if autograd_profiler._is_profiler_enabled:
with torch._C._profiler._RecordFunctionFast(
self.inductor_meta.get("kernel_name", "triton kernel"),
args,
{
"kernel_file": (self.filename or ""),
"kernel_hash": self.kernel_hash,
"kernel_backend": "triton",
"stream": stream,
},
):
return launcher(
*args,
**kwargs,
stream=stream,
)
else:
return launcher(
*args,
**kwargs,
stream=stream,
)
def autotuner(self, *args, stream, benchmark_run=False, **kwargs):
if len(self.launchers) != 1:
autotune_start_time = time.perf_counter()
if len(self.launchers) == 0:
self._costmodel_runtime_args = args
self._costmodel_runtime_kwargs = kwargs
start_time = time.time_ns()
self.precompile()
self.precompile_time_taken_ns = time.time_ns() - start_time
if len(self.launchers) > 1:
if self._should_skip_autotune_for_determinism():
self.launchers = [self.launchers[0]]
if self.save_cache_hook:
self.save_cache_hook(self.launchers[0].config, 0)
else:
self.autotune_to_one_config(*args, **kwargs)
log.info(f"{self.get_fn_name()} benchmark elapsed time {time.perf_counter() - autotune_start_time}s")
def _interpret_args_grid(
self, args: tuple[Any, ...], cfg: Config, runtime_blocks=None
) -> tuple[tuple[Any, ...], tuple[int, int, int]]:
numels = [
arg
for arg in self.fn.arg_names
if "_numel" in arg
]
runtime_blocks = runtime_blocks or ()
if self.inductor_meta.get("group_enabled", False):
runtime_block_names = tuple(
self.inductor_meta.get("runtime_block_arg_names", ())
)
grid = GridExprNpu.from_grouped_meta_and_numel(
self.inductor_meta,
cfg,
numels,
runtime_block_names=runtime_block_names,
).eval_slow(
dict(
zip(
[
*self.fn.arg_names,
*self.inductor_meta.get("extra_launcher_args", ()),
*runtime_block_names,
],
(*args, *runtime_blocks),
)
)
)
else:
grid = GridExprNpu.from_meta_and_set_numel(self.inductor_meta, cfg, numels).eval_slow(
dict(
zip(
[
*self.fn.arg_names,
*self.inductor_meta.get("extra_launcher_args", ()),
],
args,
)
)
)
if self.inductor_meta.get("extra_launcher_args"):
args = args[: -len(self.inductor_meta["extra_launcher_args"])]
return args, grid
class NPUSymbolicGroupedAutotuner(NPUCachingAutotuner):
def __init__(
self,
fn,
triton_meta,
configs,
save_cache_hook,
mutated_arg_names: List[str],
optimize_mem,
heuristic_type,
size_hints=None,
inductor_meta=None,
custom_kernel=False,
filename: Optional[str] = None,
reset_to_zero_arg_names: Optional[List[str]] = None,
):
super().__init__(
fn,
triton_meta,
configs,
save_cache_hook,
mutated_arg_names,
optimize_mem,
heuristic_type,
size_hints,
inductor_meta,
custom_kernel,
filename,
reset_to_zero_arg_names,
)
self.grouped_plan = self.inductor_meta["grouped_candidate_plan"]
self.variant_order = self.grouped_plan["variant_order"]
self.variant_id_to_index = {
variant_id: idx for idx, variant_id in enumerate(self.variant_order)
}
self.variant_launchers = [None] * len(self.variant_order)
self.group_best = [None] * self.grouped_plan["group_id_count"]
self.reachable_group_ids = tuple(
self.grouped_plan.get(
"reachable_group_ids",
range(self.grouped_plan["group_id_count"]),
)
)
self._grouped_runtime_args_snapshot = ()
self._grouped_variant_launchers_initialized = False
def _precompile_variants(self):
if self._grouped_variant_launchers_initialized:
return
if len(self.launchers) == 0:
start_time = time.time_ns()
self.kernel_name = self.get_fn_name()
if getattr(self, "_precompile_worker_parallel", None) is not None:
self._precompile_worker_parallel()
else:
self._precompile_worker()
self._make_launchers()
self.precompile_time_taken_ns = time.time_ns() - start_time
for idx, variant_id in enumerate(self.variant_order):
variant_config = self.grouped_plan["variants"][variant_id]["config"]
launcher = next(
(
launcher
for launcher in self.launchers
if config_to_dict(launcher.config) == variant_config
),
None,
)
self.variant_launchers[idx] = launcher
self._grouped_variant_launchers_initialized = True
def _runtime_feature_inputs(self, args) -> tuple[int, ...]:
if "feature_arg_indices" not in self.grouped_plan:
raise RuntimeError("grouped plan is missing feature_arg_indices")
feature_arg_indices = tuple(self.grouped_plan["feature_arg_indices"])
feature_sources = tuple(self.grouped_plan.get("feature_sources", ()))
if feature_sources and len(feature_sources) != len(feature_arg_indices):
raise RuntimeError("feature_sources and feature_arg_indices must match")
runtime_feature_inputs = []
for idxs, feature_source in zip(
feature_arg_indices,
feature_sources or ({},) * len(feature_arg_indices),
):
values = tuple(int(args[idx]) for idx in idxs)
source = feature_source.get("source")
if source in ("outer_product", "reduction_product"):
product = 1
for value in values:
product *= value
runtime_feature_inputs.append(product)
else:
if len(values) != 1:
raise RuntimeError(
f"feature source {source} expects one axis, got {values}"
)
runtime_feature_inputs.append(values[0])
return tuple(runtime_feature_inputs)
def _resolve_group_id(self, runtime_feature_inputs: tuple[int, ...]) -> int:
feature_specs = tuple(self.grouped_plan.get("group_features", ()))
if len(runtime_feature_inputs) != len(feature_specs):
raise RuntimeError(
"runtime_feature_inputs and group_features must have the same length"
)
group_id = 0
stride = 1
for value, spec in zip(runtime_feature_inputs, feature_specs):
buckets = tuple(spec["buckets"])
group_id += bucketize(value, buckets) * stride
stride *= len(buckets) + 1
group_id_count = self.grouped_plan.get("group_id_count")
if group_id_count is not None and group_id >= group_id_count:
raise RuntimeError(
f"group_id {group_id} is out of range for group_id_count {group_id_count}"
)
return group_id
def _materialize_runtime_blocks(self, candidate, args) -> tuple[int, ...]:
runtime_block_names = tuple(self.grouped_plan.get("runtime_block_append_order", ()))
policy = self.grouped_plan["policies"][candidate["policy_id"]]
runtime_block_rules = tuple(policy.get("runtime_block_rules", ()))
primary_group_axis = self.inductor_meta.get("primary_group_axis")
if primary_group_axis is not None:
primary_block_name = f"{primary_group_axis.upper()}BLOCK"
if len(runtime_block_rules) > 1:
raise RuntimeError(
"grouped autotune V1 only supports one runtime block rule"
)
if runtime_block_rules and runtime_block_rules[0][0] != primary_block_name:
raise RuntimeError(
"grouped autotune V1 runtime block rule must target primary block "
f"{primary_block_name}"
)
static_blocks = dict(policy.get("static_blocks", ()))
resolved_blocks = dict(static_blocks)
axis_arg_indices = dict(self.grouped_plan.get("axis_arg_indices", {}))
grid_target = int(policy["grid_target"])
for block_name, rule_items in runtime_block_rules:
rule = dict(rule_items)
if rule.get("op") != "ceildiv":
raise RuntimeError(
f"runtime block rule op {rule.get('op')} is not supported"
)
axis_name = rule["axis_name"]
if axis_name not in axis_arg_indices:
raise RuntimeError(f"axis_arg_indices is missing axis {axis_name}")
axis_numel = int(args[axis_arg_indices[axis_name]])
resolved_blocks[block_name] = (axis_numel + grid_target - 1) // grid_target
missing_runtime_block_names = [
name for name in runtime_block_names if name not in resolved_blocks
]
if missing_runtime_block_names:
raise RuntimeError(
"runtime blocks are missing values for "
f"{missing_runtime_block_names}"
)
return tuple(resolved_blocks[name] for name in runtime_block_names)
def _build_runtime_launch_args(self, args, runtime_blocks: tuple[int, ...]):
return (*args, *runtime_blocks)
def _benchmark_feature_inputs_for_group(self, group_id: int) -> tuple[int, ...]:
if "benchmark_feature_inputs_by_group" not in self.grouped_plan:
raise RuntimeError(
"grouped plan is missing benchmark_feature_inputs_by_group"
)
benchmark_inputs = self.grouped_plan["benchmark_feature_inputs_by_group"]
if group_id >= len(benchmark_inputs):
raise RuntimeError(
f"group_id {group_id} is out of range for benchmark_feature_inputs_by_group"
)
return tuple(benchmark_inputs[group_id])
def _materialize_benchmark_args(self, group_id: int):
if "benchmark_axis_values_by_group" not in self.grouped_plan:
raise RuntimeError(
"grouped plan is missing benchmark_axis_values_by_group"
)
benchmark_axis_values = self.grouped_plan["benchmark_axis_values_by_group"]
if group_id >= len(benchmark_axis_values):
raise RuntimeError(
f"group_id {group_id} is out of range for benchmark_axis_values_by_group"
)
axis_env = {
axis_name: int(axis_value)
for axis_name, axis_value in benchmark_axis_values[group_id]
}
def eval_expr(expr):
if isinstance(expr, int):
return expr
if not isinstance(expr, dict):
raise RuntimeError(f"unsupported benchmark expr payload: {expr}")
if "const" in expr:
return int(expr["const"])
if "runtime_arg_index" in expr:
runtime_arg_index = int(expr["runtime_arg_index"])
if runtime_arg_index >= len(self._grouped_runtime_args_snapshot):
raise RuntimeError(
"benchmark expr runtime_arg_index is out of range: "
f"{runtime_arg_index}"
)
return self._grouped_runtime_args_snapshot[runtime_arg_index]
if "axis_name" in expr:
axis_name = expr["axis_name"]
if axis_name not in axis_env:
raise RuntimeError(
f"benchmark axis env is missing axis {axis_name}"
)
return axis_env[axis_name]
if "mul" in expr:
product = 1
for operand in expr["mul"]:
product *= eval_expr(operand)
return product
if "add" in expr:
total = 0
for operand in expr["add"]:
total += eval_expr(operand)
return total
if "floordiv" in expr:
operands = tuple(eval_expr(operand) for operand in expr["floordiv"])
if len(operands) != 2:
raise RuntimeError(
f"floordiv benchmark expr expects 2 operands: {expr}"
)
return operands[0] // operands[1]
raise RuntimeError(f"unsupported benchmark expr payload: {expr}")
def materialize_dtype(dtype):
if isinstance(dtype, str):
dtype_name = dtype.removeprefix("torch.")
return getattr(torch, dtype_name)
return dtype
def workspace_factory(count, device, dtype, zero_mode):
if hasattr(self, "_benchmark_workspace_factory"):
return self._benchmark_workspace_factory(count, device, dtype, zero_mode)
if zero_mode == "ZERO_ON_CALL":
return torch.zeros(count, device=device, dtype=dtype)
return torch.empty(count, device=device, dtype=dtype)
def materialize_spec(spec, kind):
spec_kind = spec.get("kind")
source = spec.get("source")
if source == "runtime_arg":
runtime_arg_index = int(spec["index"])
if runtime_arg_index >= len(self._grouped_runtime_args_snapshot):
raise RuntimeError(
f"{kind} runtime_arg index {runtime_arg_index} is out of range"
)
return self._grouped_runtime_args_snapshot[runtime_arg_index]
if source == "axis_expr":
return eval_expr(spec["expr"])
if spec_kind == "tensor" and source in ("buffer", "constant"):
size = tuple(eval_expr(expr) for expr in spec["size_exprs"])
stride = tuple(eval_expr(expr) for expr in spec["stride_exprs"])
dtype = materialize_dtype(spec["dtype"])
tensor_factory = getattr(self, "_benchmark_tensor_factory", None)
if tensor_factory is None:
tensor_factory = rand_strided
return tensor_factory(size, stride, dtype, spec["device"])
if spec_kind == "workspace" and source == "workspace":
count = eval_expr(spec["count_expr"])
return workspace_factory(
count,
spec["device"],
materialize_dtype(spec["dtype"]),
spec["zero_mode"],
)
raise RuntimeError(f"{kind} spec source {source} is not supported")
return tuple(
materialize_spec(spec, f"{spec.get('kind', 'unknown')} arg")
for spec in self.grouped_plan.get("ordered_arg_specs", ())
)
def _build_grouped_benchmark_entries(self):
entries = []
for group_id, candidates in enumerate(self.grouped_plan["group_to_candidates"]):
if not candidates:
continue
benchmark_args = self._materialize_benchmark_args(group_id)
for candidate in candidates:
variant_index = self.variant_id_to_index[candidate["variant_id"]]
launcher = self.variant_launchers[variant_index]
if launcher is None or not launcher.runnable:
continue
runtime_blocks = self._materialize_runtime_blocks(candidate, benchmark_args)
launch_args = self._build_runtime_launch_args(
benchmark_args, runtime_blocks
)
entries.append(
{
"group_id": group_id,
"candidate": candidate,
"launcher": launcher,
"runtime_blocks": runtime_blocks,
"launch_args": launch_args,
}
)
return tuple(entries)
def _grouped_kernel_args_prefix(self, launch_args):
return launch_args[: len(self.fn.arg_names)]
def _grouped_mspti_batch_benchmark(self, kernel_funcs):
return mspti_batch_benchmark(kernel_funcs, filter_list=["triton"])
def _grouped_profile_batch_benchmark(self, kernel_funcs):
def delete_file(base_path):
if os.path.exists(base_path):
shutil.rmtree(base_path)
stream = torch.npu.current_stream()
random_uuid = uuid.uuid4().hex
md5_hash = hashlib.md5(random_uuid.encode()).hexdigest()
kernel_count = len(kernel_funcs)
autotune_path = os.path.join(os.getcwd(), "profile_result", f"triton_{md5_hash}")
WAIT = 1
WARMUP = 1
ACTIVE = 10
REPEAT = 1
SKIP_FIRST = 1
TOTAL_STEP = (WAIT + WARMUP + ACTIVE + SKIP_FIRST) * REPEAT
with create_profiler(autotune_path, WAIT, WARMUP, ACTIVE, REPEAT, SKIP_FIRST) as prof:
stream.synchronize()
for _ in range(TOTAL_STEP):
for fn in kernel_funcs:
fn()
torch.npu.synchronize()
prof.step()
stream.synchronize()
import pandas as pd
for root, _, files in os.walk(autotune_path):
for file in files:
if file != 'kernel_details.csv':
continue
target_file = os.path.join(root, file)
df = pd.read_csv(target_file)
triton_rows = df[df['Name'].str.startswith('triton', na=False)]
if len(triton_rows) != kernel_count * ACTIVE:
raise RuntimeError(
f"Expected {kernel_count * ACTIVE} rows for triton kernels, "
f"but got {len(triton_rows)}. This may be due to profiling "
f"errors. Please check the profiling result at {target_file} "
"for more details."
)
time_cost = [0] * kernel_count
for kernel_index in range(kernel_count):
for active_index in range(ACTIVE):
row_index = kernel_index + kernel_count * active_index
time_cost[kernel_index] += triton_rows.iloc[row_index]['Duration(us)']
delete_file(autotune_path)
return [cost / ACTIVE for cost in time_cost]
delete_file(autotune_path)
return []
def _bench_grouped_entry(self, entry, **kwargs):
launcher = entry["launcher"]
launch_args = entry["launch_args"]
if not self.custom_kernel and launcher.n_spills is not None and launcher.n_spills > self.inductor_meta.get(
"spill_threshold", 16
):
return float("inf")
device_interface = self.get_device_interface()
stream = device_interface.get_raw_stream(device_interface.current_device())
kernel_args_prefix = self._grouped_kernel_args_prefix(launch_args)
def kernel_call():
cloned_args, cloned_kwargs = self.clone_args(*launch_args, **kwargs)
self.reset_to_zero_args(*kernel_args_prefix, **kwargs)
launcher(
*cloned_args,
**cloned_kwargs,
stream=stream,
)
if self.inductor_meta.get("profile_bandwidth_with_do_bench_using_profiling", False):
return do_bench_using_profiling_npu(kernel_call, rep=1)
return benchmarker.benchmark_gpu(kernel_call, rep=1)
def _benchmark_grouped_entries(self, entries, **kwargs):
if not entries:
return ()
device_interface = self.get_device_interface()
stream = device_interface.get_raw_stream(device_interface.current_device())
kernel_funcs = []
kernel_entry_indices = []
timings = [None] * len(entries)
def kernel_call(entry):
launcher = entry["launcher"]
launch_args = entry["launch_args"]
kernel_args_prefix = self._grouped_kernel_args_prefix(launch_args)
def call_kernel():
if not launcher.runnable:
return
if launcher.config.pre_hook is not None:
pre_hook_args = {
**dict(zip(self.arg_names, launch_args)),
**launcher.config.kwargs,
}
launcher.config.pre_hook(
pre_hook_args
)
cloned_args, cloned_kwargs = self.clone_args(*launch_args, **kwargs)
self.reset_to_zero_args(*kernel_args_prefix, **kwargs)
launcher(
*cloned_args,
**cloned_kwargs,
stream=stream,
)
return call_kernel
for entry_idx, entry in enumerate(entries):
launcher = entry["launcher"]
n_spills = launcher.n_spills
if not self.custom_kernel and n_spills is not None and n_spills > self.inductor_meta.get(
"spill_threshold", 16
):
timings[entry_idx] = float("inf")
continue
kernel_funcs.append(kernel_call(entry))
kernel_entry_indices.append(entry_idx)
if not kernel_funcs:
return tuple(timings)
if not npu_config.aggresive_autotune:
return tuple(
timing
if timing is not None
else self._bench_grouped_entry(entry, **kwargs)
for entry, timing in zip(entries, timings)
)
def apply_batch_timings(batch_timings):
batch_timings = tuple(batch_timings)
if len(batch_timings) != len(kernel_funcs):
raise RuntimeError(
"grouped batch benchmark timing count does not match kernel funcs"
)
for entry_idx, timing in zip(kernel_entry_indices, batch_timings):
timings[entry_idx] = timing
return tuple(timings)
try:
return apply_batch_timings(self._grouped_mspti_batch_benchmark(kernel_funcs))
except Exception as mspti_exc:
log.warning(
"grouped mspti batch benchmark failed, switched to profiler batch: %s",
mspti_exc,
)
try:
return apply_batch_timings(
self._grouped_profile_batch_benchmark(kernel_funcs)
)
except Exception as profile_exc:
log.warning(
"grouped profiler batch benchmark failed, switched to single bench: %s",
profile_exc,
)
return tuple(
timing
if timing is not None
else self._bench_grouped_entry(entry, **kwargs)
for entry, timing in zip(entries, timings)
)
def _autotune_all_groups(self, *args, **kwargs):
self._grouped_runtime_args_snapshot = args
entries = self._build_grouped_benchmark_entries()
entry_group_ids = {entry["group_id"] for entry in entries}
missing_entry_groups = tuple(
group_id
for group_id in self.reachable_group_ids
if group_id not in entry_group_ids
)
if missing_entry_groups:
raise RuntimeError(
"reachable grouped autotune groups have no benchmark entries after "
"precompile/make_launcher filtering: "
f"{missing_entry_groups}"
)
timings = self._benchmark_grouped_entries(entries, **kwargs)
if len(timings) != len(entries):
raise RuntimeError(
"grouped benchmark timing count does not match benchmark entries"
)
best_by_group = {}
for entry, timing in zip(entries, timings):
group_id = entry["group_id"]
if group_id not in best_by_group or timing < best_by_group[group_id][0]:
candidate = entry["candidate"]
best_by_group[group_id] = (
timing,
(
entry["launcher"],
candidate["policy_id"],
candidate["variant_id"],
),
)
missing_best_groups = tuple(
group_id
for group_id in self.reachable_group_ids
if group_id not in best_by_group
)
if missing_best_groups:
raise RuntimeError(
"reachable grouped autotune groups failed to select a best candidate: "
f"{missing_best_groups}"
)
for group_id, (_, best_entry) in best_by_group.items():
self.group_best[group_id] = best_entry
def _all_reachable_groups_tuned(self):
return all(self.group_best[group_id] is not None for group_id in self.reachable_group_ids)
def ensure_grouped_autotune_ready(self, *args, **kwargs):
with self.lock:
if self._all_reachable_groups_tuned():
return
self._precompile_variants()
self._autotune_all_groups(*args, **kwargs)
def run(
self, *args, stream, benchmark_run=False, **kwargs
):
self.ensure_grouped_autotune_ready(*args, **kwargs)
runtime_feature_inputs = self._runtime_feature_inputs(args)
group_id = self._resolve_group_id(runtime_feature_inputs)
launcher_entry = self.group_best[group_id]
if launcher_entry is None:
raise RuntimeError(
f"runtime inputs resolved to unreachable grouped autotune group {group_id}"
)
launcher, policy_id, variant_id = launcher_entry
runtime_blocks = self._materialize_runtime_blocks(
{"policy_id": policy_id},
args,
)
selected_config = self.grouped_plan["variants"][variant_id]["config"]
log.info(
"grouped dispatch group: kernel=%s feature_inputs=%s group_id=%s selected_config=%s runtime_blocks=%s",
self.inductor_meta.get("kernel_name", self.fn.__name__),
runtime_feature_inputs,
group_id,
selected_config,
runtime_blocks,
)
return launcher(
*self._build_runtime_launch_args(args, runtime_blocks),
stream=stream,
**kwargs,
)
class NPUDebugAutotuner(NPUCachingAutotuner):
def __init__(self, *args, regex_filter="", **kwargs):
self.regex_filter = regex_filter
super().__init__(*args, **kwargs)
self.cached = None
def run(self, *args, input_grid, stream):
possible_names = _find_names(self)
kernel_name = f"{max(possible_names, key=len)}"
if not re.match(self.regex_filter, kernel_name):
return
super().run(*args, grid=input_grid, stream=stream)
(launcher,) = self.launchers
if self.cached is None:
ms = self.bench(launcher, *args, input_grid=input_grid)
num_in_out_ptrs = len(
[
arg_name
for arg_name in self.fn.arg_names
if arg_name.startswith("in_out_ptr")
]
)
num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
gb_per_s = num_gb / (ms / 1e3)
self.cached = (ms, num_gb, gb_per_s, kernel_name)
else:
ms, num_gb, gb_per_s, kernel_name = self.cached
collected_calls.append((ms, num_gb, gb_per_s, kernel_name))
print(
create_bandwidth_info_str(ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}")
)
def cached_autotune(
size_hints: Optional[List[int]],
configs: List[Config],
triton_meta,
heuristic_type,
filename=None,
inductor_meta=None,
custom_kernel=False,
):
"""
A copy of triton.autotune that calls our subclass. Our subclass
has additional debugging, error handling, and on-disk caching.
"""
configs = unique_configs(configs)
if not (len(configs) == 1 or filename):
raise RuntimeError("assert len(configs) == 1 or filename")
inductor_meta = {} if inductor_meta is None else inductor_meta
disabled = inductor_meta.get("force_disable_caches", False)
autotune_cache = None
if (
not disabled
and filename is not None
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
):
configs_hash = hash_configs(configs)
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
if autotune_cache:
best_config = autotune_cache.read_best(inductor_meta, configs)
if best_config:
configs = [best_config]
else:
if disabled:
log.debug("autotune caching is disabled by config.force_disable_caches")
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
optimize_mem = inductor_meta.pop("optimize_mem", True)
if "restore_value" in triton_meta:
mutated_arg_names += triton_meta.pop("restore_value")
reset_to_zero_arg_names: List[str] = []
if "reset_to_zero" in triton_meta:
reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))
def decorator(fn):
if inductor_meta.get("profile_bandwidth"):
return NPUDebugAutotuner(
fn,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
regex_filter=inductor_meta["profile_bandwidth_regex"],
with_profiler=inductor_meta[
"profile_bandwidth_with_do_bench_using_profiling"
],
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
reset_to_zero_arg_names=reset_to_zero_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,
filename=filename,
with_bandwidth_info=True,
)
if npu_config.fasta_autotune:
from .fasta_autotune import NPUFastAutotuner
return NPUFastAutotuner(
fn,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
reset_to_zero_arg_names=reset_to_zero_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,
filename=filename,
)
if inductor_meta.get("group_enabled", False):
return NPUSymbolicGroupedAutotuner(
fn,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
reset_to_zero_arg_names=reset_to_zero_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,
filename=filename,
)
return NPUCachingAutotuner(
fn,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
reset_to_zero_arg_names=reset_to_zero_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,
filename=filename,
)
return decorator
def patch_triton_heuristics_cached_autotune():
torch._inductor.runtime.triton_heuristics.cached_autotune = cached_autotune
def brutal_prune_tiling_configs_if_fast_run(configs, inductor_meta) -> List[Config]:
import os
max_num_str = os.environ.get("FAST_RUN_WITH_MAX_TILING_NUM", "-1")
try:
max_num = int(max_num_str)
except ValueError:
max_num = -1
if max_num > 0 and len(configs) > max_num:
configs = configs[-1 * max_num:]
logging.debug("[%s], prune tiling configs to [%s]",
inductor_meta["kernel_name"],
len(configs))
return configs
def triton_config_npu_index(
size_hints,
inductor_meta,
triton_meta=None,
is_reduction=False,
is_persistent_reduction=False,
) -> List[Config]:
if inductor_meta.get("group_enabled", False):
if inductor_meta.get("extra_launcher_args"):
raise RuntimeError(
"grouped autotune reached runtime config with extra_launcher_args; "
"this path is not supported for the current autogenerated Triton grouped autotune flow"
)
return _triton_config_npu_index_grouped(
size_hints,
inductor_meta,
triton_meta=triton_meta,
is_reduction=is_reduction,
is_persistent_reduction=is_persistent_reduction,
)
return _triton_config_npu_index_legacy(
size_hints,
inductor_meta,
triton_meta=triton_meta,
is_reduction=is_reduction,
is_persistent_reduction=is_persistent_reduction,
)
def _triton_config_npu_index_grouped(
size_hints,
inductor_meta,
triton_meta=None,
is_reduction=False,
is_persistent_reduction=False,
) -> List[Config]:
group_features = tuple(inductor_meta.get("group_features", ()))
signature_names = tuple((triton_meta or {}).get("signature", {}).keys())
axis_names = tuple(inductor_meta.get("axis_names", ()))
axis_static_values = tuple(inductor_meta.get("axis_static_values", ()))
axis_arg_indices = {}
for axis_name in axis_names:
arg_name = f"{axis_name}_numel"
if arg_name in signature_names:
axis_arg_indices[axis_name] = signature_names.index(arg_name)
def resolve_feature_arg_indices(feature_spec) -> tuple[int, ...]:
source = feature_spec["source"]
axis_names = tuple(feature_spec["axis_names"])
if source in ("primary_axis", "axis") and len(axis_names) != 1:
raise RuntimeError(
f"group feature {feature_spec['name']} expects exactly one axis, got {axis_names}"
)
if source not in ("primary_axis", "axis", "outer_product", "reduction_product"):
raise RuntimeError(
f"group feature source {source} is not supported by feature_arg_indices v1"
)
indices = []
for axis_name in axis_names:
arg_name = f"{axis_name}_numel"
if arg_name not in signature_names:
raise RuntimeError(
f"group feature {feature_spec['name']} arg {arg_name} is missing from signature"
)
indices.append(signature_names.index(arg_name))
return tuple(indices)
feature_arg_indices = tuple(
resolve_feature_arg_indices(feature_spec) for feature_spec in group_features
)
try:
group_representatives = build_group_representatives(
group_features,
axis_names,
axis_static_values,
)
except UnsupportedGroupedPlan as exc:
raise RuntimeError(
"invalid grouped autotune plan reached runtime config; "
"the current grouped codegen guard should have rejected this plan earlier"
) from exc
group_id_count = group_representatives["group_id_count"]
reachable_group_ids = set(group_representatives["reachable_group_ids"])
runtime_block_arg_names = tuple(inductor_meta.get("runtime_block_arg_names", ()))
primary_group_axis = inductor_meta.get("primary_group_axis")
if primary_group_axis is None:
raise RuntimeError("grouped autotune plan is missing primary_group_axis")
primary_block_name = f"{primary_group_axis.upper()}BLOCK"
if runtime_block_arg_names and primary_block_name not in runtime_block_arg_names:
raise RuntimeError(
f"runtime_block_arg_names is missing primary block {primary_block_name}"
)
def benchmark_axis_env(group_id: int) -> dict[str, int]:
axis_values = group_representatives["benchmark_axis_values_by_group"][group_id]
return {axis_name: int(axis_value) for axis_name, axis_value in axis_values}
primary_feature_index = find_primary_feature_index(
group_features, primary_group_axis
)
base_size_hints = {
axis_name: int(size_hints.get(axis_name, 1))
for axis_name in axis_names
}
def representative_size_hints(group_id: int) -> dict[str, int]:
representative = dict(base_size_hints)
representative.update(benchmark_axis_env(group_id))
return representative
variant_configs = []
variant_key_to_id = {}
policies = {}
group_to_candidates = []
for group_id in range(group_id_count):
if group_id not in reachable_group_ids:
group_to_candidates.append(())
continue
legacy_configs = _triton_config_npu_index_legacy(
representative_size_hints(group_id),
inductor_meta,
triton_meta=triton_meta,
is_reduction=is_reduction,
is_persistent_reduction=is_persistent_reduction,
)
group_candidates = []
group_candidate_keys = set()
for variant_idx, legacy_cfg in enumerate(legacy_configs):
variant_cfg = strip_runtime_blocks_from_cfg(
legacy_cfg, runtime_block_arg_names
)
variant_key = repr(config_to_dict(variant_cfg))
variant_id = variant_key_to_id.get(variant_key)
if variant_id is None:
variant_id = f"v{len(variant_configs)}"
variant_key_to_id[variant_key] = variant_id
variant_configs.append(variant_cfg)
policy_id = f"p{group_id}_{variant_id}"
if len(legacy_configs) > 1:
policy_id = f"{policy_id}_{variant_idx}"
policy = build_grouped_launch_policy(
group_id=group_id,
cfg=legacy_cfg,
runtime_block_arg_names=runtime_block_arg_names,
group_features=group_features,
primary_group_axis=primary_group_axis,
primary_feature_index=primary_feature_index,
axis_env=benchmark_axis_env(group_id),
npu_num_vector_core=npu_config.num_vector_core,
)
candidate_key = (
variant_id,
tuple(policy.get("static_blocks", ())),
tuple(policy.get("runtime_block_rules", ())),
int(policy.get("grid_target", 1)),
)
if candidate_key in group_candidate_keys:
continue
group_candidate_keys.add(candidate_key)
policies[policy_id] = policy
group_candidates.append(
{
"variant_id": variant_id,
"policy_id": policy_id,
}
)
group_to_candidates.append(tuple(group_candidates))
variant_order = tuple(f"v{idx}" for idx in range(len(variant_configs)))
ordered_arg_specs = tuple(inductor_meta.get("ordered_arg_specs", ()))
tensor_arg_specs = tuple(inductor_meta.get("tensor_arg_specs", ()))
size_arg_specs = tuple(inductor_meta.get("size_arg_specs", ()))
workspace_arg_specs = tuple(inductor_meta.get("workspace_arg_specs", ()))
extra_launcher_arg_specs = tuple(inductor_meta.get("extra_launcher_arg_specs", ()))
grouped_candidate_plan = {
"variants": {
f"v{idx}": {"config": config_to_dict(cfg)}
for idx, cfg in enumerate(variant_configs)
},
"variant_order": variant_order,
"group_to_candidates": tuple(group_to_candidates),
"policies": policies,
"runtime_block_arg_count": len(runtime_block_arg_names),
"runtime_block_append_order": runtime_block_arg_names,
"group_id_count": group_id_count,
"reachable_group_ids": tuple(group_representatives["reachable_group_ids"]),
"unreachable_group_ids": tuple(group_representatives["unreachable_group_ids"]),
"group_features": tuple(group_features),
"axis_arg_indices": axis_arg_indices,
"feature_arg_indices": feature_arg_indices,
"feature_sources": tuple(
{
"name": feature_spec["name"],
"source": feature_spec["source"],
"axis_names": tuple(feature_spec["axis_names"]),
}
for feature_spec in group_features
),
"benchmark_feature_inputs_by_group": tuple(
group_representatives["benchmark_feature_inputs_by_group"]
),
"benchmark_axis_values_by_group": tuple(
group_representatives["benchmark_axis_values_by_group"]
),
"ordered_arg_specs": ordered_arg_specs,
"tensor_arg_specs": tensor_arg_specs,
"size_arg_specs": size_arg_specs,
"workspace_arg_specs": workspace_arg_specs,
"extra_launcher_arg_specs": extra_launcher_arg_specs,
}
inductor_meta["grouped_candidate_plan"] = grouped_candidate_plan
candidate_count_by_group = tuple(
(group_id, len(candidates))
for group_id, candidates in enumerate(grouped_candidate_plan["group_to_candidates"])
if candidates
)
log.info(
"grouped codegen groups: kernel=%s group_id_count=%s reachable_group_count=%s total_candidate_count=%s candidate_count_by_group=%s",
inductor_meta.get("kernel_name", "triton_"),
group_id_count,
len(grouped_candidate_plan["reachable_group_ids"]),
sum(len(candidates) for candidates in grouped_candidate_plan["group_to_candidates"]),
candidate_count_by_group,
)
return variant_configs
def extract_runtime_blocks_from_cfg(
cfg,
runtime_block_arg_names: tuple[str, ...],
) -> dict[str, int]:
cfg_dict = config_to_dict(cfg)
if isinstance(cfg_dict, dict) and "kwargs" in cfg_dict:
kwargs = dict(cfg_dict["kwargs"])
else:
kwargs = dict(cfg_dict)
return {
block_name: int(kwargs[block_name])
for block_name in runtime_block_arg_names
if block_name in kwargs
}
def build_grouped_launch_policy(
group_id: int,
cfg,
runtime_block_arg_names: tuple[str, ...],
group_features,
primary_group_axis: str,
primary_feature_index: int,
axis_env: dict[str, int],
npu_num_vector_core: int,
) -> dict[str, object]:
primary_block_name = f"{primary_group_axis.upper()}BLOCK"
runtime_blocks = extract_runtime_blocks_from_cfg(cfg, runtime_block_arg_names)
if primary_block_name not in runtime_blocks:
if runtime_block_arg_names:
raise RuntimeError(
f"legacy grouped config is missing primary block {primary_block_name}"
)
return {
"group_id": group_id,
"static_blocks": (),
"runtime_block_rules": (),
"grid_target": 1,
}
tail_group = is_open_bucket_group(
group_features, primary_feature_index, group_id
)
if not tail_group:
return {
"group_id": group_id,
"static_blocks": tuple(
(block_name, runtime_blocks[block_name])
for block_name in runtime_block_arg_names
if block_name in runtime_blocks
),
"runtime_block_rules": (),
"grid_target": 1,
}
static_blocks = []
prior_programs = 1
for block_name in runtime_block_arg_names:
if block_name not in runtime_blocks:
continue
axis_name = block_name.removesuffix("BLOCK").lower()
if axis_name == primary_group_axis:
continue
static_block = int(runtime_blocks[block_name])
static_blocks.append((block_name, static_block))
if axis_name not in axis_env:
raise RuntimeError(
f"benchmark axis env is missing split axis {axis_name}"
)
prior_programs *= (int(axis_env[axis_name]) + static_block - 1) // static_block
return {
"group_id": group_id,
"static_blocks": tuple(static_blocks),
"runtime_block_rules": (
(
primary_block_name,
(("op", "ceildiv"), ("axis_name", primary_group_axis)),
),
),
"grid_target": max(1, npu_num_vector_core // prior_programs),
}
def _triton_config_npu_index_legacy(
size_hints,
inductor_meta,
triton_meta=None,
is_reduction=False,
is_persistent_reduction=False,
) -> List[Config]:
num_warps = 1
num_stages = 1
configs = []
split_axis = inductor_meta["split_axis"]
tiling_axis = inductor_meta["tiling_axis"]
no_loop_axis = inductor_meta.get("no_loop_axis", [])
low_dims = inductor_meta["low_dims"]
split_axis_dtype = inductor_meta["split_axis_dtype"]
axis_names = inductor_meta["axis_names"]
dual_reduction = inductor_meta["dual_reduction"]
input_signature = triton_meta["signature"]
input_ptr_num = len(list(filter(lambda k: 'ptr' in k, input_signature))) if triton_meta is not None else 0
npu_kernel_type = NPUKernelType(inductor_meta.get("npu_kernel_type", "simd"))
size_hints = [size_hints.get(axis_name, 1) for axis_name in axis_names]
if npu_config.fasta_autotune:
from .fasta_autotune import FastATileGenerator
tile_generator = FastATileGenerator(size_hints, axis_names, tiling_axis, no_loop_axis, split_axis, low_dims,
persistent_reduction=is_persistent_reduction,
dtype=split_axis_dtype,
npu_kernel_type=npu_kernel_type,
input_ptr_num=input_ptr_num, dual_reduction=dual_reduction)
configs = tile_generator.descend_split_tiling()
else:
tile_generator = TileGenerator(size_hints, axis_names, tiling_axis, no_loop_axis, split_axis, low_dims,
persistent_reduction=is_persistent_reduction,
dtype=split_axis_dtype,
npu_kernel_type=npu_kernel_type,
input_ptr_num=input_ptr_num, dual_reduction=dual_reduction)
if npu_kernel_type == NPUKernelType.SIMD_SIMT_MIX:
tile_generator.set_kernel_type(NPUKernelType.SIMT_ONLY)
configs.extend(tile_generator.descend_split_tiling())
tile_generator.set_kernel_type(NPUKernelType.SIMT_TEMPLATE)
configs.extend(tile_generator.descend_split_tiling())
tile_generator.set_kernel_type(NPUKernelType.SIMD)
configs.extend(tile_generator.descend_split_tiling())
else:
configs = tile_generator.descend_split_tiling()
if not configs:
cfg = {}
for x in split_axis:
cfg[f"{axis_names[x].upper()}BLOCK"] = size_hints[x]
for x in tiling_axis:
cfg[f"{axis_names[x].upper()}BLOCK_SUB"] = size_hints[x]
if not cfg:
cfg["dummy"] = 1
tmp = Config(cfg, num_warps=num_warps, num_stages=num_stages)
configs.append(tmp)
for cfg in configs:
split_blocks = [None for x in split_axis]
for i, axis in enumerate(split_axis):
name = axis_names[axis]
block_name = f"{name.upper()}BLOCK"
split_blocks[i] = cfg.kwargs[block_name]
cfg.kwargs["split_axis"] = tuple(split_axis)
cfg.kwargs["split_blocks"] = tuple(split_blocks)
inductor_ascend_linear_mode = inductor_meta.get(
"inductor_ascend_linear_mode", "no_linear"
)
if inductor_ascend_linear_mode == "no_linear":
for tiling_cfg in configs:
tiling_kwargs = copy.deepcopy(tiling_cfg.kwargs)
for tiling, tling_value in tiling_kwargs.items():
if isinstance(tiling, str) and tiling.endswith("SUB"):
tiling_cfg.kwargs[tiling.rstrip("_SUB")] = tling_value
elif inductor_ascend_linear_mode == "no_linear_loop":
for tiling_cfg in configs:
tiling_kwargs = copy.deepcopy(tiling_cfg.kwargs)
for tiling, tling_value in tiling_kwargs.items():
if isinstance(tiling, str) and tiling.endswith(
"SUB") and tiling.startswith("R"):
tiling_cfg.kwargs[tiling.rstrip("_SUB")] = tling_value
logging.debug("[%s], generate candidate tiling count: [%s]",
inductor_meta["kernel_name"],
len(configs))
configs = brutal_prune_tiling_configs_if_fast_run(configs, inductor_meta)
return configs
def strip_runtime_blocks_from_cfg(
cfg: Config,
runtime_block_arg_names: tuple[str, ...],
) -> Config:
cfg_dict = config_to_dict(cfg)
if isinstance(cfg_dict, dict) and "kwargs" in cfg_dict:
normalized = copy.deepcopy(cfg_dict["kwargs"])
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
if key in cfg_dict:
normalized[key] = cfg_dict[key]
else:
normalized = copy.deepcopy(cfg_dict)
for name in runtime_block_arg_names:
normalized.pop(name, None)
return config_from_dict(normalized)
def pointwise(
size_hints,
triton_meta,
tile_hint=None,
filename=None,
min_elem_per_thread=0,
inductor_meta=None,
):
inductor_meta = {} if inductor_meta is None else inductor_meta
configs = triton_config_npu_index(size_hints, inductor_meta, triton_meta)
return cached_autotune(
size_hints,
[*configs],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.POINTWISE,
filename=filename,
)
def reduction(
size_hints,
reduction_hint=False,
triton_meta=None,
filename=None,
inductor_meta=None,
):
"""args to @triton.heuristics()"""
inductor_meta = {} if inductor_meta is None else inductor_meta
inductor_meta["reduction_hint"] = reduction_hint
if triton_meta is None:
raise RuntimeError("assert triton_meta is not None")
configs = triton_config_npu_index(
size_hints,
inductor_meta=inductor_meta,
triton_meta=triton_meta,
is_reduction=True,
)
return cached_autotune(
size_hints,
[
*configs,
],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
filename=filename,
heuristic_type=HeuristicType.REDUCTION,
)
def persistent_reduction(
size_hints,
reduction_hint=False,
triton_meta=None,
filename=None,
inductor_meta=None,
):
inductor_meta = {} if inductor_meta is None else inductor_meta
inductor_meta["reduction_hint"] = reduction_hint
configs = triton_config_npu_index(
size_hints,
inductor_meta=inductor_meta,
is_reduction=True,
triton_meta=triton_meta,
is_persistent_reduction=True,
)
return cached_autotune(
size_hints,
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
filename=filename,
heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
)
def foreach(
size_hints = None,
triton_meta = None,
num_warps = 1,
filename=None, inductor_meta=None
):
"""
Compile a triton foreach kernel
"""
inductor_meta = {} if inductor_meta is None else inductor_meta
configs = triton_config_npu_index(
size_hints,
inductor_meta=inductor_meta,
is_reduction=True,
triton_meta=triton_meta,
is_persistent_reduction=True,
)
return cached_autotune(
size_hints,
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,
filename=filename,
)
def benchmark_all_configs(self, *args, **kwargs):
with dynamo_timed("benchmark_all_configs"):
return self._benchmark_all_configs(*args, **kwargs)
def _benchmark_all_configs(self, *args, **kwargs):
log.info(f"{self.get_fn_name()} candidate launcher count = {len(self.launchers)}")
tilling_kernel_list = []
def kernel_call(launcher):
def call_kernel():
if not launcher.runnable:
return
if launcher.config.pre_hook is not None:
launcher.config.pre_hook(
{**dict(zip(self.arg_names, args)), **launcher.config.kwargs}
)
cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
self.reset_to_zero_args(*args, **kwargs)
launcher(
*cloned_args,
**cloned_kwargs,
stream=stream,
)
return call_kernel
for idx, launcher in enumerate(self.launchers):
if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold:
return float("inf")
device_interface = self.get_device_interface()
stream = device_interface.get_raw_stream(device_interface.current_device())
kernel_call_fn = kernel_call(launcher)
tilling_kernel_list.append(kernel_call_fn)
try:
kernel_call_fn()
torch.npu.synchronize()
log.debug(f"PreRun [{self.fn.__name__}], index: {idx}\n tiling [{launcher.config}] success")
except Exception as e:
launcher.runnable = False
log.warning(f"PreRun [{self.fn.__name__}], index: {idx}\n tiling [{launcher.config}] \n err: {e}")
valid_tiling_length = len([launcher for launcher in self.launchers if launcher.runnable])
if not valid_tiling_length:
raise RuntimeError(f"All tiling for [{self.fn.__name__}] are not runnable.")
def do_batch_benchmark(tilling_kernel_list):
def delete_file(base_path):
if os.path.exists(base_path):
shutil.rmtree(base_path)
stream = torch.npu.current_stream()
random_uuid = uuid.uuid4().hex
md5_hash = hashlib.md5(random_uuid.encode()).hexdigest()
tiling_length = len(tilling_kernel_list)
autotune_path = os.path.join(os.getcwd(), "profile_result", f"triton_{md5_hash}")
WAIT = 1
WARMUP = 1
ACTIVE = 10
REPEAT = 1
SKIP_FIRST = 1
TOTAL_STEP = (WAIT + WARMUP + ACTIVE + SKIP_FIRST) * REPEAT
with create_profiler(autotune_path, WAIT, WARMUP, ACTIVE, REPEAT, SKIP_FIRST) as prof:
stream.synchronize()
for _ in range(TOTAL_STEP):
for fn in tilling_kernel_list:
fn()
torch.npu.synchronize()
prof.step()
stream.synchronize()
import pandas as pd
for root, _, files in os.walk(autotune_path):
for file in files:
if file != 'kernel_details.csv':
continue
target_file = os.path.join(root, file)
df = pd.read_csv(target_file)
triton_rows = df[df['Name'].str.startswith('triton', na=False)]
time_cost = [0] * tiling_length
valid_tiling_index = 0
if len(triton_rows) != valid_tiling_length * ACTIVE:
raise RuntimeError(f"Expected {valid_tiling_length * ACTIVE} rows for triton kernels, but got {len(triton_rows)}. "
f"This may be due to profiling errors. Please check the profiling result at {target_file} for more details.")
for tiling_index in range(tiling_length):
if not self.launchers[tiling_index].runnable:
time_cost[tiling_index] = float('inf')
continue
for active_index in range(ACTIVE):
time_cost[tiling_index] += triton_rows.iloc[valid_tiling_index + valid_tiling_length * active_index]['Duration(us)']
valid_tiling_index += 1
time_cost = list(map(lambda x: x / ACTIVE, time_cost))
delete_file(autotune_path)
return time_cost
delete_file(autotune_path)
return []
def try_do_benchmark_using_mspti(tilling_kernel_list):
try:
timinglist = mspti_batch_benchmark(tilling_kernel_list, filter_list=["triton"])
except Exception as e:
timinglist = []
if len(timinglist) != len(self.launchers):
timinglist = do_batch_benchmark(tilling_kernel_list)
return timinglist
try:
timinglist = try_do_benchmark_using_mspti(tilling_kernel_list)
if not len(timinglist) == len(self.launchers):
raise RuntimeError("not len(timinglist) == len(self.launchers)")
timings = {launcher: timing for launcher, timing in zip(self.launchers, timinglist)}
except Exception as e:
print("some cases in batch benchmark has error! Logging Exception as:")
print(e)
print("switched to single bench...")
timings = {
launcher: self.bench(launcher, *args, **kwargs) if launcher.runnable else float("inf")
for launcher in self.launchers
}
for k, v in timings.items():
self.coordesc_tuner.cache_benchmark_result(k.config, v)
if log.isEnabledFor(logging.DEBUG):
sorted_timings = sorted(timings.items(), key=lambda x: x[1])
for [k, v] in sorted_timings:
log.debug(
"[%s] [%s] benchmark time: [%f]",
self.fn.__name__,
k.config,
v,
)
return timings
def precompile_parallel(
self,
warm_cache_only=False,
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
):
start_time = time.perf_counter()
if hasattr(self, "skip_precompile"):
if self.skip_precompile:
return
runtime_args, runtime_kwargs = self._resolve_costmodel_runtime_inputs()
self._apply_costmodel_to_configs(*runtime_args, **runtime_kwargs)
if warm_cache_only:
self.kernel_name = self.get_fn_name()
self._precompile_worker_parallel()
log.info(f"kernel: {self.get_fn_name()} precompile elapsed time: {time.perf_counter() - start_time}s")
return
if self.compile_results:
for result in self.compile_results:
TritonBundler.put(
triton_hash_to_path_key(result.kernel.hash),
self.triton_meta.get("device", 0),
)
self._make_launchers()
return
self._precompile_worker_parallel()
self._make_launchers()
log.info(f"kernel: {self.get_fn_name()} precompile elapsed time: {time.perf_counter() - start_time}s")