import builtins
import contextlib
import dataclasses
import functools
import inspect
import itertools
import json
import logging
import math
import operator
import os
import sys
import textwrap
import time
from collections import namedtuple
from concurrent.futures import as_completed, ThreadPoolExecutor, Future
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from unittest.mock import patch
import sympy
from filelock import FileLock
import torch
import torch._inductor.async_compile
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
from torch._inductor import config, ir
from torch._inductor.ir import ChoiceCaller
from torch._inductor.utils import restore_stdout_stderr, sympy_product, unique, Placeholder
from torch._inductor.virtualized import V
from torch._inductor.codegen.triton import (
texpr,
TritonScheduling,
gen_common_triton_imports,
)
from torch._inductor.codecache import PyCodeCache
from torch._inductor.autotune_process import (
TensorMeta,
TritonBenchmarkRequest,
TritonCPUBenchmarkRequest,
TritonGPUBenchmarkRequest,
)
from torch._inductor.select_algorithm import (
TritonTemplate,
TritonTemplateKernel,
VERIFY,
DEBUG,
get_mm_log_filename,
append_to_log,
get_num_workers,
NoValidChoicesError,
create_inputs_key,
create_precompile_key,
ExternKernelCaller,
TritonTemplateCaller,
AutotuneArgs,
)
from torch._inductor.codegen.common import IndentedBuffer
from torch._inductor.exc import CppCompileError
from torch.utils._ordered_set import OrderedSet
from ..profiler import tensorboard_trace_handler
log = logging.getLogger("torch._inductor")
class NPUCompileError(CppCompileError):
pass
class NPUTritonTemplate(TritonTemplate):
"""NPU-specific Triton template for kernel generation.
This class extends TritonTemplate to provide NPU-specific optimizations
and configurations for Triton kernel generation.
"""
index_counter = itertools.count()
def __init__(self, name: str, grid: Any, source: str, debug: bool = False) -> None:
"""Initialize NPU Triton template.
Args:
name: Template name for identification
grid: Grid function for kernel launch configuration
source: Triton kernel source code
debug: Enable debug mode for verbose output
"""
super().__init__(name, grid, source, debug)
def generate(
self,
input_nodes: list[ir.IRNode],
layout: ir.Layout,
num_stages: int,
num_warps: int,
prefix_args: int = 0,
suffix_args: int = 0,
epilogue_fn: Callable = identity,
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
mutated_inputs: Optional[list[ir.IRNode]] = None,
call_sizes: Optional[list[sympy.Expr]] = None,
workspace_arg: Optional[Any] = None,
**kwargs: Any,
) -> Optional[ir.ChoiceCaller]:
defines = StringIO()
kwargs["ALLOW_TF32"] = "False"
for name, val in kwargs.items():
defines.write(f"{name} : tl.constexpr = {val}\n")
defines = defines.getvalue()
fake_out = ir.Buffer(name="buf_out", layout=layout)
kernel_name = f"triton_{self.name}"
numel = sympy_product(layout.size)
buffers = itertools.chain(input_nodes, (fake_out,))
if not TritonScheduling.can_use_32bit_indexing(numel, buffers):
raise NotImplementedError(
"64-bit indexing is not yet implemented for triton templates"
)
if call_sizes is None:
call_sizes = layout.size
kernel_options = {
"input_nodes": input_nodes,
"defines": defines,
"num_stages": num_stages,
"num_warps": num_warps,
"grid_fn": self.grid,
"meta": kwargs,
"call_sizes": call_sizes,
"prefix_args": prefix_args,
"suffix_args": suffix_args,
"epilogue_fn": epilogue_fn,
"subgraphs": subgraphs,
}
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
V.graph.set_current_device(layout.device),
NPUTritonTemplateKernel(
kernel_name=kernel_name,
output_node=fake_out,
workspace_arg=workspace_arg,
use_jit=False,
**kernel_options,
) as kernel,
):
try:
template = kernel.render(self.template, kwargs)
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
code = template.finalize_all()
except ZeroDivisionError:
log.debug(
"ZeroDivisionError during kernel rendering for %s, "
"returning None to skip this configuration",
kernel_name,
)
return None
if self.debug:
log.debug("Generated Code:\n", code)
extra = (
"-".join(
[
*[
f"{kwarg}={repr(kwargs[kwarg])}"
for kwarg in sorted(kwargs.keys())
],
f"num_stages={num_stages}",
f"num_warps={num_warps}",
]
)
+ "-"
)
mod = PyCodeCache.load(code, extra)
input_call_args = tuple(kernel.args.input_buffers.keys())
expected_input_args = tuple(unique(x.get_name() for x in input_nodes))
assert input_call_args[: len(expected_input_args)] == expected_input_args, (
input_call_args,
expected_input_args,
)
full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args])
extra_args = V.graph.sizevars.size_hints(
map(sympy.expand, tuple(kernel.args.sizevars.keys())),
fallback=config.unbacked_symint_fallback,
)
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
def make_kernel_render(out_node):
kernel = NPUTritonTemplateKernel(
kernel_name=str(Placeholder.KERNEL_NAME),
output_node=out_node,
workspace_arg=workspace_arg,
use_jit=False,
**kernel_options,
)
render = functools.partial(
kernel.render,
self.template,
kwargs,
)
return kernel, render
assert mod.__file__ is not None
grid = self.grid(
*V.graph.sizevars.size_hints(
call_sizes,
fallback=config.unbacked_symint_fallback,
),
kwargs,
)
bmreq_cls: type[TritonBenchmarkRequest]
if layout.device.type == "cpu":
bmreq_cls = TritonCPUBenchmarkRequest
else:
bmreq_cls = TritonGPUBenchmarkRequest
bmreq = bmreq_cls(
module_path=mod.__file__,
module_cache_key=mod.key,
kernel_name=kernel_name,
extra_args=[*extra_args, *grid],
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
waves_per_eu=kwargs.get("waves_per_eu", 0),
kpack=kwargs.get("kpack", 2),
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes),
output_tensor_meta=TensorMeta.from_irnodes(layout),
workspace_arg=workspace_arg,
)
return TritonTemplateCaller(
kernel_hash_name,
full_input_nodes,
layout,
make_kernel_render,
extra.strip("-").replace("-", ", "),
bmreq,
log_info={
"tile_shape": str(
(
kwargs.get("BLOCK_M", -1),
kwargs.get("BLOCK_K", -1),
kwargs.get("BLOCK_N", -1),
)
),
"num_stages": num_stages,
"num_warps": num_warps,
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
"acc_type": str(kwargs.get("ACC_TYPE", None)),
},
mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg,
allowed_prologue_inps=kernel.prologue_supported_inputs.copy(),
)
class NPUTritonTemplateKernel(TritonTemplateKernel):
"""NPU-specific Triton template kernel for code generation.
This class extends TritonTemplateKernel to provide NPU-specific
kernel generation and compilation functionality.
"""
def __init__(
self,
kernel_name: str,
input_nodes: list[ir.IRNode],
output_node: ir.IRNode,
defines: str,
num_stages: int,
num_warps: int,
grid_fn: Callable,
meta: dict[str, Any],
call_sizes: list[sympy.Expr],
use_jit: bool = False,
prefix_args: int = 0,
suffix_args: int = 0,
epilogue_fn: Callable = identity,
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
workspace_arg: Optional[Any] = None,
) -> None:
"""Initialize NPU Triton template kernel.
Args:
kernel_name: Name of the kernel
input_nodes: List of input IR nodes
output_node: Output IR node
defines: Kernel defines string
num_stages: Number of pipeline stages
num_warps: Number of warps
grid_fn: Grid function for launch configuration
meta: Metadata dictionary
call_sizes: Call sizes for grid computation
use_jit: Whether to use JIT compilation
prefix_args: Number of prefix arguments
suffix_args: Number of suffix arguments
epilogue_fn: Epilogue function
subgraphs: List of subgraph buffers
workspace_arg: Workspace argument
"""
super().__init__(
kernel_name,
input_nodes,
output_node,
defines,
num_stages,
num_warps,
grid_fn,
meta,
call_sizes,
use_jit,
prefix_args,
suffix_args,
epilogue_fn,
subgraphs,
workspace_arg,
)
def def_kernel(self, *argnames: str) -> str:
"""Hook called from template code to generate function def and needed args.
Args:
*argnames: Variable number of argument names
Returns:
Render hook key string
"""
assert all(isinstance(x, str) for x in argnames)
renames = IndentedBuffer(initial_indent=1)
named_args = self.input_nodes[
self.prefix_args : len(self.input_nodes) - self.suffix_args
]
assert len(argnames) == len(named_args), (
len(argnames),
len(named_args),
self.prefix_args,
len(self.input_nodes),
)
for idx, input_node in enumerate(self.input_nodes):
node_name = input_node.get_name()
if node_name in V.graph.removed_buffers:
continue
if node_name in self.prologue_fused_inputs:
continue
if idx < self.prefix_args:
self.args.input(node_name)
elif idx < len(self.input_nodes) - self.suffix_args:
name = argnames[idx - self.prefix_args]
arg_name = f"arg_{name}"
self.named_input_nodes[name] = input_node
self.args.input_buffers[node_name] = arg_name
else:
self.args.input(node_name)
for name in argnames:
input_node = self.named_input_nodes[name]
if input_node.get_name() in V.graph.removed_buffers:
continue
if input_node.get_name() in self.prologue_fused_inputs:
continue
arg_name = self.args.input_buffers[input_node.get_name()]
if input_node.get_layout().offset == 0:
renames.writeline(f"{name} = {arg_name}")
else:
offset = texpr(self.rename_indexing(input_node.get_layout().offset))
renames.writeline(f"{name} = {arg_name} + {offset}")
def hook():
arg_defs, *_ = self.args.python_argdefs()
code = IndentedBuffer()
code.splice(gen_common_triton_imports())
code.splice(self.jit_lines())
code.writeline(
f"def {self.kernel_name}({', '.join(x.full_name() for x in arg_defs)}):"
)
with code.indent():
code.splice(self.defines)
code.splice(renames.getvalue())
return code.getvalue()
assert "<DEF_KERNEL>" not in self.render_hooks
self.render_hooks["<DEF_KERNEL>"] = hook
return "<DEF_KERNEL>"
def patch_algorithm_selector() -> None:
"""Patch AlgorithmSelectorCache with NPU-specific implementations.
This function replaces the default AlgorithmSelectorCache methods with
NPU-optimized versions that include profiling and benchmarking capabilities
specific to NPU hardware.
"""
def __call__(
self,
name: str,
choices: List[ChoiceCaller],
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
precompilation_timeout_seconds: int = 60 * 60,
return_multi_template: bool = False,
) -> Any:
from .codegen.catlass.catlass_kernel import CATLASSTemplateCaller
if input_gen_fns is not None or layout.device.type == "cpu":
return_multi_template = False
choices = [choice for choice in choices if choice is not None]
if mm_file_name := get_mm_log_filename():
M, K = input_nodes[-2].get_size()[:2]
N = input_nodes[-1].get_size()[-1]
append_to_log(mm_file_name, {"invoke": str((M, K, N))})
if len(choices) == 0:
backend_config = (
"max_autotune_gemm_backends"
if name != "convolution"
else "max_autotune_conv_backends"
)
raise NoValidChoicesError(
f"No choices to select, please consider adding ATEN into {backend_config} "
"config (defined in torch/_inductor/config.py) to allow at least one choice. "
)
log.debug("Max autotune selects from %s choices.", str(len(choices)))
if len(choices) == 1:
if not isinstance(choices[0], CATLASSTemplateCaller):
return choices[0].output_node()
@functools.lru_cache(None)
def make_benchmark_fn():
return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
inputs_key = create_inputs_key(input_nodes)
def precompile(choices) -> Callable[[], None]:
log.debug("Starting precompilation")
def no_op(*args, **kwargs):
return
if (
precompilation_timeout_seconds is None
or precompilation_timeout_seconds <= 0
):
return no_op
num_workers = min(get_num_workers(), len(choices))
if num_workers <= 0:
return no_op
if (
sys.version_info.major == 3
and sys.version_info.minor == 11
and sys.version_info.micro <= 8
):
return no_op
timings = self.lookup(
choices,
name,
inputs_key,
benchmark=None,
)
if timings:
if len(timings) == len(choices):
log.debug("Timings found in cache, returning no_op")
return no_op
if config.search_autotune_cache and not (
config.max_autotune or config.max_autotune_gemm
):
return no_op
precompile_key = create_precompile_key(name, inputs_key, choices)
if precompile_func := self.precompile_cache.get(precompile_key):
return precompile_func
log.info(
"Multithreaded precompilation for %d choices using %d worker threads",
len(choices),
num_workers,
)
def precompile_with_captured_stdout(choice):
log.debug("Precompiling choice with captured stdout: %s", choice)
with restore_stdout_stderr():
choice.precompile()
def on_complete(future):
assert future in start_times
elapsed_times[future] = time.time() - start_times[future]
log.debug(
"Precompilation complete for future: %s, elapsed time: %.02fs",
future,
elapsed_times[future],
)
executor = ThreadPoolExecutor(max_workers=num_workers)
async_compile = torch._inductor.async_compile.AsyncCompile()
futures: dict[Future[Any], ChoiceCaller] = {}
start_times: dict[Future[Any], float] = {}
elapsed_times: dict[Future[Any], float] = {}
seen_choices: OrderedSet[ChoiceCaller] = OrderedSet()
for c in choices:
if c.hash_key() in seen_choices:
log.debug("Skipping already seen choice: %s", c)
continue
else:
seen_choices.add(c.hash_key())
if hasattr(c, "precompile"):
future = executor.submit(precompile_with_captured_stdout, c)
log.debug("Submitted precompile for choice: %s", c)
start_times[future] = time.time()
future.add_done_callback(on_complete)
futures[future] = c
@functools.lru_cache(None)
@restore_stdout_stderr()
def wait_on_futures():
counters["inductor"]["select_algorithm_precompile"] += 1
for future in as_completed(
futures,
timeout=precompilation_timeout_seconds,
):
if e := future.exception():
log.error(
"Exception %s for benchmark choice %s", e, futures[future]
)
else:
counters["inductor"]["select_algorithm_num_precompiles"] += 1
log.info(
"Precompiling benchmark choice %s took %.02fs",
futures[future],
elapsed_times[future],
)
executor.shutdown(wait=True)
self.precompile_cache[precompile_key] = wait_on_futures
return wait_on_futures
def autotune(choices):
log.debug("Starting autotuning")
with dynamo_timed(
f"{name}_template_autotuning",
log_pt2_compile_event=True,
dynamo_compile_column_us="compile_time_autotune_time_us",
):
return make_benchmark_fn()(choices)
if config.autotune_in_subproc:
from torch._inductor.autotune_process import tuning_pool
tuning_pool.initialize()
def do_autotuning(precompile_fn):
precompile_start_ts = time.time()
with dynamo_timed(
f"{name}_template_precompiling",
log_pt2_compile_event=True,
dynamo_compile_column_us="compile_time_autotune_time_us",
):
precompile_fn()
precompile_elapse = time.time() - precompile_start_ts
autotune_start_ts = time.time()
timings = self.lookup(
choices,
name,
inputs_key,
autotune,
)
autotune_elapse = time.time() - autotune_start_ts
log.debug("Autotuning elapsed time: %.02fs", autotune_elapse)
if timings and all(
not math.isfinite(timing) for timing in timings.values()
):
raise NoValidChoicesError
if make_benchmark_fn.cache_info().currsize:
counters["inductor"]["select_algorithm_autotune"] += 1
if (
make_benchmark_fn.cache_info().currsize
or log.getEffectiveLevel() == logging.DEBUG
or config.trace.log_autotuning_results
):
self.log_results(
name, input_nodes, timings, autotune_elapse, precompile_elapse
)
for feedback_fn in self.feedback_saver_fns:
feedback_fn(timings, name, input_nodes, choices)
return timings
precompile_fn = precompile(choices)
if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
def get_timings():
timings = do_autotuning(precompile_fn)
min_extern_choice = float("inf")
for choice, timing in timings.items():
if isinstance(choice, ExternKernelCaller):
min_extern_choice = min(min_extern_choice, timing)
timings = {
choice: time
for choice, time in timings.items()
if (
time <= min_extern_choice
or not isinstance(choice, ExternKernelCaller)
)
}
return timings
allowed_prologue_inps: OrderedSet[str] = OrderedSet()
for c in choices:
if isinstance(c, TritonTemplateCaller):
allowed_prologue_inps |= c.allowed_prologue_inps
return torch._inductor.ir.TensorBox.create(
torch._inductor.ir.MultiTemplateBuffer(
layout,
input_nodes,
get_timings,
choices,
allowed_prologue_inps,
)
)
timings = do_autotuning(precompile_fn)
if timings == {} or choices[0] not in timings:
return choices[0].output_node()
selected_key = builtins.min(timings, key=timings.__getitem__)
selected_choice = selected_key.output_node()
log.debug("selected choice: %s", str(selected_choice))
return selected_choice
@classmethod
def make_benchmark_fn(
cls,
choices: List[ChoiceCaller],
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
) -> Callable:
"""Create a benchmark function for the given choices.
Args:
choices: List of choice callers to benchmark
input_nodes: List of input IR nodes
layout: Output layout
input_gen_fns: Optional dict mapping arg indices to input generation functions
Returns:
Benchmark function that can be called with choices
"""
if input_gen_fns is None:
input_gen_fns = {}
def get_inputs(
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]],
) -> AutotuneArgs:
unique_example_inputs = {
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
for i, x in enumerate(input_nodes)
}
example_inputs = list(unique_example_inputs.values())
example_inputs_extern = [
(
unique_example_inputs[input_node.get_name()]
if unique_example_inputs[input_node.get_name()].is_mkldnn
else torch.as_strided(
unique_example_inputs[input_node.get_name()],
V.graph.sizevars.size_hints(
input_node.get_size(),
fallback=config.unbacked_symint_fallback,
),
V.graph.sizevars.size_hints(
input_node.get_stride(),
fallback=config.unbacked_symint_fallback,
),
V.graph.sizevars.size_hint(
input_node.get_layout().offset,
fallback=config.unbacked_symint_fallback,
),
)
)
for input_node in input_nodes
]
from .codegen.catlass.catlass_kernel import CATLASSTemplateCaller
is_group_mm = False
for choice in choices:
if isinstance(choice, CATLASSTemplateCaller) and "GroupedMatmulSliceMTla" in choice.description:
is_group_mm = True
if not is_group_mm and len(input_nodes) == 3:
example_inputs = example_inputs[1:] + [example_inputs[0]]
out = cls.benchmark_example_value(layout)
out_extern = torch.as_strided(
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
)
expected = None
if VERIFY:
choices[0].benchmark(*example_inputs_extern, out=out_extern)
expected = out_extern.clone()
return AutotuneArgs.from_choice_args(
example_inputs,
example_inputs_extern,
out,
out_extern,
expected,
)
if DEBUG:
log.debug(f"{len(choices)} tuning requests:")
def benchmark_choice_in_current_process(
choice: ChoiceCaller, autotune_args: AutotuneArgs
) -> float:
is_extern = isinstance(choice, ExternKernelCaller)
benchmark_tensors = autotune_args.get_benchmark_tensors(is_extern)
inpts, output = benchmark_tensors.unpack()
output.zero_()
result = choice.benchmark(*inpts, out=output)
if VERIFY and autotune_args.expected is not None:
autotune_args.verify(**VERIFY)
if torch.npu.is_available():
torch.npu.synchronize()
return result
def profiling_choices_in_current_process(
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]],
) -> Dict[Union[ExternKernelCaller, TritonTemplateCaller], float]:
inputs = get_inputs(choices)
funcs = []
for choice in choices:
is_extern = isinstance(choice, ExternKernelCaller)
benchmark_tensors = inputs.get_benchmark_tensors(is_extern)
inpts, output = benchmark_tensors.unpack()
output.zero_()
if is_extern:
algo = choice.to_callable()
fn = algo
args = tuple(inpts)
kwargs = {"out": output}
else:
fn = choice.bmreq.make_run_fn(*inpts, output_tensor=output)
args = ()
kwargs = {}
funcs.append((fn, args, kwargs))
func_times = do_batch_profiling(funcs)
return {choice: func_times[i] for i, choice in enumerate(choices)}
def do_batch_profiling(
funcs: List[Tuple[Callable, Tuple, Dict]], key: Optional[str] = None
) -> List[Optional[float]]:
import torch_npu
import shutil
import uuid
import hashlib
def delete_file(base_path):
if os.path.exists(base_path):
shutil.rmtree(base_path)
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
l2_cache=False,
data_simplification=False,
)
random_uuid = uuid.uuid4().hex
md5_hash = hashlib.md5(random_uuid.encode()).hexdigest()
num_funcs = len(funcs)
torch_path = os.path.join(os.getcwd(), "profile_results", md5_hash)
TOTAL_STEP = 50
l2_cache_size = 192 * (1 << 20)
buffer = torch.empty(l2_cache_size // 4, dtype=torch.int, device="npu")
buffer.zero_()
torch.npu.synchronize()
with torch_npu.profiler.profile(
activities=[torch_npu.profiler.ProfilerActivity.NPU],
on_trace_ready=tensorboard_trace_handler(torch_path),
record_shapes=False,
profile_memory=False,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config,
):
for fn, args, kwargs in funcs:
for _ in range(TOTAL_STEP):
buffer.zero_()
fn(*args, **kwargs)
torch.npu.synchronize()
buffer.abs_()
torch.npu.synchronize()
del buffer
import pandas as pd
for root, _, files in os.walk(torch_path):
for file in files:
if file != "kernel_details.csv":
continue
target_file = os.path.join(root, file)
df = pd.read_csv(target_file)
filter_cond = ~df["Name"].str.contains(r"zero|ZerosLike", case=False, na=False)
filter_df = df[filter_cond]
if key is not None:
key_rows = filter_df[filter_df["Name"].str.contains(key, na=False)]
else:
key_rows = filter_df
time_cost = []
last_df_index = -1
for idx, row in key_rows.iterrows():
if "absaicore" in row["Name"].lower():
time_cost.append(key_rows.loc[last_df_index + 1:idx - 1, 'Duration(us)'].sum())
last_df_index = idx
time_cost = [x / TOTAL_STEP / 1e3 for x in time_cost]
delete_file(torch_path)
return time_cost
delete_file(torch_path)
return []
def benchmark_in_current_process(
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]],
) -> Dict[Union[ExternKernelCaller, TritonTemplateCaller], float]:
inputs = get_inputs(choices)
timings = {}
for choice in choices:
try:
timing = benchmark_choice_in_current_process(choice, inputs)
except NPUCompileError as e:
log.error(
"NPU compilation error during autotuning: \n%s. \nIgnoring this choice.",
str(e),
)
timing = float("inf")
except NotImplementedError as e:
log.warning("Not yet implemented: %s", e)
timing = float("inf")
except RuntimeError as e:
msg = str(e)
if "invalid argument" in msg:
msg += "\n\nThis may mean this NPU is too small for max_autotune mode.\n\n"
else:
if "illegal memory access" in msg:
msg += "\n\nEither error in template or triton bug.\n"
log.error(
"Runtime error during autotuning: \n%s. \nIgnoring this choice.",
msg,
)
timing = float("inf")
except AssertionError as e:
raise AssertionError(
f"Incorrect result from choice {choice}\n"
) from e
except Exception as e:
try:
from triton.runtime.autotuner import OutOfResources
if isinstance(e, OutOfResources):
log.warning(e)
timing = float("inf")
else:
raise e
except ImportError:
raise e from None
timings[choice] = timing
return timings
def benchmark_in_sub_process(
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]],
):
from torch._inductor import autotune_process
from .codegen.catlass.catlass_kernel import CATLASSTemplateCaller
extern = [
c
for c in choices
if isinstance(c, ExternKernelCaller) or isinstance(c, CATLASSTemplateCaller)
]
triton = [c for c in choices if c not in extern]
timings = benchmark_in_current_process(extern)
timings.update(autotune_process.benchmark_in_sub_process(triton))
return timings
from .config import catlass as catlass_config
if catlass_config.catlass_bench_use_profiling:
benchmark = profiling_choices_in_current_process
else:
benchmark = (
benchmark_in_sub_process
if config.autotune_in_subproc
else benchmark_in_current_process
)
return benchmark
from torch._inductor.select_algorithm import AlgorithmSelectorCache
AlgorithmSelectorCache.__call__ = __call__
AlgorithmSelectorCache.make_benchmark_fn = make_benchmark_fn