import pytest
import triton
import triton.language as tl
from typing import NamedTuple
import torch
@triton.jit
def _tuple_increment(values):
return tl.tuple([v + 1 for v in values])
@triton.jit
def _tuple_index_func(Ptrs, values):
for i in tl.static_range(len(values)):
tl.store(Ptrs[i], values[i])
@triton.jit
def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):
values = _tuple_increment(values)
_tuple_index_func(Ptrs, values)
@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
def test_index(size, device):
vals = tuple([i + 1 for i in range(size)])
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
assert vals == tuple([x.item() - 1 for x in rets])
@triton.jit
def _tuple_assign(XPtrs, YPtrs, values):
X0, X1 = XPtrs
x0, x1, _ = values
tl.store(X0, x0)
tl.store(X1, x1)
Y0, Y1, Y2 = YPtrs
Y = Y0, Y1, Y2
y = x0, 10, x1
tl.store(Y[0], y[0])
tl.store(Y[1], y[1])
tl.store(Y[2], y[2])
@pytest.mark.interpreter
def test_assign(device):
vals = (2., 3., None)
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
_tuple_assign[(1, )](x, y, vals)
assert x[0] == vals[0]
assert x[1] == vals[1]
assert y[0] == vals[0]
assert y[1] == 10
assert y[2] == vals[1]
@triton.jit
def _tuple_ret(a, b):
return a + b, \
a - b, \
a * b
@pytest.mark.interpreter
def test_assign_return(device):
@triton.jit
def with_fn(X, Y, A, B, C):
x = tl.load(X)
y = tl.load(Y)
a, b, c = _tuple_ret(x, y)
tl.store(A, a)
tl.store(B, b)
tl.store(C, c)
@triton.jit
def without_fn(X, Y, A, B, C):
x = tl.load(X)
y = tl.load(Y)
a, b, c = x + y, x - y, x * y
tl.store(A, a)
tl.store(B, b)
tl.store(C, c)
x = torch.tensor([1.3], device=device, dtype=torch.float32)
y = torch.tensor([1.9], device=device, dtype=torch.float32)
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
for kernel in [with_fn, without_fn]:
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
a_ref, b_ref, c_ref = x + y, x - y, x * y
assert a_tri == a_ref
assert b_tri == b_ref
assert c_tri == c_ref
@triton.jit
def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
tl.static_assert(tuple1[1] is None)
tl.store(Ptr + 5, cst2)
tl.store(Ptr + 6, tuple1[0])
tl.store(Ptr + 7, tl.load(tuple1[2][0]))
tl.store(Ptr + 8, tuple1[2][1][0])
tl.store(Ptr + 9, tl.load(tuple1[2][1][2]))
@triton.jit
def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2):
tl.static_assert(N1 is None)
tl.static_assert(tuple1[1][1] is None)
tl.static_assert(tuple1[1][3] == 4)
tl.store(Ptr + 0, tl.load(tuple1[0]))
tl.store(Ptr + 1, tuple1[1][0])
tl.store(Ptr + 2, tl.load(tuple1[1][2]))
tl.store(Ptr + 3, cst1 + val1)
tl.store(Ptr + 4, tl.load(tuple2[0]))
_tuple_fn0(Ptr, 15, (-1, None, tuple1))
@pytest.mark.interpreter
def test_serialize(device):
x0 = torch.tensor([8], dtype=torch.int32, device=device)
x1 = torch.tensor([12], dtype=torch.int32, device=device)
y0 = torch.tensor([10], dtype=torch.int32, device=device)
z = torch.empty((10, ), dtype=torch.int32, device=device)
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1, tl.constexpr(4))), 20, 1, (y0, ))
ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device)
assert torch.equal(z, ref)
class Function(NamedTuple):
fn: tl.constexpr
captured: tuple
class Tensor(NamedTuple):
ptr: any
shape: tuple
stride: tuple
@triton.jit
def _namedtuple_create_func0(shape, ptr, stride):
return Tensor(shape=shape, ptr=ptr, stride=stride)
@triton.jit
def _namedtuple_create_func1(shape, ptr, stride):
tensor = Tensor(shape=shape, ptr=ptr, stride=stride)
return tensor
@triton.jit
def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
mask = (offs_m[:, None] < Tensor.shape[0]) & (offs_n[None, :] < Tensor.shape[1])
return mask
@triton.jit
def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
X = _namedtuple_create_func0(_X.shape, _X.ptr, _X.stride)
Y = _namedtuple_create_func1(Y.shape, Y.ptr, Y.stride)
Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1]
Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1]
x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0)
y = closure.fn(x, *closure.captured)
tl.store(Ys, y, mask=_namedtuple_mask_func(Y, BLOCK_M, BLOCK_N))
@pytest.mark.interpreter
def test_namedtuple(device):
x = torch.randn((32, 32), dtype=torch.float32, device=device)
y = torch.empty((16, 16), dtype=torch.float32, device=device)
a = torch.tensor([5.2], dtype=torch.float32, device=device)
@triton.jit
def mul(x, a):
return x * tl.load(a)
function = Function(mul, (a, ))
tx = Tensor(x, x.shape, x.stride())
ty = Tensor(y, y.shape, y.stride())
_namedtuple_kernel[(1, )](function, tx, ty, 64, 64)
assert torch.allclose(y, x[:16, :16] * a)
@pytest.mark.interpreter
def test_eq(device):
@triton.jit
def fn(ret_ptrs):
tl.store(ret_ptrs + 0, (1, 2) == (1, 2))
tl.store(ret_ptrs + 1, (1, 2) == (1, 1))
tl.store(ret_ptrs + 2, tl.tuple((1, 2)) == (1, 2))
tl.store(ret_ptrs + 3, tl.tuple((1, 2)) == (1, 3))
rets = torch.zeros((4, ), dtype=torch.int32, device=device)
fn[(1, )](rets)
assert rets[0].item() == 1
assert rets[1].item() == 0
assert rets[2].item() == 1
assert rets[3].item() == 0
@pytest.mark.interpreter
def test_add(device):
@triton.jit
def fn(ret_ptrs):
tuple0 = ((0, 1)) + (2, 3)
for i in tl.static_range(4):
tl.store(ret_ptrs + i, tuple0[i])
tuple1 = tl.tuple((4, 5)) + (6, 7)
for i in tl.static_range(4):
tl.store(ret_ptrs + 4 + i, tuple1[i])
rets = torch.zeros((8, ), dtype=torch.int32, device=device)
fn[(1, )](rets)
torch.testing.assert_close(rets.cpu(), torch.arange(8, dtype=torch.int32))
def test_passing_tuple_with_constexpr(device):
@triton.jit
def m_to_the_n(X, shape: tl.constexpr, strides, m_n):
Xs = X + tl.arange(0, shape[0])[:, None] * strides[0] + tl.arange(0, shape[1])[None, :] * strides[1]
data = tl.load(Xs)
for i in tl.range(0, m_n[1]):
data = m_n[0] * data
tl.store(Xs, data)
x = torch.arange(0, 64, device=device).reshape(8, 8)
expected_x = 8 * x.clone()
m_to_the_n[(1, )](x, x.shape, x.stride(), (2, 3))
torch.testing.assert_close(x, expected_x, rtol=0, atol=0)
def test_passing_tuple_to_make_tensor_descriptor(device, with_allocator):
@triton.jit
def m_to_the_n(X_base, shape, strides, m_n, BLOCK_DIM: tl.constexpr):
tl.static_assert(isinstance(strides[1].type, tl.constexpr_type))
X = tl.make_tensor_descriptor(
X_base,
shape=shape,
strides=strides,
block_shape=[BLOCK_DIM, BLOCK_DIM],
)
tl.static_assert(isinstance(strides[1].type, tl.constexpr_type))
data = X.load([0, 0])
for i in tl.range(0, m_n[1]):
data = m_n[0] * data
X.store([0, 0], data)
x = torch.arange(0, 16, device=device).reshape(4, 4)
expected_x = 8 * x.clone()
m_to_the_n[(1, )](x, x.size(), x.stride(), (2, 3), x.size(0))
torch.testing.assert_close(x, expected_x, rtol=0, atol=0)
def test_modifying_tuples():
@triton.jit
def set_tuple_value_at_idx():
t = tl.tuple([5, 6, 7])
t[0] = 0
with pytest.raises(triton.CompilationError):
set_tuple_value_at_idx[(1, )]()
@pytest.mark.interpreter
def test_tuple_logic():
@triton.jit
def tuple_logic_kernel():
tl.static_assert(((3, 4) or (5, 6)) == (3, 4))
tl.static_assert(((3, 4) and (5, 6)) == (5, 6))
tl.static_assert(((3, 4) and ()) == ())
tl.static_assert((() or (5, 6)) == (5, 6))
tl.static_assert(((1, 2) and (3, 4) and (5, 6)) == (5, 6))
tl.static_assert(((1, 2) or (3, 4) or (5, 6)) == (1, 2))
tl.static_assert((() and tl.program_id(0)) == ())
tuple_logic_kernel[(1, )]()
@pytest.mark.interpreter
def test_tuple_float():
@triton.jit
def _namedtuple_float_tuple_kernel():
x, y = float("-inf"), float("inf")
_namedtuple_float_tuple_kernel[(1, )]()