import pytest
from akg_agents.op.utils.config_utils import check_task_config
class TestTaskConfig:
"""测试任务配置验证函数"""
@pytest.mark.level0
def test_valid_configs(self):
"""测试有效的配置组合"""
valid_configs = [
("torch", "triton_ascend", "ascend", "ascend910b4"),
("torch", "triton_cuda", "cuda", "a100")
]
for framework, dsl, backend, arch in valid_configs:
try:
check_task_config(framework, backend, arch, dsl)
print(f"有效配置: {framework} + {backend} + {arch} + {dsl}")
except ValueError as e:
pytest.fail(f"有效配置验证未通过: {framework} + {backend} + {arch} + {dsl}, 错误: {e}")
@pytest.mark.level0
def test_invalid_framework(self):
"""测试无效的框架"""
invalid_configs = [
("invalid_framework", "triton_ascend", "ascend", "ascend910b4"),
("pytorch", "triton_cuda", "cuda", "a100"),
]
for framework, dsl, backend, arch in invalid_configs:
with pytest.raises(ValueError, match="Unsupported framework"):
check_task_config(framework, backend, arch, dsl)
print(f"正确捕获无效框架错误: {framework}")
@pytest.mark.level0
def test_invalid_backend(self):
"""测试无效的后端"""
invalid_configs = [
("torch", "triton_ascend", "invalid_backend", "ascend910b4"),
]
for framework, dsl, backend, arch in invalid_configs:
with pytest.raises(ValueError, match="does not support backend"):
check_task_config(framework, backend, arch, dsl)
print(f"正确捕获无效后端错误: {framework} + {backend}")
@pytest.mark.level0
def test_invalid_arch(self):
"""测试无效的架构"""
invalid_configs = [
("torch", "triton_cuda", "cuda", "invalid_arch")
]
for framework, dsl, backend, arch in invalid_configs:
with pytest.raises(ValueError, match="does not support arch"):
check_task_config(framework, backend, arch, dsl)
print(f"正确捕获无效架构错误: {backend} + {arch}")
@pytest.mark.level0
def test_invalid_dsl(self):
"""测试无效的实现类型"""
invalid_configs = [
("torch", "invalid_impl", "cuda", "a100")
]
for framework, dsl, backend, arch in invalid_configs:
with pytest.raises(ValueError, match="does not support dsl"):
check_task_config(framework, backend, arch, dsl)
print(f"正确捕获无效实现类型错误: {dsl}")
@pytest.mark.level0
def test_mismatched_combinations(self):
"""测试不匹配的组合"""
mismatched_configs = [
("torch", "triton_cuda", "ascend", "ascend910b4"),
("torch", "triton_ascend", "cuda", "a100"),
]
for framework, dsl, backend, arch in mismatched_configs:
with pytest.raises(ValueError, match="does not support dsl"):
check_task_config(framework, backend, arch, dsl)
print(f"正确捕获不匹配组合错误: {framework} + {backend} + {arch} + {dsl}")
@pytest.mark.level0
def test_nonexistent_combinations(self):
"""测试不存在的组合"""
nonexistent_configs = [
("torch", "triton_ascend", "ascend", "ascend310p3"),
]
for framework, dsl, backend, arch in nonexistent_configs:
with pytest.raises(ValueError):
check_task_config(framework, backend, arch, dsl)
print(f"正确捕获不存在组合错误: {framework} + {backend} + {arch} + {dsl}")
@pytest.mark.level0
def test_edge_cases(self):
"""测试边界情况"""
with pytest.raises(ValueError):
check_task_config("", "ascend", "ascend910b4", "triton_ascend")
with pytest.raises(ValueError):
check_task_config(None, "ascend", "ascend910b4", "triton_ascend")
with pytest.raises(ValueError):
check_task_config("MindSpore", "ascend", "ascend910b4", "triton_ascend")
print("正确捕获边界情况错误")
@pytest.mark.level0
def test_empty_values_for_each_param(self):
"""测试每个参数为空时都会报错(防止隐式默认值遗漏)"""
with pytest.raises((ValueError, AttributeError)):
check_task_config("", "ascend", "ascend910b4", "triton_ascend")
with pytest.raises((ValueError, AttributeError)):
check_task_config(None, "ascend", "ascend910b4", "triton_ascend")
with pytest.raises((ValueError, AttributeError)):
check_task_config("torch", "", "ascend910b4", "triton_ascend")
with pytest.raises((ValueError, AttributeError)):
check_task_config("torch", None, "ascend910b4", "triton_ascend")
with pytest.raises((ValueError, AttributeError)):
check_task_config("torch", "ascend", "", "triton_ascend")
with pytest.raises((ValueError, AttributeError)):
check_task_config("torch", "ascend", None, "triton_ascend")
with pytest.raises((ValueError, AttributeError)):
check_task_config("torch", "ascend", "ascend910b4", "")
with pytest.raises((ValueError, AttributeError)):
check_task_config("torch", "ascend", "ascend910b4", None)
with pytest.raises((ValueError, AttributeError)):
check_task_config(None, None, None, None)
print("正确捕获每个参数为空的情况")
@pytest.mark.level0
def test_all_valid_combinations(self):
"""测试所有有效的组合"""
all_valid_combinations = [
("torch", "triton_ascend", "ascend", "ascend910b4"),
("torch", "triton_cuda", "cuda", "a100"),
]
for framework, dsl, backend, arch in all_valid_combinations:
try:
check_task_config(framework, backend, arch, dsl)
print(f"有效组合: {framework} + {backend} + {arch} + {dsl}")
except ValueError as e:
pytest.fail(f"有效组合验证未通过: {framework} + {backend} + {arch} + {dsl}, 错误: {e}")