"""
测试多模型配置功能
"""
import pytest
from akg_agents.core_v2.config import ModelConfig, get_settings, AKGSettings
from akg_agents.core_v2.llm import create_llm_client
class TestMultiModelConfig:
"""测试多模型配置"""
def test_model_config_creation(self):
"""测试 ModelConfig 创建"""
config = ModelConfig(
base_url="https://api.openai.com/v1",
api_key="sk-test",
model_name="gpt-4",
temperature=0.0
)
assert config.base_url == "https://api.openai.com/v1"
assert config.model_name == "gpt-4"
assert config.temperature == 0.0
def test_model_config_from_dict(self):
"""测试从字典创建 ModelConfig"""
data = {
"base_url": "https://api.deepseek.com/v1",
"api_key": "sk-test",
"model_name": "deepseek-chat",
"temperature": 0.5
}
config = ModelConfig.from_dict(data)
assert config.model_name == "deepseek-chat"
assert config.temperature == 0.5
def test_settings_with_multiple_models(self, monkeypatch, tmp_path):
"""测试多模型配置加载"""
monkeypatch.setenv("AKG_AGENTS_COMPLEX_BASE_URL", "https://api.openai.com/v1")
monkeypatch.setenv("AKG_AGENTS_COMPLEX_API_KEY", "sk-complex")
monkeypatch.setenv("AKG_AGENTS_COMPLEX_MODEL_NAME", "gpt-4")
monkeypatch.setenv("AKG_AGENTS_STANDARD_BASE_URL", "https://api.deepseek.com/v1")
monkeypatch.setenv("AKG_AGENTS_STANDARD_API_KEY", "sk-standard")
monkeypatch.setenv("AKG_AGENTS_STANDARD_MODEL_NAME", "deepseek-chat")
monkeypatch.setenv("AKG_AGENTS_FAST_BASE_URL", "https://api.openai.com/v1")
monkeypatch.setenv("AKG_AGENTS_FAST_API_KEY", "sk-fast")
monkeypatch.setenv("AKG_AGENTS_FAST_MODEL_NAME", "gpt-3.5-turbo")
settings = get_settings()
assert "complex" in settings.models
assert "standard" in settings.models
assert "fast" in settings.models
assert settings.models["complex"].model_name == "gpt-4"
assert settings.models["standard"].model_name == "deepseek-chat"
assert settings.models["fast"].model_name == "gpt-3.5-turbo"
class TestLLMClientFactory:
"""测试 LLM Client 工厂函数"""
def test_create_client_with_level(self, monkeypatch):
"""测试使用模型级别创建 client"""
monkeypatch.setenv("AKG_AGENTS_STANDARD_BASE_URL", "https://api.openai.com/v1")
monkeypatch.setenv("AKG_AGENTS_STANDARD_API_KEY", "sk-test")
monkeypatch.setenv("AKG_AGENTS_STANDARD_MODEL_NAME", "gpt-4")
client = create_llm_client(model_level="standard")
assert client.provider.model_name == "gpt-4"
def test_create_client_with_direct_params(self):
"""测试直接指定参数创建 client"""
client = create_llm_client(
model_name="deepseek-chat",
base_url="https://api.deepseek.com/v1",
api_key="sk-test"
)
assert client.provider.model_name == "deepseek-chat"
assert "deepseek" in str(client.provider.client.base_url)
def test_backward_compatibility(self, monkeypatch):
"""测试向后兼容性(旧的单模型配置)"""
monkeypatch.setenv("AKG_AGENTS_BASE_URL", "https://api.deepseek.com/v1")
monkeypatch.setenv("AKG_AGENTS_API_KEY", "sk-test")
monkeypatch.setenv("AKG_AGENTS_MODEL_NAME", "deepseek-chat")
settings = get_settings()
assert "standard" in settings.models
assert settings.models["standard"].model_name == "deepseek-chat"
class TestModelLevelFallback:
"""测试模型级别 fallback 机制"""
def test_resolve_model_level_direct_match(self):
"""测试请求级别直接存在"""
settings = AKGSettings()
settings.models["standard"] = ModelConfig(
base_url="http://test", api_key="key", model_name="gpt-4"
)
settings.default_model = "standard"
level, config = settings.resolve_model_level("standard")
assert level == "standard"
assert config.model_name == "gpt-4"
def test_resolve_model_level_fallback_to_default(self):
"""测试请求级别不存在,fallback 到 default_model"""
settings = AKGSettings()
settings.models["standard"] = ModelConfig(
base_url="http://test", api_key="key", model_name="gpt-4"
)
settings.default_model = "standard"
level, config = settings.resolve_model_level("fast")
assert level == "standard"
assert config.model_name == "gpt-4"
def test_resolve_model_level_fallback_to_order(self):
"""测试请求级别和 default 都不存在,按顺序 fallback"""
settings = AKGSettings()
settings.models["complex"] = ModelConfig(
base_url="http://test", api_key="key", model_name="gpt-4o"
)
settings.default_model = "fast"
level, config = settings.resolve_model_level("coder")
assert level == "complex"
assert config.model_name == "gpt-4o"
def test_resolve_model_level_no_fallback_available(self):
"""测试所有 fallback 都失败"""
settings = AKGSettings()
settings.models = {}
settings.default_model = "standard"
with pytest.raises(ValueError) as exc_info:
settings.resolve_model_level("fast")
assert "not found and no fallback available" in str(exc_info.value)
def test_resolve_model_level_fallback_priority(self):
"""测试 fallback 优先级:请求 -> default -> complex -> standard -> fast"""
settings = AKGSettings()
settings.models["fast"] = ModelConfig(
base_url="http://fast", api_key="key", model_name="fast-model"
)
settings.models["complex"] = ModelConfig(
base_url="http://complex", api_key="key", model_name="complex-model"
)
settings.default_model = "standard"
level, config = settings.resolve_model_level("coder")
assert level == "complex"
assert config.model_name == "complex-model"
settings.models = {}
settings.models["fast"] = ModelConfig(
base_url="http://fast", api_key="key", model_name="fast-model"
)
level, config = settings.resolve_model_level("coder")
assert level == "fast"
if __name__ == "__main__":
pytest.main([__file__, "-v"])