from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
import pytest
from msagent.cli.handlers import models as models_module
from msagent.cli.handlers.models import ModelHandler
from msagent.configs import LLMConfig, LLMProvider
def _build_llm_config(alias: str, provider: LLMProvider = LLMProvider.OPENAI) -> LLMConfig:
return LLMConfig.model_construct(
alias=alias,
provider=provider,
model="gpt-4o-mini",
context_window=128000,
)
def _build_session(tmp_path: Path, *, model: str = "default") -> SimpleNamespace:
return SimpleNamespace(
context=SimpleNamespace(
agent="Profiler",
working_dir=tmp_path,
model=model,
thread_id="thread-1",
),
update_context=lambda **kwargs: None,
)
def test_model_handler_groups_models_by_provider() -> None:
models = [
_build_llm_config("default", LLMProvider.OPENAI),
_build_llm_config("fast", LLMProvider.OPENAI),
_build_llm_config("gemini-pro", LLMProvider.GOOGLE),
]
grouped = ModelHandler._group_models_by_provider(models)
assert "openai" in grouped
assert "google" in grouped
assert len(grouped["openai"]) == 2
assert len(grouped["google"]) == 1
def test_model_handler_format_tabbed_model_list_shows_current_and_default_markers() -> None:
models = [
_build_llm_config("default", LLMProvider.OPENAI),
_build_llm_config("fast", LLMProvider.OPENAI),
]
handler = ModelHandler(_build_session(Path.cwd(), model="default"))
providers = handler._group_models_by_provider(models)
provider_names = list(providers.keys())
formatted = handler._format_tabbed_model_list(
providers,
provider_names,
selected_provider_idx=0,
selected_model_idx=0,
current_model="default",
default_model="default",
)
text = "".join(fragment[1] for fragment in formatted)
assert "[current]" in text
assert "[default]" in text
def test_model_handler_format_tabbed_model_list_shows_provider_tabs() -> None:
models = [
_build_llm_config("gpt-4o", LLMProvider.OPENAI),
_build_llm_config("gemini-pro", LLMProvider.GOOGLE),
]
handler = ModelHandler(_build_session(Path.cwd(), model="gpt-4o"))
providers = handler._group_models_by_provider(models)
provider_names = list(providers.keys())
formatted = handler._format_tabbed_model_list(
providers,
provider_names,
selected_provider_idx=0,
selected_model_idx=0,
current_model="gpt-4o",
default_model="gpt-4o",
)
text = "".join(fragment[1] for fragment in formatted)
assert "openai" in text
assert "google" in text
@pytest.mark.asyncio
async def test_model_handler_reports_no_other_models_when_only_one(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
errors: list[str] = []
monkeypatch.setattr(models_module.console, "print_error", errors.append)
monkeypatch.setattr(models_module.console, "print", lambda *_args, **_kwargs: None)
agent_config = SimpleNamespace(llm=SimpleNamespace(alias="default"))
config_data = SimpleNamespace(llms=[_build_llm_config("default")])
async def fake_load_agent_config(_agent, _working_dir):
return agent_config
async def fake_load_llms_config(_working_dir):
return config_data
monkeypatch.setattr(models_module.initializer, "load_agent_config", fake_load_agent_config)
monkeypatch.setattr(models_module.initializer, "load_llms_config", fake_load_llms_config)
handler = ModelHandler(_build_session(tmp_path))
await handler.handle()
assert "No other models available" in errors
@pytest.mark.asyncio
async def test_model_handler_updates_context_on_successful_switch(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
new_model = _build_llm_config("fast", LLMProvider.OPENAI)
default_model = _build_llm_config("default", LLMProvider.OPENAI)
agent_config = SimpleNamespace(llm=SimpleNamespace(alias="default"))
config_data = SimpleNamespace(llms=[default_model, new_model])
context_updates: dict[str, object] = {}
session = SimpleNamespace(
context=SimpleNamespace(
agent="Profiler",
working_dir=tmp_path,
model="default",
thread_id="thread-1",
),
update_context=lambda **kwargs: context_updates.update(kwargs),
)
async def fake_load_agent_config(_agent, _working_dir):
return agent_config
async def fake_load_llms_config(_working_dir):
return config_data
async def fake_update_agent_llm(_agent, _model, _working_dir):
pass
async def fake_get_model_selection(_models, _current, _default):
return "fast"
monkeypatch.setattr(models_module.initializer, "load_agent_config", fake_load_agent_config)
monkeypatch.setattr(models_module.initializer, "load_llms_config", fake_load_llms_config)
monkeypatch.setattr(models_module.initializer, "update_agent_llm", fake_update_agent_llm)
handler = ModelHandler(session)
monkeypatch.setattr(handler, "_get_model_selection", fake_get_model_selection)
await handler.handle()
assert context_updates.get("model") == "fast"
@pytest.mark.asyncio
async def test_model_handler_skips_update_when_selection_canceled(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
default_model = _build_llm_config("default", LLMProvider.OPENAI)
new_model = _build_llm_config("fast", LLMProvider.OPENAI)
agent_config = SimpleNamespace(llm=SimpleNamespace(alias="default"))
config_data = SimpleNamespace(llms=[default_model, new_model])
context_updates: dict[str, object] = {}
session = SimpleNamespace(
context=SimpleNamespace(
agent="Profiler",
working_dir=tmp_path,
model="default",
thread_id="thread-1",
),
update_context=lambda **kwargs: context_updates.update(kwargs),
)
async def fake_load_agent_config(_agent, _working_dir):
return agent_config
async def fake_load_llms_config(_working_dir):
return config_data
async def fake_get_model_selection(_models, _current, _default):
return ""
monkeypatch.setattr(models_module.initializer, "load_agent_config", fake_load_agent_config)
monkeypatch.setattr(models_module.initializer, "load_llms_config", fake_load_llms_config)
handler = ModelHandler(session)
monkeypatch.setattr(handler, "_get_model_selection", fake_get_model_selection)
await handler.handle()
assert context_updates == {}
@pytest.mark.asyncio
async def test_model_handler_handles_exception_gracefully(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
errors: list[str] = []
monkeypatch.setattr(models_module.console, "print_error", errors.append)
monkeypatch.setattr(models_module.console, "print", lambda *_args, **_kwargs: None)
async def fake_load_agent_config(_agent, _working_dir):
raise RuntimeError("config load failed")
monkeypatch.setattr(models_module.initializer, "load_agent_config", fake_load_agent_config)
handler = ModelHandler(_build_session(tmp_path))
await handler.handle()
assert any("Error switching models" in e for e in errors)