import triton
import triton.language as tl
import numpy as np
import torch
import pytest
import test_common
def torch_add(x, y):
res = x + y
return res
@triton.jit
def triton_asm_add(x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = tl.inline_asm_elementwise(
asm="""
ADD.s64 $0, $1, $2
""",
constraints=(
"=l,l,l"
),
args=[x, y],
dtype=tl.int64,
is_pure=True,
pack=1,
)
tl.store(output_ptr + offsets, output, mask=mask)
@pytest.mark.parametrize('param_list',
[
['int64', 4096, 1024],
]
)
def test_case(param_list):
dtype, length, block_size = param_list
ncore = length // block_size
x = test_common.generate_tensor((length,), dtype).npu()
y = test_common.generate_tensor((length,), dtype).npu()
res_ref = torch_add(x, y)
res_cal = torch.zeros((length,), dtype = eval('torch.' + dtype)).npu()
triton_asm_add[(ncore,)](x, y, res_cal, length, BLOCK_SIZE=block_size)
test_common.validate_cmp(dtype, res_cal, res_ref)