import triton
import triton.language as tl
import torch
import torch_npu


@triton.jit
def add_kernel(
    x_ptr, y_ptr, out_ptr,  # pointers
    n_elements,             # size
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    tl.store(out_ptr + offsets, x + y, mask=mask)


def main():
    # Input size
    n = 1024

    # Allocate tensors on GPU
    x = torch.arange(n, device="npu", dtype=torch.float32)
    y = torch.ones(n, device="npu", dtype=torch.float32)
    out = torch.empty_like(x)

    BLOCK_SIZE = 128
    grid = (triton.cdiv(n, BLOCK_SIZE),)

    add_kernel[grid](
        x, y, out,
        n,
        BLOCK_SIZE=BLOCK_SIZE,
    )

    # Verify result
    if torch.allclose(out, x + y):
        print("▒~\~E Triton works! Result is correct.")
    else:
        print("▒~]~L Wrong result!")
        print(out[:10])


if __name__ == "__main__":
    main()