算子测试编写指南
文件结构
写一个算子的测试需要 1 个测试文件:
tests/${op_name}/
└── test_${op_name}.py # 测试文件(必须)
说明:使用 pytest 测试框架,测试文件命名为 test_${op_name}.py。
测试文件结构
必须包含的 2 个部分
1. 接口测试
验证算子已注册可用,调用方式正确:
def test_${op_name}_interface():
"""接口测试:验证算子已注册可用"""
import torch
import torch_npu
import ops_multimodal_fusion
x = torch.randn(32, 64).npu()
result = torch.ops.ops_multimodal_fusion.${op_name}(x)
# 验证输出形状
assert result.shape == x.shape
# 验证输出在 NPU 上
assert result.device.type == 'npu'
2. 精度测试
将 NPU 计算结果与 CPU 参考实现对比:
@pytest.mark.parametrize("shape", [(32, 64), (128, 256), (1024, 1024)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_${op_name}_accuracy(shape, dtype):
"""精度测试:与 CPU 参考实现对比"""
import torch
import torch_npu
import ops_multimodal_fusion
# 创建测试数据
x = torch.randn(shape, dtype=dtype)
x_npu = x.npu()
# NPU 计算
result_npu = torch.ops.ops_multimodal_fusion.${op_name}(x_npu)
# CPU 参考计算(使用 PyTorch 内置函数或自定义实现)
with torch.no_grad():
result_cpu = torch.abs(x) # 对于 abs 算子
# 对比结果
rtol = 1e-5 if dtype == torch.float32 else 1e-3
atol = 1e-5 if dtype == torch.float32 else 1e-3
assert torch.allclose(result_npu.cpu(), result_cpu, rtol=rtol, atol=atol)
完整示例
test_abs.py
import torch
import torch_npu
import pytest
# 确保算子库已加载
import ops_multimodal_fusion
class TestAbs:
"""abs 算子测试类"""
def test_abs_interface(self):
"""接口测试:验证算子已注册可用"""
x = torch.randn(32, 64).npu()
result = torch.ops.ops_multimodal_fusion.abs(x)
assert result.shape == x.shape
assert result.device.type == 'npu'
@pytest.mark.parametrize("shape", [
(32, 64),
(128, 256),
(1024, 1024),
(1, 1),
(10,),
])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_abs_accuracy(self, shape, dtype):
"""精度测试:与 CPU 参考实现对比"""
# 创建测试数据
x = torch.randn(shape, dtype=dtype)
x_npu = x.npu()
# NPU 计算
result_npu = torch.ops.ops_multimodal_fusion.abs(x_npu)
# CPU 参考计算
result_cpu = torch.abs(x)
# 对比结果
rtol = 1e-5 if dtype == torch.float32 else 1e-3
atol = 1e-5 if dtype == torch.float32 else 1e-3
assert torch.allclose(result_npu.cpu(), result_cpu, rtol=rtol, atol=atol)
def test_abs_negative_values(self):
"""边界测试:验证负数处理"""
x = torch.tensor([-1.0, -2.0, -3.0]).npu()
result = torch.ops.ops_multimodal_fusion.abs(x)
expected = torch.tensor([1.0, 2.0, 3.0])
assert torch.allclose(result.cpu(), expected)
def test_abs_zero(self):
"""边界测试:验证零值处理"""
x = torch.zeros(10).npu()
result = torch.ops.ops_multimodal_fusion.abs(x)
assert torch.allclose(result.cpu(), torch.zeros(10))
pytest 常用功能
参数化测试
@pytest.mark.parametrize("shape", [(32, 64), (128, 256)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_${op_name}(shape, dtype):
# 测试逻辑
测试类组织
class Test${OpName}:
def test_interface(self):
# 接口测试
def test_accuracy(self):
# 精度测试
def test_boundary(self):
# 边界测试
标记测试
@pytest.mark.skip(reason="待实现")
def test_future_feature():
pass
@pytest.mark.slow
def test_large_shape():
pass
运行测试
运行单个算子测试
pytest tests/${op_name}/ -v
运行所有测试
pytest tests/ -v
运行特定测试
pytest tests/${op_name}/test_${op_name}.py::Test${OpName}::test_accuracy -v
带详细输出
pytest tests/ -v -s
测试最佳实践
- 接口测试:验证算子可调用,输出形状正确
- 精度测试:使用参数化覆盖多种 shape 和 dtype
- 边界测试:测试特殊值(零、负数、极大/极小值)
- 性能测试:可选,测量算子执行时间
参考示例
完整示例:tests/abs/test_abs.py