import torch
import triton
import triton.language as tl
def test_add(input0, input1):
def torch_func(a, b):
return a + b
@triton.jit
def triton_kernel_add(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr):
idx = tl.arange(0, XS)
tmp0 = tl.load(in_ptr0 + idx)
tmp1 = tl.load(in_ptr1 + idx)
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + idx, tmp2)
def triton_func(a, b):
n = a.shape[0]
y0 = torch.empty_like(a)
triton_kernel_add[(1,)](y0, a, b, n)
return y0
torch_ref = torch_func(input0, input1)
triton_cal = triton_func(input0, input1)
torch.testing.assert_close(triton_cal, torch_ref)
def test_or(input0, input1):
def torch_func(a, b):
return a | b
@triton.jit
def triton_kernel_or(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr):
idx = tl.arange(0, XS)
tmp0 = tl.load(in_ptr0 + idx)
tmp1 = tl.load(in_ptr1 + idx)
tmp2 = tmp0 | tmp1
tl.store(out_ptr0 + idx, tmp2)
def triton_func(a, b):
n = a.shape[0]
y0 = torch.empty_like(a)
triton_kernel_or[(1,)](y0, a, b, n)
return y0
torch_ref = torch_func(input0, input1)
triton_cal = triton_func(input0, input1)
torch.testing.assert_close(triton_cal, torch_ref)
def test_inductor_add(a, b):
def torch_add(x, y):
res = x + y
return res
compiled_func = torch.compile(torch_add)
compiled_func(a, b)
if __name__ == "__main__":
test_case_is_inductor = False
N = 1024
low = 1
high = 100
x0_fp32 = torch.rand((N,), dtype=torch.float32).npu()
x1_fp32 = torch.rand((N,), dtype=torch.float32).npu()
x0_fp16 = torch.rand((N,), dtype=torch.float16).npu()
x1_fp16 = torch.rand((N,), dtype=torch.float16).npu()
x0_bf16 = torch.rand((N,), dtype=torch.bfloat16).npu()
x1_bf16 = torch.rand((N,), dtype=torch.bfloat16).npu()
x0_i64 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int64).npu()
x1_i64 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int64).npu()
x0_i32 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int32).npu()
x1_i32 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int32).npu()
x0_i16 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int16).npu()
x1_i16 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int16).npu()
x0_i8 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int8).npu()
x1_i8 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int8).npu()
x0_i1 = torch.randint(low=0, high=2, size=(N,)).bool().npu()
x1_i1 = torch.randint(low=0, high=2, size=(N,)).bool().npu()
test_cases = [
('fp32', x0_fp32, x1_fp32),
('fp16', x0_fp16, x1_fp16),
('bf16', x0_bf16, x1_bf16),
('i64', x0_i64, x1_i64),
('i32', x0_i32, x1_i32),
('i16', x0_i16, x1_i16),
('i8', x0_i8, x1_i8),
('i1', x0_i1, x1_i1),
]
for dtype_name, x0, x1 in test_cases:
print(f"Running test for {dtype_name}...")
if dtype_name != 'i1':
test_inductor_add(x0, x1)
test_add(x0, x1)
else:
test_or(x0, x1)