import importlib.util
import itertools
import os
import shutil
import pathlib
from concurrent.futures import Executor, Future, ThreadPoolExecutor
import pytest
import torch
import triton
import triton.language as tl
from triton._internal_testing import is_hip
@triton.jit
def function_0(i):
return i + 1
@triton.jit
def function_1(i):
i = i + 1
cond: tl.constexpr = True
if cond:
FN: tl.constexpr = function_2
else:
FN: tl.constexpr = function_0
return FN(i)
@triton.jit
def function_2(i):
i = i + 1
return i
@triton.jit
def combine_fn(a, b):
return COMBINE_OP
@triton.jit
def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
@triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
@triton.jit(do_not_specialize_on_alignment=["i"])
def kernel_nospec_on_alignment(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
@triton.jit
def kernel_with_combine_fn(X, BLOCK: tl.constexpr):
i = tl.arange(0, BLOCK)
i = REDUCE_OR_SCAN(i, 0, combine_fn)
tl.store(X, i)
def apply_src_change(target, old, new, to_modify):
kernel.hash = None
function_0.hash = None
function_1.hash = None
function_2.hash = None
to_modify._unsafe_update_src(to_modify.src.replace(old, new))
ret = target.cache_key
to_modify._unsafe_update_src(to_modify.src.replace(new, old))
return ret
def test_nochange():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1', function_1)
assert baseline == updated
def test_toplevel_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_1)
assert baseline != updated
def test_nested1_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_2)
assert baseline != updated
def test_nested2_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_0)
assert baseline != updated
def test_combine_fn_change():
orig_combine_fn_src = combine_fn.src
orig_kernel_src = kernel_with_combine_fn.src
seen_keys = set()
for reduce_or_scan, combine_op in itertools.product(
["tl.reduce", "tl.associative_scan"],
["a + b", "a * b"],
):
combine_fn._unsafe_update_src(orig_combine_fn_src.replace("COMBINE_OP", combine_op))
kernel_with_combine_fn._unsafe_update_src(orig_kernel_src.replace("REDUCE_OR_SCAN", reduce_or_scan))
try:
key = kernel_with_combine_fn.cache_key
finally:
combine_fn._unsafe_update_src(orig_combine_fn_src)
kernel_with_combine_fn._unsafe_update_src(orig_kernel_src)
assert key not in seen_keys
seen_keys.add(key)
@triton.constexpr_function
def constexpr_flag_fn():
return False
@triton.jit
def constexpr_fn_user(out):
a: tl.constexpr = constexpr_flag_fn()
tl.store(out, a)
def test_constexpr_fn_change():
baseline = constexpr_fn_user.cache_key
orig_src = constexpr_flag_fn.src
new_src = orig_src.replace("False", "True")
constexpr_flag_fn._unsafe_update_src(new_src)
constexpr_fn_user.hash = None
updated = constexpr_fn_user.cache_key
assert baseline != updated
constexpr_flag_fn._unsafe_update_src(orig_src)
constexpr_fn_user.hash = None
assert constexpr_fn_user.cache_key == baseline
@triton.constexpr_function
def invalid_constexpr_fn():
return torch.cuda.get_device_capability()
def test_invalid_constexpr_fn():
with pytest.raises(RuntimeError):
invalid_constexpr_fn.cache_key
def write_and_load_module(temp_file: pathlib.Path, code, num_extra_lines):
temp_file.write_text(('# extra line\n' * num_extra_lines) + code)
spec = importlib.util.spec_from_file_location("module.name", str(temp_file))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def test_changed_line_numbers_invalidate_cache(tmp_path: pathlib.Path):
from textwrap import dedent
code = dedent("""
import triton
@triton.jit
def test_kernel(i):
i = i + 1
""")
temp_file0 = tmp_path / "test_changed_line_numbers_invalidate_cache0.py"
orig_mod = write_and_load_module(temp_file0, code, 0)
orig_cache_key = orig_mod.test_kernel.cache_key
temp_file1 = tmp_path / "test_changed_line_numbers_invalidate_cache1.py"
updated_mod = write_and_load_module(temp_file1, code, 1)
updated_cache_key = updated_mod.test_kernel.cache_key
assert orig_cache_key != updated_cache_key
def test_reuse(device, fresh_triton_cache):
counter = 0
def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1
triton.knobs.runtime.jit_cache_hook = inc_counter
x = torch.empty(1, dtype=torch.int32, device=device)
for i in range(10):
kernel[(1, )](x, 1, BLOCK=1024)
assert counter == 1
@pytest.mark.parametrize('mode', ['enable', 'disable', 'disable_on_alignment'])
def test_specialize(mode, device, fresh_triton_cache):
counter = 0
def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1
triton.knobs.runtime.jit_cache_hook = inc_counter
x = torch.empty(1, dtype=torch.int32, device=device)
function = {'enable': kernel, 'disable': kernel_nospec, 'disable_on_alignment': kernel_nospec_on_alignment}[mode]
target = {'enable': 3, 'disable': 1, 'disable_on_alignment': 2}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1, )](x, i, BLOCK=512)
assert counter == target
def test_annotation(device):
@triton.jit
def kernel(X, i: tl.int32):
tl.store(X, i)
x = torch.empty(1, dtype=torch.int32, device=device)
device = getattr(torch, device).current_device()
kernel[(1, )](x, 1)
kernel[(1, )](x, 8)
kernel[(1, )](x, 16)
kernel[(1, )](x, 17)
assert len(kernel.device_caches[device][0]) == 3
GLOBAL_DEFAULT_ARG = 1
def test_kernel_default_arg(device):
global GLOBAL_DEFAULT_ARG
@triton.jit
def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
tl.store(X, i)
x = torch.empty(1, dtype=torch.int32, device=device)
kernel[(1, )](x)
assert x == torch.ones_like(x)
GLOBAL_DEFAULT_ARG = 2
kernel[(1, )](x)
assert x == torch.ones_like(x)
device = getattr(torch, device).current_device()
assert len(kernel.device_caches[device][0]) == 1
GLOBAL_VAR = tl.constexpr(1)
def test_kernel_global_var_change(device):
global GLOBAL_VAR
@triton.jit
def kernel(X):
tl.store(X, GLOBAL_VAR)
x = torch.empty(1, dtype=torch.int32, device=device)
kernel[(1, )](x)
assert x == torch.ones_like(x)
GLOBAL_VAR = 2
with pytest.raises(RuntimeError) as e:
kernel[(1, )](x)
assert "global variable" in str(e.value).lower()
GLOBAL = 42
def test_local_shadows_global():
global GLOBAL
@triton.jit
def kernel():
_, GLOBAL = 0, 0
a = GLOBAL
GLOBAL = 42
kernel[(1, )]()
GLOBAL = 43
kernel[(1, )]()
CONSTEXPR_GLOBAL = tl.constexpr(42)
def test_local_does_not_shadow_global():
global CONSTEXPR_GLOBAL
@triton.jit
def kernel():
a = CONSTEXPR_GLOBAL
_, CONSTEXPR_GLOBAL = 0, 0
CONSTEXPR_GLOBAL = tl.constexpr(42)
kernel[(1, )]()
CONSTEXPR_GLOBAL = tl.constexpr(43)
with pytest.raises(RuntimeError):
kernel[(1, )]()
CONFLICTING_GLOBAL = tl.constexpr(0)
@triton.jit
def conflicting_global_inner():
a = CONFLICTING_GLOBAL
def test_conflicting_global_in_inner_function():
global CONFLICTING_GLOBAL
@triton.jit
def kernel1():
a = CONFLICTING_GLOBAL
conflicting_global_inner()
@triton.jit
def kernel2():
a = CONFLICTING_GLOBAL
conflicting_global_inner()
kernel1[(1, )]()
CONFLICTING_GLOBAL = 1
with pytest.raises(RuntimeError) as e:
kernel2[(1, )]()
assert "Global variable CONFLICTING_GLOBAL has value" in str(e.value)
def test_use_builtin():
@triton.jit
def kernel():
a = float(0)
kernel[(1, )]()
kernel[(1, )]()
def test_no_cache_module_as_global():
@triton.jit
def kernel():
tl.arange(0, 16)
kernel[(1, )]()
assert not kernel.used_global_vals
BUILTIN_AS_GLOBAL = tl.int32
def test_cache_builtin_as_global():
global BUILTIN_AS_GLOBAL
@triton.jit
def kernel():
x = BUILTIN_AS_GLOBAL
kernel[(1, )]()
BUILTIN_AS_GLOBAL = tl.int64
with pytest.raises(RuntimeError) as e:
kernel[(1, )]()
assert "global variable" in str(e.value).lower()
def test_cache_closure():
def make_closure(cst):
@triton.jit
def closure():
tl.full((16, ), cst, dtype=tl.int32)
return closure
cst = tl.constexpr(42)
closure = make_closure(cst)
closure[(1, )]()
cst.value = 43
with pytest.raises(RuntimeError) as e:
closure[(1, )]()
assert "cst has changed since we compiled this kernel, from constexpr[42] to constexpr[43]" in str(e.value)
@triton.jit
def no_cache_callable_inner():
pass
def test_no_cache_callable():
@triton.jit
def kernel():
no_cache_callable_inner()
kernel[(1, )]()
assert not kernel.used_global_vals
def test_constexpr_cache_invalidation_recreated(device):
def test_run(val):
VAL = tl.constexpr(val)
@triton.jit
def kernel(out):
tl.store(out, VAL)
out = torch.zeros(1, device=device)
kernel[(1, )](out)
return out.item()
assert test_run(123) == 123
assert test_run(123) == 123
assert test_run(1234) == 1234
assert test_run(1234) == 1234
def test_jit_warmup_cache(device) -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
args = [
torch.randn(32, dtype=torch.float32, device=device),
torch.randn(32, dtype=torch.float32, device=device),
torch.randn(32, dtype=torch.float32, device=device),
32,
]
device = getattr(torch, device).current_device()
assert len(kernel_add.device_caches[device][0]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.device_caches[device][0]) == 1
kernel_add.warmup(*args, grid=(1, ))
assert len(kernel_add.device_caches[device][0]) == 1
kernel_add.warmup(*args, grid=(1, ))
assert len(kernel_add.device_caches[device][0]) == 1
def test_jit_debug(device) -> None:
@triton.jit
def kernel(tmp):
tl.device_assert(tl.load(tmp) == 1, "tmp == 1")
device = getattr(torch, device).current_device()
tmp = torch.tensor([1], dtype=torch.int32, device=device)
assert len(kernel.device_caches[device][0]) == 0
kernel[(1, )](tmp, debug=False)
assert len(kernel.device_caches[device][0]) == 1
kernel[(1, )](tmp, debug=True)
assert len(kernel.device_caches[device][0]) == 2
bins = list(kernel.device_caches[device][0].values())
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
@triton.jit
def add_fn(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
def test_jit_noinline(device) -> None:
@triton.jit
def kernel_add_device(a, b, o, N: tl.constexpr):
add_fn(a, b, o, N)
device = getattr(torch, device).current_device()
assert len(kernel_add_device.device_caches[device][0]) == 0
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add_device.device_caches[device][0]) == 1
bins = list(kernel_add_device.device_caches[device][0].values())
inline_ttir = bins[0].asm['ttir']
add_fn.noinline = True
add_fn.hash = None
kernel_add_device.hash = None
kernel_add_device.device_caches[device][0].clear()
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add_device.device_caches[device][0]) == 1
bins = list(kernel_add_device.device_caches[device][0].values())
noinline_ttir = bins[0].asm['ttir']
assert inline_ttir != noinline_ttir
def test_memory_leak() -> None:
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
def test_preload(device, fresh_triton_cache) -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx))
device = getattr(torch, device).current_device()
specialization_data = None
def cache_hook(*args, **kwargs):
nonlocal specialization_data
specialization_data = kwargs["compile"]["specialization_data"]
triton.knobs.runtime.jit_cache_hook = cache_hook
pre_compile = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
hash = pre_compile.hash
assert specialization_data is not None
shutil.rmtree(fresh_triton_cache)
kernel_add.device_caches[device][0].clear()
kernel_preload = kernel_add.preload(specialization_data)
assert kernel_preload.hash == hash
assert len(kernel_add.device_caches[device][0]) == 1
counter = 0
def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1
triton.knobs.runtime.jit_cache_hook = inc_counter
final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
assert counter == 0
assert len(kernel_add.device_caches[device][0]) == 1
assert final_kernel.hash == hash
with pytest.raises(RuntimeError, match="Specialization data is for"):
kernel_sub.preload(specialization_data)
def test_hooks(device, fresh_triton_cache) -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
specialization_data = None
is_warmup = False
key = 0
name = None
def cache_hook(*args, **kwargs):
nonlocal specialization_data
specialization_data = kwargs["compile"]["specialization_data"]
nonlocal is_warmup
is_warmup = kwargs["compile"]["is_warmup"]
nonlocal key
key = kwargs["compile"]["key"]
nonlocal name
name = kwargs["fn"].name
specialization_data_compiled = None
def compiled_hook(*args, **kwargs):
nonlocal specialization_data_compiled
specialization_data_compiled = kwargs["compile"]["specialization_data"]
triton.knobs.runtime.jit_cache_hook = cache_hook
triton.knobs.runtime.jit_post_compile_hook = compiled_hook
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
assert specialization_data is not None and specialization_data_compiled == specialization_data
assert is_warmup is True
assert key in kernel_add.device_caches[getattr(torch, device).current_device()][0]
assert name == "test_hooks.<locals>.kernel_add"
@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip())
def test_within_2gb(device, fresh_triton_cache) -> None:
default_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0")
try:
use_buffer_ops_opts = ["1", "0"]
pointer_ranges = [[(0, )], []]
for use_buffer_ops, pointer_range in zip(use_buffer_ops_opts, pointer_ranges):
os.environ["AMDGCN_USE_BUFFER_OPS"] = use_buffer_ops
@triton.jit
def kernel_add(a):
tl.load(a)
pointer_range_32 = None
def cache_hook(*args, **kwargs):
nonlocal pointer_range_32
pointer_range_32 = [
k for k, v in kwargs["compile"]["configs"][0].items() if ["tt.pointer_range", 32] in v
]
triton.knobs.runtime.jit_cache_hook = cache_hook
kernel_add.warmup(torch.float32, grid=(1, ))
assert pointer_range_32 == pointer_range
kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device))
assert len(pointer_range_32) == 0
kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device))
assert pointer_range_32 == pointer_range
finally:
os.environ["AMDGCN_USE_BUFFER_OPS"] = default_buffer_ops
def test_async_compile(device, fresh_triton_cache):
@triton.jit
def kernel(Y, a: tl.constexpr):
tl.store(Y, a)
with (
ThreadPoolExecutor(2) as pool,
triton.AsyncCompileMode(pool),
):
a = torch.empty((16, 16), device=device)
b = torch.empty((16, 16), dtype=torch.int32, device=device)
kernel.warmup(a, 0, grid=(1, ))
kernel.warmup(a, 1, grid=(1, ))
kernel.warmup(b, 0, grid=(1, ))
kernel.warmup(b, 1, grid=(1, ))
assert len(kernel.cache[device]) == 0
kernel[(1, )](b, 1)
assert b[0, 0] == 1
kernel[(1, )](b, 0)
assert b[0, 0] == 0
kernel[(1, )](a, 0)
assert a[0, 0] == 0
kernel[(1, )](a, 1)
assert a[0, 0] == 1
kernel[(1, )](a, 2)
assert a[0, 0] == 2