import triton
import triton.language as tl
import torch
import torch_npu
@triton.jit
def add_kernel(
x_ptr, y_ptr, out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)
def main():
n = 1024
x = torch.arange(n, device="npu", dtype=torch.float32)
y = torch.ones(n, device="npu", dtype=torch.float32)
out = torch.empty_like(x)
BLOCK_SIZE = 128
grid = (triton.cdiv(n, BLOCK_SIZE),)
add_kernel[grid](
x, y, out,
n,
BLOCK_SIZE=BLOCK_SIZE,
)
if torch.allclose(out, x + y):
print("▒~\~E Triton works! Result is correct.")
else:
print("▒~]~L Wrong result!")
print(out[:10])
if __name__ == "__main__":
main()