import math
import torch
import triton
import triton.language as tl
import pytest
@triton.jit
def if_tensor_kernel(
kv_start_idx,
output_ptr,
):
pid = tl.program_id(0)
if kv_start_idx:
value = tl.load(kv_start_idx + pid)
tl.store(output_ptr + pid, value)
def test_kernel():
n = 8
device = 'npu'
kv_start_idx = torch.arange(n, dtype=torch.float32, device=device)
output1 = torch.zeros(n, dtype=torch.float32, device=device)
if_tensor_kernel[(n,)](
kv_start_idx, output1,
)
expected = torch.arange(n, dtype=torch.float32, device=device)
assert torch.allclose(output1, expected), f"Output {output1} != Expected {expected}"
print(f"RESULT: output1 = {output1}")
print("✅ Test passed!")
@triton.jit
def mul_if_block_kernel(
value,
value_stride0,
value_stride1,
output,
output_stride0,
output_stride1,
lengths,
bs,
dim,
max_seq_len,
DIM_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr
):
batch_idx = tl.program_id(0)
if batch_idx >= bs:
return
need_reverse = tl.program_id(1) == 0
block_idx = tl.program_id(2)
block_start = block_idx * BLOCK_SIZE
reverse_len = tl.load(lengths + batch_idx)
copy_len = max_seq_len - reverse_len
value = value + batch_idx * value_stride0
output = output + batch_idx * output_stride0
if need_reverse:
if block_start >= reverse_len:
return
value_block_ptr = tl.make_block_ptr(
base=value,
shape=(reverse_len, dim),
strides=(value_stride0, 1),
offsets=(block_start, 0),
block_shape=(BLOCK_SIZE, DIM_SIZE),
order=(0, 1)
)
block_values = tl.load(value_block_ptr, boundary_check=(0, 1), padding_option="zero")
output_block_ptr = tl.make_block_ptr(
base=output + (reverse_len - 1) * output_stride1,
shape=(reverse_len, dim),
strides=(output_stride0, 1),
offsets=(block_start, 0),
block_shape=(BLOCK_SIZE, DIM_SIZE),
order=(0, 1)
)
tl.store(output_block_ptr, block_values, boundary_check=(0, 1))
else:
if block_start >= copy_len:
return
value_block_ptr = tl.make_block_ptr(
base=value + reverse_len * value_stride1,
shape=(max_seq_len, dim),
strides=(value_stride0, 1),
offsets=(block_start, 0),
block_shape=(BLOCK_SIZE, DIM_SIZE),
order=(0, 1)
)
block_values = tl.load(value_block_ptr, boundary_check=(0, 1), padding_option="zero")
output_block_ptr = tl.make_block_ptr(
base=output + reverse_len * output_stride1,
shape=(max_seq_len, dim),
strides=(output_stride0, 1),
offsets=(block_start, 0),
block_shape=(BLOCK_SIZE, DIM_SIZE),
order=(0, 1)
)
tl.store(output_block_ptr, block_values, boundary_check=(0, 1))
def ref_reverse(value: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
bs, max_seq_len, dim = value.shape
out = value.clone()
for b in range(bs):
r = int(lengths[b])
if r > 0:
out[b, :r] = value[b, :r].flip(0)
return out
@pytest.mark.parametrize("bs,max_seq_len,dim,BLOCK_SIZE", [
(2, 32, 16, 8),
])
def test_reverse_sequence_kernel(bs, max_seq_len, dim, BLOCK_SIZE):
device = "npu"
value = torch.randn(bs, max_seq_len, dim, device=device, dtype=torch.float32)
lengths = torch.tensor(
[0, max_seq_len // 2, max_seq_len] + [max_seq_len // 3] * max(0, bs - 3),
device=device, dtype=torch.int32
)[:bs]
output = torch.empty_like(value)
value_stride0 = value.stride(0)
value_stride1 = value.stride(1)
output_stride0 = output.stride(0)
output_stride1 = output.stride(1)
DIM_SIZE = dim
grid = (bs, 2, math.ceil(max_seq_len / BLOCK_SIZE))
mul_if_block_kernel[grid](
value, value_stride0, value_stride1,
output, output_stride0, output_stride1,
lengths, bs, dim, max_seq_len,
DIM_SIZE=DIM_SIZE, BLOCK_SIZE=BLOCK_SIZE
)
if __name__ == "__main__":
test_kernel()