from __future__ import annotations
import torch
import triton
import triton.language as tl
import pytest
import numpy as np
def annotated_function(return_type=None, **arg_types):
"""A decorator to add annotations to a function."""
def decorator(func):
func.__annotations__ = {**arg_types, 'return': return_type}
return func
return decorator
@pytest.mark.parametrize(("signed", "width"), [
(signed, width) for signed in [False, True]\
for width in [8, 16, 32, 64]
] + [(False, 1)]
)
def test_int_annotation(signed, width, device):
@triton.jit
@annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}")
def _kernel(X, v):
tl.store(X + v, v)
h = _kernel[(1, )](torch.empty(1, device=device), 3)
pfx = 'si' if signed else 'ui'
if not signed and width < 64:
assert "arith.extui %v" in h.asm["ttir"]
assert f'%v: i{width}' in h.asm["ttir"]
assert f'arith.{pfx}tofp' in h.asm["ttir"]
def test_unknown_annotation(device):
@triton.jit
def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr):
pass
x = torch.empty(1, device=device)
_kernel[(1, )](x, x.shape[0], 32)
try:
_kernel[(1, )](x.shape[0], x.shape[0], 32)
except AttributeError:
pass
@pytest.mark.parametrize(
("dtype", "test_val"),
[(dtype, test_val)
for dtype in [tl.float16, tl.bfloat16, tl.float32, tl.float64]
for test_val in [0.0, 42.0, float("inf"), float("nan")]],
)
def test_float_annotation(device, dtype, test_val):
@triton.jit
@annotated_function(val=dtype)
def _kernel(ptr, val):
tl.static_assert(val.dtype == dtype)
tl.store(ptr, val)
ptr = torch.empty(1, device=device, dtype=torch.float32)
h = _kernel[(1, )](ptr, test_val)
np.testing.assert_allclose(ptr.cpu().numpy(), [test_val], atol=1e-6)
if dtype == tl.float16:
assert "%val: f16" in h.asm["ttir"]
assert "arith.extf %val : f16 to f32" in h.asm["ttir"]
elif dtype == tl.bfloat16:
assert "%val: bf16" in h.asm["ttir"]
assert "arith.extf %val : bf16 to f32" in h.asm["ttir"]
elif dtype == tl.float32:
assert "%val: f32" in h.asm["ttir"]
elif dtype == tl.float64:
assert "%val: f64" in h.asm["ttir"]
assert "arith.truncf %val : f64 to f32" in h.asm["ttir"]