import torch
import triton
import triton.language as tl
def test_add(x0, x1):
"""
测试 Triton 实现的向量加法与 PyTorch 的结果,精度比对是否一致。
步骤:
1. 使用 PyTorch 计算参考结果(torch_ref)
2. 使用 Triton 编写 kernel 并计算结果(triton_cal)
3. 调用 accuracy_comparison 进行精度比对
"""
def torch_func(x0, x1):
res = x0 + x1
return res
@triton.jit
def triton_kernel_add(
out_ptr0,
in_ptr0,
in_ptr1,
XS: tl.constexpr
):
idx = tl.arange(0, XS)
tmp0 = tl.load(in_ptr0 + idx)
tmp1 = tl.load(in_ptr1 + idx)
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + idx, tmp2)
def triton_func(x0, x1):
y0 = torch.empty_like(x0)
triton_kernel_add[1, 1, 1](y0, x0, x1, XS=x0.numel())
return y0
torch_ref = torch_func(x0, x1)
triton_cal = triton_func(x0, x1)
accuracy_comparison(triton_cal, torch_ref)
print(
f"== dtype:{triton_cal.dtype} == The accuracy comparison between triton_result and torch_result was successful.")
def accuracy_comparison(y_cal, y_ref):
"""
精度比对函数:根据数据类型选择合适的比对策略。
不同数据类型的处理策略:
- 浮点类型(float16/32, bfloat16):使用 torch.testing.assert_close,设置相对/绝对误差容限
- 整数类型(int8/16/32/64):要求完全相等(torch.equal)
- 布尔类型(bool):CPU 上严格比较(避免设备差异)
"""
assert y_cal.dtype == y_ref.dtype, f"dtype mismatch: {y_cal.dtype} vs {y_ref.dtype}"
tensor_dtype = y_cal.dtype
y_cal = y_cal.npu()
y_ref = y_ref.npu()
if tensor_dtype == torch.float16:
torch.testing.assert_close(y_ref, y_cal, rtol=1e-3, atol=1e-3, equal_nan=True)
elif tensor_dtype == torch.bfloat16:
torch.testing.assert_close(
y_ref.to(torch.float32),
y_cal.to(torch.float32),
rtol=1e-3,
atol=1e-3,
equal_nan=True
)
elif tensor_dtype == torch.float32:
torch.testing.assert_close(y_ref, y_cal, rtol=1e-4, atol=1e-4, equal_nan=True)
elif tensor_dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint32]:
assert torch.equal(y_cal, y_ref), f"Integer tensors are not equal for dtype {tensor_dtype}"
elif tensor_dtype == torch.bool:
assert torch.equal(y_cal.cpu(), y_ref.cpu()), "Boolean tensors are not equal"
else:
raise ValueError(f'Invalid or unsupported tensor dtype: {tensor_dtype}')
if __name__ == "__main__":
N = 1024
low = 1
high = 100
x0_fp32 = torch.rand((N,), dtype=torch.float32).npu()
x1_fp32 = torch.rand((N,), dtype=torch.float32).npu()
x0_fp16 = torch.rand((N,), dtype=torch.float16).npu()
x1_fp16 = torch.rand((N,), dtype=torch.float16).npu()
x0_bf16 = torch.rand((N,), dtype=torch.bfloat16).npu()
x1_bf16 = torch.rand((N,), dtype=torch.bfloat16).npu()
x0_i64 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int64).npu()
x1_i64 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int64).npu()
x0_i32 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int32).npu()
x1_i32 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int32).npu()
x0_i16 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int16).npu()
x1_i16 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int16).npu()
x0_i8 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int8).npu()
x1_i8 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int8).npu()
x0_i1 = torch.randint(low=0, high=2, size=(N,)).bool().npu()
x1_i1 = torch.randint(low=0, high=2, size=(N,)).bool().npu()
test_cases = [
('fp32', x0_fp32, x1_fp32),
('fp16', x0_fp16, x1_fp16),
('bf16', x0_bf16, x1_bf16),
('i64', x0_i64, x1_i64),
('i32', x0_i32, x1_i32),
('i16', x0_i16, x1_i16),
('i8', x0_i8, x1_i8),
('i1', x0_i1, x1_i1),
]
for dtype_name, x0, x1 in test_cases:
print(f"Running test for {dtype_name}...")
test_add(x0, x1)