import torch
import triton
import triton.language as tl
import pytest
def test_decorator_with_def(device):
def triton_heuristics_pointwise(**kwargs):
def decorator(func):
return func
return decorator
@triton_heuristics_pointwise(inductor_meta={'backend_hash': 'def0aeffabe53b3f8'}, )
@triton.jit
def kernel():
pass
try:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
except Exception as e:
pytest.fail(f"triton compile failed with error: {e}")
def test_triton_heuristic(device):
N = 1023
src = torch.empty(N, device=device)
dst = torch.zeros(N, device=device)
do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, quantiles=quantiles, warmup=1, rep=1)
@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], do_bench=do_bench)
@triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0})
@triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0})
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr, EVEN_N: tl.constexpr, EVEN_src: tl.constexpr):
tl.store(dst, EVEN_N)
tl.store(dst + 1, EVEN_src)
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
_kernel[grid](dst, src, N=N)
assert dst[0].item() == 0.0
assert dst[1].item() == 1.0
assert _kernel.base_fn.__name__ == "_kernel"