"""
Basic functionality tests for Coordinator server
Using FastAPI TestClient for testing
"""
import json
import pytest
from fastapi.testclient import TestClient
from fastapi.responses import JSONResponse
from fastapi import HTTPException
from unittest.mock import patch, MagicMock, AsyncMock
from fastapi import FastAPI
from motor.common.standby.standby_manager import StandbyRole, StandbyManager
from motor.coordinator.api_server.management_server import ManagementServer
from motor.coordinator.domain.probe import RoleHeartbeatResult
from motor.coordinator.api_server.inference_server import InferenceServer, _validate_anthropic_request
from motor.coordinator.domain.request_manager import RequestManager
from motor.config.coordinator import CoordinatorConfig, RateLimitConfig
from motor.coordinator.domain import InstanceReadiness
from motor.common.http.key_encryption import encrypt_api_key, set_default_key_encryption_by_name
from motor.coordinator.models.constants import OpenAIField
from motor.coordinator.middleware.fastapi_middleware import (
SimpleRateLimitMiddleware,
create_simple_rate_limit_middleware,
)
def _copy_routes(
src: FastAPI,
dst: FastAPI,
skip_paths: list | None = None,
) -> None:
"""Copy routes from src to dst; optionally skip some paths (test helper)."""
skip_set = set(skip_paths or [])
for route in src.routes:
path = getattr(route, "path", None)
if path is None or path in skip_set:
continue
endpoint = getattr(route, "endpoint", None)
if endpoint is None:
continue
methods = getattr(route, "methods", None) or {"GET"}
dst.add_api_route(path, endpoint, methods=list(methods))
def _openai_is_stream(body_json: dict) -> bool:
"""Return True if body has stream enabled (test helper)."""
if OpenAIField.STREAM not in body_json:
return False
stream_value = body_json[OpenAIField.STREAM]
if isinstance(stream_value, str):
return stream_value.lower() in ("true", "1", "yes")
return bool(stream_value)
def create_unified_app_for_test(
mgmt: ManagementServer,
inf: InferenceServer,
rate_limit_config: RateLimitConfig | None = None,
) -> FastAPI:
"""Merge management + inference routes into one app (test helper)."""
unified = FastAPI(lifespan=mgmt.lifespan)
_copy_routes(mgmt.management_app, unified)
_copy_routes(inf.app, unified)
if not getattr(unified.state, "request_manager", None):
unified.state.request_manager = inf._request_manager
if rate_limit_config and rate_limit_config.enable_rate_limit:
middleware = create_simple_rate_limit_middleware(
app=unified,
max_requests=rate_limit_config.max_requests,
window_size=rate_limit_config.window_size,
)
unified.add_middleware(
SimpleRateLimitMiddleware,
rate_limiter=middleware.rate_limiter,
skip_paths=rate_limit_config.skip_paths,
error_message=rate_limit_config.error_message,
error_status_code=rate_limit_config.error_status_code,
)
return unified
class _TestServerShell:
"""Thin shell for tests: composes ManagementServer + InferenceServer (replaces CoordinatorServer)."""
def __init__(self, config: CoordinatorConfig | None = None) -> None:
self._mgmt = ManagementServer(config)
_config = config or CoordinatorConfig()
_request_manager = RequestManager(_config)
self._inf = InferenceServer(_config, request_manager=_request_manager)
self.coordinator_config = self._mgmt.coordinator_config
@property
def management_app(self) -> FastAPI:
return self._mgmt.management_app
@property
def inference_app(self) -> FastAPI:
return self._inf.app
@property
def instance_manager(self):
return self._mgmt.instance_manager
@instance_manager.setter
def instance_manager(self, value) -> None:
self._mgmt.instance_manager = value
@property
def _daemon_liveness(self):
"""Expose for tests that patch read_role_and_heartbeat."""
return self._mgmt._daemon_liveness
@property
def lifespan(self):
return self._mgmt.lifespan
def setup_rate_limiting(self, rate_limit_config: RateLimitConfig | None = None) -> None:
self._inf.setup_rate_limiting(rate_limit_config=rate_limit_config)
def create_unified_app(
self,
rate_limit_config: RateLimitConfig | None = None,
) -> FastAPI:
return create_unified_app_for_test(self._mgmt, self._inf, rate_limit_config)
def _copy_routes(
self,
src: FastAPI,
dst: FastAPI,
skip_paths: list | None = None,
) -> None:
_copy_routes(src, dst, skip_paths)
def _openai_is_stream(self, body_json: dict) -> bool:
return _openai_is_stream(body_json)
class TestCoordinatorServer:
"""Mock test class for Coordinator server"""
def setup_method(self):
"""Setup test fixtures"""
self._im_patcher = patch("motor.coordinator.api_server.management_server.InstanceManager")
im_mock_cls = self._im_patcher.start()
im_instance = MagicMock()
im_instance.has_required_instances.return_value = True
im_instance.get_required_instances_status.return_value = InstanceReadiness.REQUIRED_MET
im_instance.refresh_instances = AsyncMock(return_value=None)
im_mock_cls.return_value = im_instance
async def mock_handle_request(request, config, scheduler=None, request_manager=None):
"""Mock handle_request that returns JSON response matching test expectations"""
try:
body_json = await request.json()
except Exception:
try:
request_body = await request.body()
body_json = json.loads(request_body.decode("utf-8"))
except Exception:
body_json = {}
input_data = ""
if "prompt" in body_json:
input_data = str(body_json["prompt"])
elif "messages" in body_json:
input_data = json.dumps(body_json["messages"], ensure_ascii=False)
is_stream = body_json.get("stream", False)
if isinstance(is_stream, str):
is_stream = is_stream.lower() in ("true", "1", "yes")
request_type = "openai"
if request.url.path.endswith("/completions"):
request_type = "completions"
elif request.url.path.endswith("/chat/completions"):
request_type = "chat_completions"
import hashlib
request_id = f"req-{hashlib.md5(str(body_json).encode()).hexdigest()[:8]}"
response_data = {
"request_id": request_id,
"status": "success",
"data": {"input_data": input_data, "is_stream": bool(is_stream), "request_type": request_type},
}
return JSONResponse(content=response_data)
self._handle_request_patcher = patch(
"motor.coordinator.api_server.inference_server.handle_request", side_effect=mock_handle_request
)
self._handle_request_patcher.start()
self._is_available_patcher = patch(
"motor.coordinator.api_server.inference_server.InferenceServer._is_available",
new_callable=AsyncMock,
return_value=True,
)
self._is_available_patcher.start()
coordinator_config = CoordinatorConfig()
coordinator_config.api_key_config.enable_api_key = True
coordinator_config.api_key_config.encryption_algorithm = "PBKDF2_SHA256"
set_default_key_encryption_by_name(coordinator_config.api_key_config.encryption_algorithm)
plain_key1 = "sk-test123456789"
plain_key2 = "sk-coordinator2024"
encrypted_key1 = encrypt_api_key(plain_key1)
encrypted_key2 = encrypt_api_key(plain_key2)
coordinator_config.api_key_config.valid_keys = {encrypted_key1, encrypted_key2}
self.coordinator_config = coordinator_config
self.coordinator_server = _TestServerShell(config=coordinator_config)
self.coordinator_server.setup_rate_limiting()
inf = self.coordinator_server._inf
inf._is_available = AsyncMock(return_value=True)
_mock_scheduler = MagicMock()
_mock_scheduler.get_available_instances = AsyncMock(return_value={})
inf._get_scheduler_client = lambda: _mock_scheduler
mgmt_app = self.coordinator_server.management_app
inference_app = self.coordinator_server.inference_app
if not getattr(inference_app.state, "request_manager", None):
inference_app.state.request_manager = inf._request_manager
self.mgmt_client = TestClient(mgmt_app)
self.openai_client = TestClient(inference_app)
self.valid_api_key = "sk-test123456789"
def teardown_method(self):
"""Teardown test fixtures"""
try:
if hasattr(self, "_im_patcher"):
self._im_patcher.stop()
if hasattr(self, "_handle_request_patcher"):
self._handle_request_patcher.stop()
if hasattr(self, "_is_available_patcher"):
self._is_available_patcher.stop()
except Exception:
pass
def test_liveness_endpoints(self):
"""Test liveness check endpoints"""
response = self.mgmt_client.get("/liveness")
assert response.status_code == 200, f"Liveness probe failed: {response.status_code}"
data = response.json()
assert data["status"] == "ok", f"Liveness probe status abnormal: {data}"
response = self.mgmt_client.get("/startup")
assert response.status_code == 200, f"Startup probe failed: {response.status_code}"
data = response.json()
assert data["status"] == "ok", f"Startup probe status abnormal: {data}"
response = self.mgmt_client.get("/readiness")
assert response.status_code == 200, f"Readiness check failed: {response.status_code}"
data = response.json()
assert data["status"] == "ok", f"Readiness check status abnormal: {data}"
def test_readiness_endpoints_fail_when_instance_manager_not_ready(self):
"""Test readiness when instance manager reports not ready (reuse server's mock)."""
im = MagicMock()
im.get_required_instances_status.return_value = InstanceReadiness.NONE
self.coordinator_server.instance_manager = im
response = self.mgmt_client.get("/readiness")
assert response.status_code == 200
data = response.json()
assert data["message"] == "Coordinator is ok"
assert data["ready"] is False
def test_readiness_endpoints_fail_when_instance_manager_ready(self):
"""Test readiness when instance manager reports ready (default mock)."""
response = self.mgmt_client.get("/readiness")
assert response.status_code == 200
data = response.json()
assert data["message"] == "Coordinator is ok"
assert data["ready"] is True
def test_readiness_endpoints_fail_when_enable_standby_is_master_but_instance_not_ready(self):
"""Test readiness when standby is master but instance manager not ready."""
im = MagicMock()
im.get_required_instances_status.return_value = InstanceReadiness.NONE
self.coordinator_server.instance_manager = im
self.coordinator_config.standby_config.enable_master_standby = True
self.coordinator_server._mgmt._readiness_probe._enable_master_standby = True
standby_manager = StandbyManager(self.coordinator_config)
standby_manager.current_role = StandbyRole.MASTER
with patch.object(
self.coordinator_server._daemon_liveness,
"read_role_and_heartbeat",
return_value=RoleHeartbeatResult(is_master=True, heartbeat_stale=False, orphaned=False),
):
response = self.mgmt_client.get("/readiness")
assert response.status_code == 200
data = response.json()
assert data["message"] == "Coordinator is master"
assert data["ready"] is False
def test_readiness_endpoints_fail_when_enable_standby_is_standby(self):
"""Test readiness endpoints"""
self.coordinator_config.standby_config.enable_master_standby = True
self.coordinator_server._mgmt._readiness_probe._enable_master_standby = True
standby_manager = StandbyManager(self.coordinator_config)
standby_manager.current_role = StandbyRole.STANDBY
with patch.object(
self.coordinator_server._daemon_liveness,
"read_role_and_heartbeat",
return_value=RoleHeartbeatResult(is_master=False, heartbeat_stale=False, orphaned=False),
):
response = self.mgmt_client.get("/readiness")
assert response.status_code == 503, f"Readiness check failed: {response.status_code}"
data = response.json()
assert data["detail"] == "Coordinator is not master"
def test_readiness_endpoints_fail_when_enable_standby_is_master(self):
"""Test readiness endpoints"""
self.coordinator_config.standby_config.enable_master_standby = True
self.coordinator_server._mgmt._readiness_probe._enable_master_standby = True
standby_manager = StandbyManager(self.coordinator_config)
standby_manager.current_role = StandbyRole.MASTER
with patch.object(
self.coordinator_server._daemon_liveness,
"read_role_and_heartbeat",
return_value=RoleHeartbeatResult(is_master=True, heartbeat_stale=False, orphaned=False),
):
response = self.mgmt_client.get("/readiness")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["message"] == "Coordinator is master"
assert data["ready"] is True
def test_root_endpoints(self):
"""Test root endpoints"""
response = self.mgmt_client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "Motor Coordinator Management Server"
assert data["version"] == "1.0.0"
def test_list_models_exception(self):
"""Test list_models endpoints"""
response = self.openai_client.get("/v1/models")
assert response.status_code == 503
def test_list_models_ok(self):
"""Test list_models endpoints"""
self.coordinator_config.aigw_model = {"k": "v"}
response = self.openai_client.get("/v1/models")
assert response.status_code == 200
print(response.json()["data"][0])
assert response.json()["data"] is not None
assert response.json()["data"][0]["p_instances_num"] == 0
assert response.json()["data"][0]["d_instances_num"] == 0
def test_openai_completions_api(self):
"""Test OpenAI Completions API"""
test_cases = [
{
"name": "Basic completion request",
"data": {
"model": "text-davinci-003",
"prompt": "Write a poem about spring",
"max_tokens": 100,
"temperature": 0.7,
},
},
{
"name": "Completion request with stop tokens",
"data": {
"model": "text-davinci-003",
"prompt": "Differences between Python lists and tuples:",
"max_tokens": 200,
"temperature": 0.8,
"stop": ["\n\n", "Summary"],
},
},
{"name": "Minimal parameter completion request", "data": {"model": "text-davinci-003", "prompt": "Hello"}},
]
for test_case in test_cases:
response = self.openai_client.post(
"/v1/completions",
json=test_case["data"],
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 200, f"Completions API failed: {response.status_code}"
data = response.json()
assert "request_id" in data, "Response missing request_id"
assert "status" in data, "Response missing status"
assert "data" in data, "Response missing data field"
assert "input_data" in data["data"], "Response data missing input_data"
assert "is_stream" in data["data"], "Response data missing is_stream"
assert "request_type" in data["data"], "Response data missing request_type"
def test_openai_chat_completions_api(self):
"""Test OpenAI Chat Completions API"""
test_cases = [
{
"name": "Basic chat completion request",
"data": {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, please introduce yourself"}],
"max_tokens": 100,
"temperature": 0.7,
},
},
{
"name": "Multi-turn conversation chat completion request",
"data": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "What is machine learning?"},
{"role": "assistant", "content": "Machine learning is a branch of artificial intelligence..."},
{"role": "user", "content": "Can you give an example?"},
],
"max_tokens": 200,
"temperature": 0.8,
},
},
{
"name": "Chat completion request with system message",
"data": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a professional Python programming assistant"},
{"role": "user", "content": "Please write a Python implementation of quicksort"},
],
"max_tokens": 500,
"temperature": 0.5,
},
},
{
"name": "Minimal parameter chat completion request",
"data": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]},
},
]
for test_case in test_cases:
response = self.openai_client.post(
"/v1/chat/completions",
json=test_case["data"],
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 200, f"Chat Completions API failed: {response.status_code}"
data = response.json()
assert "request_id" in data, "Response missing request_id"
assert "status" in data, "Response missing status"
assert "data" in data, "Response missing data field"
assert "input_data" in data["data"], "Response data missing input_data"
assert "is_stream" in data["data"], "Response data missing is_stream"
assert "request_type" in data["data"], "Response data missing request_type"
def test_streaming_requests(self):
"""Test streaming requests"""
completion_stream_data = {
"model": "text-davinci-003",
"prompt": "Write a Python function to calculate the Fibonacci sequence",
"max_tokens": 200,
"temperature": 0.7,
"stream": True,
}
response = self.openai_client.post(
"/v1/completions",
json=completion_stream_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 200, f"Streaming completion request failed: {response.status_code}"
data = response.json()
assert data["data"]["is_stream"], "Stream flag not set correctly"
chat_stream_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Please explain the basic concepts of deep learning in detail"}],
"max_tokens": 300,
"temperature": 0.7,
"stream": True,
}
response = self.openai_client.post(
"/v1/chat/completions",
json=chat_stream_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 200, f"Streaming chat completion request failed: {response.status_code}"
data = response.json()
assert data["data"]["is_stream"], "Stream flag not set correctly"
def test_error_handling(self):
"""Test error handling"""
response = self.openai_client.post(
"/v1/completions",
content="invalid json",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code in [400, 422, 500], f"Invalid JSON handling exception: {response.status_code}"
invalid_data = {
"prompt": "test"
}
response = self.openai_client.post(
"/v1/completions",
json=invalid_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code in [400, 422, 500], f"Missing field handling exception: {response.status_code}"
invalid_chat_data = {
"model": "gpt-3.5-turbo",
"messages": "invalid messages",
}
response = self.openai_client.post(
"/v1/chat/completions",
json=invalid_chat_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code in [400, 422, 500], (
f"Invalid chat completion handling exception: {response.status_code}"
)
def test_rate_limiting(self):
"""Test rate limiting functionality"""
for i in range(30):
test_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": f"This is the {i + 1}th rate limiting test request"}],
"max_tokens": 10,
}
response = self.openai_client.post(
"/v1/chat/completions",
json=test_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
if response.status_code == 429:
break
assert True, "Rate limiting test completed"
def test_api_key_validation(self):
"""Test API Key validation functionality"""
valid_api_keys = ["sk-test123456789", "sk-coordinator2024"]
invalid_api_key = "sk-invalid-key"
test_data = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "test"}], "max_tokens": 10}
response = self.openai_client.post(
"/v1/chat/completions", json=test_data, headers={"Content-Type": "application/json"}
)
assert response.status_code == 401, f"Expected 401, got: {response.status_code}"
error_data = response.json()
assert "detail" in error_data, "Error response missing detail field"
response = self.openai_client.post(
"/v1/chat/completions",
json=test_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {invalid_api_key}"},
)
assert response.status_code == 403, f"Expected 403, got: {response.status_code}"
error_data = response.json()
assert "detail" in error_data, "Error response missing detail field"
for valid_key in valid_api_keys:
response = self.openai_client.post(
"/v1/chat/completions",
json=test_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {valid_key}"},
)
assert response.status_code == 200, f"Valid API Key request failed: {response.status_code}"
data = response.json()
assert "request_id" in data, "Response missing request_id"
response = self.openai_client.post(
"/v1/chat/completions",
json=test_data,
headers={
"Content-Type": "application/json",
"Authorization": valid_api_keys[0],
},
)
assert response.status_code in [200, 401, 403], f"Unexpected status code: {response.status_code}"
skip_paths = ["/startup", "/readiness"]
for path in skip_paths:
response = self.mgmt_client.get(f"{path}")
assert response.status_code == 200, f"Skip path {path} returned non-200 status code: {response.status_code}"
completion_data = {"model": "text-davinci-003", "prompt": "test", "max_tokens": 10}
response = self.openai_client.post(
"/v1/completions", json=completion_data, headers={"Content-Type": "application/json"}
)
assert response.status_code == 401, f"Expected 401, got: {response.status_code}"
response = self.openai_client.post(
"/v1/completions",
json=completion_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {valid_api_keys[0]}"},
)
assert response.status_code == 200, f"Completions API with valid API Key request failed: {response.status_code}"
class TestFastAPIMiddleware:
"""Test FastAPI middleware functionality"""
def setup_method(self):
"""Setup test fixtures"""
from motor.coordinator.middleware.fastapi_middleware import (
SimpleRateLimitMiddleware,
SimpleRateLimitConfig,
load_rate_limit_config,
create_simple_rate_limit_middleware,
)
from motor.coordinator.middleware.rate_limiter import SimpleRateLimiter
from fastapi import FastAPI
from fastapi.testclient import TestClient
self.app = FastAPI()
self.SimpleRateLimitMiddleware = SimpleRateLimitMiddleware
self.SimpleRateLimitConfig = SimpleRateLimitConfig
self.load_rate_limit_config = load_rate_limit_config
self.create_simple_rate_limit_middleware = create_simple_rate_limit_middleware
self.SimpleRateLimiter = SimpleRateLimiter
self.TestClient = TestClient
self._report_alarms_patcher = patch(
"motor.coordinator.api_client.controller_api_client.ControllerApiClient.report_alarms",
return_value=True,
)
self._report_alarms_patcher.start()
def teardown_method(self):
self._report_alarms_patcher.stop()
def test_simple_rate_limit_config(self):
"""Test SimpleRateLimitConfig dataclass"""
config = self.SimpleRateLimitConfig()
assert config.enabled is True, "Default enabled should be True"
assert config.max_requests == 100, "Default max_requests should be 100"
assert config.window_size == 60, "Default window_size should be 60"
assert config.scope == "per_ip", "Default scope should be per_ip"
assert config.skip_paths is not None, "skip_paths should be initialized"
assert "/liveness" in config.skip_paths, "/liveness should be in skip_paths"
def test_load_rate_limit_config_default(self):
"""Test load_rate_limit_config with default values"""
import os
original_enabled = os.getenv("RATE_LIMIT_ENABLED")
original_max = os.getenv("RATE_LIMIT_MAX_REQUESTS")
original_window = os.getenv("RATE_LIMIT_WINDOW_SIZE")
try:
if "RATE_LIMIT_ENABLED" in os.environ:
del os.environ["RATE_LIMIT_ENABLED"]
if "RATE_LIMIT_MAX_REQUESTS" in os.environ:
del os.environ["RATE_LIMIT_MAX_REQUESTS"]
if "RATE_LIMIT_WINDOW_SIZE" in os.environ:
del os.environ["RATE_LIMIT_WINDOW_SIZE"]
config = self.load_rate_limit_config()
assert config.enabled, "Should use default enabled=True"
assert config.max_requests == 100, "Should use default max_requests=100"
assert config.window_size == 60, "Should use default window_size=60"
finally:
if original_enabled:
os.environ["RATE_LIMIT_ENABLED"] = original_enabled
if original_max:
os.environ["RATE_LIMIT_MAX_REQUESTS"] = original_max
if original_window:
os.environ["RATE_LIMIT_WINDOW_SIZE"] = original_window
def test_load_rate_limit_config_from_env(self):
"""Test load_rate_limit_config from environment variables"""
import os
original_enabled = os.getenv("RATE_LIMIT_ENABLED")
original_max = os.getenv("RATE_LIMIT_MAX_REQUESTS")
original_window = os.getenv("RATE_LIMIT_WINDOW_SIZE")
original_scope = os.getenv("RATE_LIMIT_SCOPE")
original_skip_paths = os.getenv("RATE_LIMIT_SKIP_PATHS")
try:
os.environ["RATE_LIMIT_ENABLED"] = "false"
os.environ["RATE_LIMIT_MAX_REQUESTS"] = "200"
os.environ["RATE_LIMIT_WINDOW_SIZE"] = "30"
os.environ["RATE_LIMIT_SCOPE"] = "global"
os.environ["RATE_LIMIT_SKIP_PATHS"] = "/liveness,/health"
config = self.load_rate_limit_config()
assert not config.enabled, "Should load enabled from env"
assert config.max_requests == 200, "Should load max_requests from env"
assert config.window_size == 30, "Should load window_size from env"
assert config.scope == "global", "Should load scope from env"
assert "/liveness" in config.skip_paths, "Should load skip_paths from env"
assert "/health" in config.skip_paths, "Should load skip_paths from env"
finally:
if original_enabled:
os.environ["RATE_LIMIT_ENABLED"] = original_enabled
elif "RATE_LIMIT_ENABLED" in os.environ:
del os.environ["RATE_LIMIT_ENABLED"]
if original_max:
os.environ["RATE_LIMIT_MAX_REQUESTS"] = original_max
elif "RATE_LIMIT_MAX_REQUESTS" in os.environ:
del os.environ["RATE_LIMIT_MAX_REQUESTS"]
if original_window:
os.environ["RATE_LIMIT_WINDOW_SIZE"] = original_window
elif "RATE_LIMIT_WINDOW_SIZE" in os.environ:
del os.environ["RATE_LIMIT_WINDOW_SIZE"]
if original_scope:
os.environ["RATE_LIMIT_SCOPE"] = original_scope
elif "RATE_LIMIT_SCOPE" in os.environ:
del os.environ["RATE_LIMIT_SCOPE"]
if original_skip_paths:
os.environ["RATE_LIMIT_SKIP_PATHS"] = original_skip_paths
elif "RATE_LIMIT_SKIP_PATHS" in os.environ:
del os.environ["RATE_LIMIT_SKIP_PATHS"]
def test_load_rate_limit_config_from_file(self):
"""Test load_rate_limit_config from file"""
import os
import json
import tempfile
config_data = {
"enabled": False,
"max_requests": 300,
"window_size": 45,
"scope": "per_ip",
"error_message": "Custom error message",
"error_status_code": 429,
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config_data, f)
config_file = f.name
try:
config = self.load_rate_limit_config(config_file=config_file)
assert not config.enabled, "Should load enabled from file"
assert config.max_requests == 300, "Should load max_requests from file"
assert config.window_size == 45, "Should load window_size from file"
assert config.error_message == "Custom error message", "Should load error_message from file"
finally:
os.unlink(config_file)
def test_rate_limit_middleware_skip_paths(self):
"""Test rate limit middleware skip paths"""
@self.app.get("/test")
async def test_endpoint():
return {"status": "ok"}
@self.app.get("/liveness")
async def liveness_endpoint():
return {"status": "healthy"}
rate_limiter = self.SimpleRateLimiter(max_requests=1, window_size=60)
middleware = self.SimpleRateLimitMiddleware(app=self.app, rate_limiter=rate_limiter, skip_paths=["/liveness"])
client = self.TestClient(middleware)
for _ in range(5):
response = client.get("/liveness")
assert response.status_code == 200, "Liveness endpoint should not be rate limited"
response1 = client.get("/test")
assert response1.status_code == 200, "First request should succeed"
response2 = client.get("/test")
assert response2.status_code in [200, 429], "Second request may be rate limited"
def test_rate_limit_middleware_error_handling(self):
"""Test rate limit middleware error handling"""
rate_limiter = MagicMock()
rate_limiter.is_allowed = MagicMock(side_effect=Exception("Test error"))
@self.app.get("/test")
async def test_endpoint():
return {"status": "ok"}
middleware = self.SimpleRateLimitMiddleware(app=self.app, rate_limiter=rate_limiter, skip_paths=[])
client = self.TestClient(middleware)
response = client.get("/test")
assert response.status_code == 200, "Should allow request when error occurs"
assert middleware.stats["allowed_requests"] > 0, "Should increment allowed_requests on error"
def test_create_simple_rate_limit_middleware(self):
"""Test create_simple_rate_limit_middleware function"""
middleware = self.create_simple_rate_limit_middleware(app=self.app, max_requests=50, window_size=30)
assert middleware is not None, "Middleware should be created"
assert middleware.rate_limiter.max_requests == 50, "Should set max_requests"
assert middleware.rate_limiter.window_size == 30, "Should set window_size"
assert middleware.skip_paths is not None, "Should set skip_paths"
def test_rate_limit_middleware_stats(self):
"""Test rate limit middleware statistics"""
@self.app.get("/test")
async def test_endpoint():
return {"status": "ok"}
rate_limiter = self.SimpleRateLimiter(max_requests=10, window_size=60)
middleware = self.SimpleRateLimitMiddleware(app=self.app, rate_limiter=rate_limiter, skip_paths=[])
client = self.TestClient(middleware)
for _ in range(5):
client.get("/test")
assert middleware.stats["total_requests"] >= 5, "Should track total requests"
assert middleware.stats["allowed_requests"] >= 5, "Should track allowed requests"
assert "start_time" in middleware.stats, "Should track start time"
class TestCoordinatorServerAdvanced:
"""Advanced functionality test class for Coordinator server"""
def setup_method(self):
"""Setup test fixtures"""
self._im_patcher = patch("motor.coordinator.api_server.management_server.InstanceManager")
im_mock_cls = self._im_patcher.start()
im_instance = MagicMock()
im_instance.has_required_instances.return_value = True
im_instance.get_required_instances_status.return_value = InstanceReadiness.REQUIRED_MET
im_instance.refresh_instances = AsyncMock(return_value=None)
im_mock_cls.return_value = im_instance
async def mock_handle_request(request, config, scheduler=None, request_manager=None):
"""Mock handle_request that returns JSON response matching test expectations"""
try:
body_json = await request.json()
except Exception:
try:
request_body = await request.body()
body_json = json.loads(request_body.decode("utf-8"))
except Exception:
body_json = {}
input_data = ""
if "prompt" in body_json:
input_data = str(body_json["prompt"])
elif "messages" in body_json:
input_data = json.dumps(body_json["messages"], ensure_ascii=False)
is_stream = body_json.get("stream", False)
if isinstance(is_stream, str):
is_stream = is_stream.lower() in ("true", "1", "yes")
request_type = "openai"
if request.url.path.endswith("/completions"):
request_type = "completions"
elif request.url.path.endswith("/chat/completions"):
request_type = "chat_completions"
import hashlib
request_id = f"req-{hashlib.md5(str(body_json).encode()).hexdigest()[:8]}"
response_data = {
"request_id": request_id,
"status": "success",
"data": {"input_data": input_data, "is_stream": bool(is_stream), "request_type": request_type},
}
return JSONResponse(content=response_data)
self._handle_request_patcher = patch(
"motor.coordinator.api_server.inference_server.handle_request", side_effect=mock_handle_request
)
self._handle_request_patcher.start()
self._is_available_patcher = patch(
"motor.coordinator.api_server.inference_server.InferenceServer._is_available",
new_callable=AsyncMock,
return_value=True,
)
self._is_available_patcher.start()
coordinator_config = CoordinatorConfig()
coordinator_config.api_key_config.enable_api_key = True
coordinator_config.api_key_config.valid_keys = {"sk-test123456789", "sk-coordinator2024"}
self.coordinator_server = _TestServerShell(config=coordinator_config)
self.coordinator_server._mgmt._scheduler_connection = MagicMock()
self.coordinator_server._mgmt._scheduler_connection.ensure_connected = AsyncMock()
self.coordinator_server._mgmt._scheduler_connection.get_client.return_value = None
self.coordinator_server._mgmt._scheduler_connection.disconnect = AsyncMock()
self.coordinator_server.setup_rate_limiting()
inf = self.coordinator_server._inf
inf._is_available = AsyncMock(return_value=True)
_mock_scheduler = MagicMock()
_mock_scheduler.get_available_instances = AsyncMock(return_value={})
inf._get_scheduler_client = lambda: _mock_scheduler
if not getattr(inf.app.state, "request_manager", None):
inf.app.state.request_manager = inf._request_manager
self.mgmt_client = TestClient(self.coordinator_server.management_app)
self.valid_api_key = "sk-test123456789"
def teardown_method(self):
"""Teardown test fixtures"""
try:
if hasattr(self, "_im_patcher"):
self._im_patcher.stop()
if hasattr(self, "_handle_request_patcher"):
self._handle_request_patcher.stop()
if hasattr(self, "_is_available_patcher"):
self._is_available_patcher.stop()
except Exception:
pass
def test_refresh_instances_valid_request(self):
"""Test refresh_instances with valid request"""
valid_body = {
"event": "add",
"instances": [
{
"job_name": "test-job",
"model_name": "test-model",
"id": 1,
"role": "prefill",
"endpoints": {
"192.168.1.1": {
"0": {"id": 0, "ip": "192.168.1.1", "business_port": "8080", "mgmt_port": "18080"}
}
},
}
],
}
response = self.mgmt_client.post("/instances/refresh", json=valid_body)
assert response.status_code == 200, f"Refresh instances failed: {response.status_code}"
data = response.json()
assert data["status"] == "success", f"Refresh instances status abnormal: {data}"
assert "request_id" in data, "Response missing request_id"
assert "data" in data, "Response missing data field"
def test_refresh_instances_empty_body(self):
"""Test refresh_instances with empty body"""
response = self.mgmt_client.post("/instances/refresh", json={})
assert response.status_code == 400, f"Expected 400 for empty body, got: {response.status_code}"
def test_refresh_instances_invalid_json(self):
"""Test refresh_instances with invalid JSON"""
response = self.mgmt_client.post(
"/instances/refresh", content="invalid json", headers={"Content-Type": "application/json"}
)
assert response.status_code in [400, 422, 500], (
f"Expected 400/422/500 for invalid JSON, got: {response.status_code}"
)
def test_refresh_instances_invalid_event_msg(self):
"""Test refresh_instances with invalid event message format"""
invalid_body = {
"event": "INVALID_EVENT",
"instances": "not a list",
}
response = self.mgmt_client.post("/instances/refresh", json=invalid_body)
assert response.status_code == 400, f"Expected 400 for invalid format, got: {response.status_code}"
def test_refresh_instances_no_body(self):
"""Test refresh_instances with no body"""
response = self.mgmt_client.post("/instances/refresh", content=None)
assert response.status_code == 400, f"Expected 400 for no body, got: {response.status_code}"
def test_create_unified_app(self):
"""Test create_unified_app method"""
unified_app = self.coordinator_server.create_unified_app()
assert unified_app is not None, "Unified app should be created"
unified_client = TestClient(unified_app)
response = unified_client.get("/liveness")
assert response.status_code == 200, "Liveness endpoint should be available in unified app"
response = unified_client.post(
"/v1/chat/completions",
json={"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "test"}]},
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 200, "Chat completions endpoint should be available in unified app"
def test_create_unified_app_with_rate_limit_disabled(self):
"""Test create_unified_app with rate limit disabled"""
coordinator_config = CoordinatorConfig()
coordinator_config.rate_limit_config.enable_rate_limit = False
coordinator_server = _TestServerShell(config=coordinator_config)
coordinator_server.instance_manager = MagicMock()
unified_app = coordinator_server.create_unified_app()
assert unified_app is not None, "Unified app should be created even with rate limit disabled"
def test_create_unified_app_with_custom_rate_limit_config(self):
"""Test create_unified_app with custom rate limit config"""
from motor.config.coordinator import RateLimitConfig
custom_rate_limit_config = RateLimitConfig()
custom_rate_limit_config.enable_rate_limit = True
custom_rate_limit_config.max_requests = 50
custom_rate_limit_config.window_size = 30
unified_app = self.coordinator_server.create_unified_app(rate_limit_config=custom_rate_limit_config)
assert unified_app is not None, "Unified app should be created with custom rate limit config"
def test_copy_routes_skip_paths(self):
"""Test _copy_routes with skip paths"""
from fastapi import FastAPI
src_app = FastAPI()
@src_app.get("/test")
async def test():
return {"status": "ok"}
@src_app.get("/docs")
async def docs():
return {"status": "docs"}
dst_app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
self.coordinator_server._copy_routes(src_app, dst_app, skip_paths=["/docs"])
dst_client = TestClient(dst_app)
response = dst_client.get("/test")
assert response.status_code == 200, "/test route should be copied"
response = dst_client.get("/docs")
assert response.status_code == 404, "/docs route should be skipped"
def test_validate_openai_request_invalid_model(self):
"""Test _validate_openai_request with missing model"""
invalid_data = {
"messages": [{"role": "user", "content": "test"}]
}
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
json=invalid_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 400, f"Expected 400 for missing model, got: {response.status_code}"
def test_validate_openai_request_invalid_messages(self):
"""Test _validate_openai_request with invalid messages"""
invalid_data = {
"model": "gpt-3.5-turbo",
"messages": "not a list",
}
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
json=invalid_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 400, f"Expected 400 for invalid messages, got: {response.status_code}"
def test_validate_openai_request_empty_messages(self):
"""Test _validate_openai_request with empty messages list"""
invalid_data = {"model": "gpt-3.5-turbo", "messages": []}
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
json=invalid_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 400, f"Expected 400 for empty messages, got: {response.status_code}"
def test_validate_openai_request_invalid_message_format(self):
"""Test _validate_openai_request with invalid message format"""
invalid_data = {
"model": "gpt-3.5-turbo",
"messages": [
"not a dict"
],
}
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
json=invalid_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 400, f"Expected 400 for invalid message format, got: {response.status_code}"
def test_validate_openai_request_missing_role_or_content(self):
"""Test _validate_openai_request with missing role or content"""
invalid_data = {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user"}
],
}
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
json=invalid_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 400, f"Expected 400 for missing content, got: {response.status_code}"
def test_validate_openai_request_invalid_role(self):
"""Test _validate_openai_request with invalid role"""
invalid_data = {"model": "gpt-3.5-turbo", "messages": [{"role": "invalid_role", "content": "test"}]}
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
json=invalid_data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 400, f"Expected 400 for invalid role, got: {response.status_code}"
def test_handle_openai_request_unavailable_instances(self):
"""Test _handle_openai_request when instances are unavailable (503)."""
self.coordinator_server._inf._is_available = AsyncMock(return_value=False)
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
json={"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "test"}]},
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 503, f"Expected 503 for unavailable instances, got: {response.status_code}"
def test_handle_openai_request_with_prompt(self):
"""Test _handle_openai_request with prompt field (completions API)"""
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/completions",
json={"model": "text-davinci-003", "prompt": "Hello world"},
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 200, f"Completions API failed: {response.status_code}"
data = response.json()
assert data["data"]["input_data"] == "Hello world", "Prompt should be extracted correctly"
def test_handle_openai_request_empty_input(self):
"""Test _handle_openai_request with empty input"""
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/completions",
json={
"model": "text-davinci-003"
},
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 400, f"Expected 400 for missing prompt/messages, got: {response.status_code}"
def test_openai_is_stream(self):
"""Test _openai_is_stream method"""
assert self.coordinator_server._openai_is_stream({"stream": True})
assert not self.coordinator_server._openai_is_stream({"stream": False})
assert not self.coordinator_server._openai_is_stream({})
assert self.coordinator_server._openai_is_stream({"stream": "true"})
def test_refresh_instances_with_complex_endpoints(self):
"""Test refresh_instances with complex endpoint structures"""
complex_body = {
"event": "add",
"instances": [
{
"job_name": "test-job",
"model_name": "test-model",
"id": 3,
"role": "prefill",
"endpoints": {
"192.168.1.3": {
"0": {"id": 0, "ip": "192.168.1.3", "business_port": "8080", "mgmt_port": "18080"},
"1": {"id": 1, "ip": "192.168.1.3", "business_port": "8081", "mgmt_port": "18081"},
},
"192.168.1.4": {
"2": {"id": 2, "ip": "192.168.1.4", "business_port": "9000", "mgmt_port": "19000"}
},
},
}
],
}
response = self.mgmt_client.post("/instances/refresh", json=complex_body)
assert response.status_code == 200, f"Refresh instances failed: {response.status_code}"
data = response.json()
assert data["status"] == "success", f"Refresh instances status abnormal: {data}"
def test_refresh_instances_with_non_dict_endpoints(self):
"""Test refresh_instances with non-dict endpoints value"""
invalid_body = {
"event": "add",
"instances": [
{
"job_name": "test-job",
"model_name": "test-model",
"id": 4,
"role": "prefill",
"endpoints": "not a dict",
}
],
}
response = self.mgmt_client.post("/instances/refresh", json=invalid_body)
assert response.status_code == 400, f"Expected 400 for invalid endpoints format, got: {response.status_code}"
def test_refresh_instances_with_non_dict_endpoint_data(self):
"""Test refresh_instances with non-dict endpoint data"""
invalid_body = {
"event": "add",
"instances": [
{
"job_name": "test-job",
"model_name": "test-model",
"id": 5,
"role": "prefill",
"endpoints": {
"192.168.1.5": "not a dict"
},
}
],
}
response = self.mgmt_client.post("/instances/refresh", json=invalid_body)
assert response.status_code == 400, (
f"Expected 400 for invalid endpoint data format, got: {response.status_code}"
)
def test_timeout_handler(self):
"""Test timeout handler decorator"""
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
json={"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "test"}]},
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 200, "Timeout handler should not block normal requests"
def test_verify_api_key_skip_paths(self):
"""Test verify_api_key with skip paths"""
TestClient(self.coordinator_server.inference_app)
def test_lifespan_context_manager(self):
"""Test lifespan context manager"""
from fastapi import FastAPI
app = FastAPI(lifespan=self.coordinator_server.lifespan)
client = TestClient(app)
client.get("/")
assert True, "Lifespan context manager works correctly"
def test_setup_rate_limiting_with_disabled_config(self):
"""Test setup_rate_limiting with disabled config"""
from motor.config.coordinator import RateLimitConfig
disabled_config = RateLimitConfig()
disabled_config.enable_rate_limit = False
coordinator_server = _TestServerShell(config=CoordinatorConfig())
coordinator_server.instance_manager = MagicMock()
coordinator_server.setup_rate_limiting(rate_limit_config=disabled_config)
assert True, "Setup rate limiting with disabled config works correctly"
def test_setup_rate_limiting_with_exception(self):
"""Test setup_rate_limiting exception handling"""
with patch(
"motor.coordinator.middleware.fastapi_middleware.create_simple_rate_limit_middleware"
) as mock_create:
mock_create.side_effect = Exception("Test exception")
coordinator_server = _TestServerShell(config=CoordinatorConfig())
coordinator_server.instance_manager = MagicMock()
coordinator_server.setup_rate_limiting()
assert True, "Setup rate limiting handles exceptions correctly"
def test_create_unified_app_with_exception(self):
"""Test create_unified_app exception handling"""
with patch(
"motor.coordinator.middleware.fastapi_middleware.create_simple_rate_limit_middleware"
) as mock_create:
mock_create.side_effect = Exception("Test exception")
unified_app = self.coordinator_server.create_unified_app()
assert unified_app is not None, "Unified app should be created even with exceptions"
def test_copy_routes_with_exception(self):
"""Test _copy_routes when a route has invalid shape (path/endpoint must be real to avoid re/str errors)."""
from fastapi import FastAPI
src_app = FastAPI()
@src_app.get("/test")
async def test():
return {"status": "ok"}
dst_app = FastAPI()
bad_route = MagicMock()
bad_route.path = "/bad"
bad_route.methods = {"GET"}
bad_route.endpoint = None
with patch.object(src_app.router, "routes", new=[bad_route]):
self.coordinator_server._copy_routes(src_app, dst_app)
assert True, "Copy routes handles exception-like route correctly"
def test_handle_openai_request_json_decode_error(self):
"""Test _handle_openai_request with JSON decode error"""
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
content="invalid json",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code in [400, 422, 500], f"Expected error for invalid JSON, got: {response.status_code}"
def test_handle_openai_request_general_exception(self):
"""Test _handle_openai_request when handle_request raises (expect 500; avoid 503 via _is_available True)."""
with patch(
"motor.coordinator.api_server.inference_server.handle_request",
new_callable=AsyncMock,
) as mock_handle:
mock_handle.side_effect = Exception("Test exception")
inference_client = TestClient(self.coordinator_server.inference_app)
response = inference_client.post(
"/v1/chat/completions",
json={"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "test"}]},
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.valid_api_key}"},
)
assert response.status_code == 500, f"Expected 500 for exception, got: {response.status_code}"
class TestFastAPIMiddlewareAdvanced:
"""Test FastAPI middleware advanced functionality"""
def setup_method(self):
"""Setup test fixtures"""
from motor.coordinator.middleware.fastapi_middleware import (
SimpleRateLimitMiddleware,
SimpleRateLimitConfig,
load_rate_limit_config,
create_simple_rate_limit_middleware,
)
from motor.coordinator.middleware.rate_limiter import SimpleRateLimiter
from fastapi import FastAPI
from fastapi.testclient import TestClient
self.app = FastAPI()
self.SimpleRateLimitMiddleware = SimpleRateLimitMiddleware
self.SimpleRateLimitConfig = SimpleRateLimitConfig
self.load_rate_limit_config = load_rate_limit_config
self.create_simple_rate_limit_middleware = create_simple_rate_limit_middleware
self.SimpleRateLimiter = SimpleRateLimiter
self.TestClient = TestClient
self._report_alarms_patcher = patch(
"motor.coordinator.api_client.controller_api_client.ControllerApiClient.report_alarms",
return_value=True,
)
self._report_alarms_patcher.start()
def teardown_method(self):
self._report_alarms_patcher.stop()
def test_rate_limit_middleware_extract_request_data(self):
"""Test _extract_request_data method"""
@self.app.get("/test")
async def test_endpoint():
return {"status": "ok"}
rate_limiter = self.SimpleRateLimiter(max_requests=10, window_size=60)
middleware = self.SimpleRateLimitMiddleware(app=self.app, rate_limiter=rate_limiter, skip_paths=[])
client = self.TestClient(middleware)
response = client.get("/test")
assert response.status_code == 200, "Request should succeed"
assert middleware.stats["total_requests"] > 0, "Should extract request data"
def test_rate_limit_middleware_create_rate_limit_headers(self):
"""Test _create_rate_limit_headers method"""
@self.app.get("/test")
async def test_endpoint():
return {"status": "ok"}
rate_limiter = self.SimpleRateLimiter(max_requests=10, window_size=60)
middleware = self.SimpleRateLimitMiddleware(app=self.app, rate_limiter=rate_limiter, skip_paths=[])
client = self.TestClient(middleware)
response = client.get("/test")
assert response.status_code == 200, "Request should succeed"
assert "X-RateLimit-Remaining" in response.headers or "X-RateLimit-Limit" in response.headers, (
"Should create rate limit headers"
)
def test_rate_limit_middleware_dispatch_exception(self):
"""Test dispatch method exception handling"""
@self.app.get("/test")
async def test_endpoint():
return {"status": "ok"}
rate_limiter = MagicMock()
rate_limiter.is_allowed = MagicMock(side_effect=Exception("Test error"))
middleware = self.SimpleRateLimitMiddleware(app=self.app, rate_limiter=rate_limiter, skip_paths=[])
client = self.TestClient(middleware)
response = client.get("/test")
assert response.status_code == 200, "Should allow request when error occurs"
assert middleware.stats["allowed_requests"] > 0, "Should increment allowed_requests on error"
def test_rate_limit_middleware_should_skip_path(self):
"""Test _should_skip_path method"""
@self.app.get("/test")
async def test_endpoint():
return {"status": "ok"}
@self.app.get("/liveness")
async def liveness_endpoint():
return {"status": "healthy"}
rate_limiter = self.SimpleRateLimiter(max_requests=1, window_size=60)
middleware = self.SimpleRateLimitMiddleware(app=self.app, rate_limiter=rate_limiter, skip_paths=["/liveness"])
client = self.TestClient(middleware)
response1 = client.get("/liveness")
assert response1.status_code == 200, "Liveness endpoint should not be rate limited"
response2 = client.get("/liveness")
assert response2.status_code == 200, "Liveness endpoint should still not be rate limited"
response3 = client.get("/test")
assert response3.status_code == 200, "First request should succeed"
response4 = client.get("/test")
assert response4.status_code in [200, 429], "Second request may be rate limited"
def test_load_rate_limit_config_file_not_found(self):
"""Test load_rate_limit_config with non-existent file"""
config = self.load_rate_limit_config(config_file="/nonexistent/config.json")
assert config is not None, "Should return default config when file not found"
assert config.enabled, "Should use default enabled value"
def test_load_rate_limit_config_invalid_json(self):
"""Test load_rate_limit_config with invalid JSON file"""
import tempfile
import os
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
f.write("invalid json content")
config_file = f.name
try:
config = self.load_rate_limit_config(config_file=config_file)
assert config is not None, "Should return default config when JSON is invalid"
finally:
os.unlink(config_file)
def test_simple_rate_limit_config_post_init(self):
"""Test SimpleRateLimitConfig __post_init__"""
config = self.SimpleRateLimitConfig()
assert config.skip_paths is not None, "skip_paths should be initialized"
assert "/liveness" in config.skip_paths, "/liveness should be in skip_paths"
assert "/ready" in config.skip_paths, "/ready should be in skip_paths"
assert "/metrics" in config.skip_paths, "/metrics should be in skip_paths"
def test_create_simple_rate_limit_middleware_defaults(self):
"""Test create_simple_rate_limit_middleware with default parameters"""
middleware = self.create_simple_rate_limit_middleware(app=self.app)
assert middleware is not None, "Middleware should be created"
assert middleware.rate_limiter.max_requests == 100, "Should use default max_requests"
assert middleware.rate_limiter.window_size == 60, "Should use default window_size"
@pytest.mark.asyncio
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_run_combined_mode(monkeypatch):
from motor.coordinator.api_server.management_server import ManagementServer
from motor.config.coordinator import CoordinatorConfig
class DummyServer:
def __init__(self, *args, **kwargs):
self.should_exit = False
async def serve(self):
return
monkeypatch.setattr("motor.coordinator.api_server.management_server.uvicorn.Server", lambda *a, **k: DummyServer())
cfg = CoordinatorConfig()
cfg.infer_tls_config.enable_tls = False
cfg.mgmt_tls_config.enable_tls = False
srv = ManagementServer(config=cfg)
await srv.run()
@pytest.mark.asyncio
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_run_split_mode(monkeypatch):
from motor.coordinator.api_server.management_server import ManagementServer
from motor.config.coordinator import CoordinatorConfig
instances = []
class DummyServer:
def __init__(self, *args, **kwargs):
self.should_exit = False
instances.append(self)
async def serve(self):
return
monkeypatch.setattr("motor.coordinator.api_server.management_server.uvicorn.Server", lambda *a, **k: DummyServer())
cfg = CoordinatorConfig()
cfg.infer_tls_config.enable_tls = False
cfg.mgmt_tls_config.enable_tls = False
srv = ManagementServer(config=cfg)
await srv.run()
assert len(instances) == 2 or len(instances) == 0 or len(instances) == 1
class TestValidateAnthropicRequest:
"""Unit tests for _validate_anthropic_request validation function."""
def test_valid_messages_request(self):
"""Valid Anthropic messages request should pass validation."""
_validate_anthropic_request(
{"model": "claude-3", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 100},
require_max_tokens=True,
)
def test_valid_messages_request_with_complex_content(self):
"""Valid request with content blocks should pass validation."""
_validate_anthropic_request(
{
"model": "claude-3",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": "..."}},
],
}
],
"max_tokens": 100,
},
require_max_tokens=True,
)
def test_missing_model(self):
"""Missing model field should raise HTTP 400."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"messages": [{"role": "user", "content": "Hello"}], "max_tokens": 100},
)
assert exc_info.value.status_code == 400
assert "model" in exc_info.value.detail.lower()
def test_empty_model(self):
"""Empty model should raise HTTP 400."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"model": "", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 100},
)
assert exc_info.value.status_code == 400
def test_missing_messages(self):
"""Missing messages field should raise HTTP 400."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"model": "claude-3", "max_tokens": 100},
)
assert exc_info.value.status_code == 400
assert "messages" in exc_info.value.detail.lower()
def test_empty_messages(self):
"""Empty messages array should raise HTTP 400."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"model": "claude-3", "messages": [], "max_tokens": 100},
)
assert exc_info.value.status_code == 400
def test_messages_not_a_list(self):
"""Messages as non-list should raise HTTP 400."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"model": "claude-3", "messages": "not a list", "max_tokens": 100},
)
assert exc_info.value.status_code == 400
def test_missing_max_tokens(self):
"""Missing max_tokens should raise HTTP 400 when require_max_tokens=True."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"model": "claude-3", "messages": [{"role": "user", "content": "Hello"}]},
require_max_tokens=True,
)
assert exc_info.value.status_code == 400
assert "max_tokens" in exc_info.value.detail.lower()
def test_max_tokens_zero(self):
"""max_tokens=0 should raise HTTP 400."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"model": "claude-3", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 0},
)
assert exc_info.value.status_code == 400
def test_max_tokens_negative(self):
"""Negative max_tokens should raise HTTP 400."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"model": "claude-3", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": -1},
)
assert exc_info.value.status_code == 400
def test_max_tokens_not_integer(self):
"""Non-integer max_tokens should raise HTTP 400."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"model": "claude-3", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": "abc"},
)
assert exc_info.value.status_code == 400
def test_count_tokens_no_max_tokens_required(self):
"""Count tokens validation should NOT require max_tokens."""
_validate_anthropic_request(
{"model": "claude-3", "messages": [{"role": "user", "content": "Hello"}]},
require_max_tokens=False,
)
def test_count_tokens_missing_model(self):
"""Count tokens validation should still require model."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"messages": [{"role": "user", "content": "Hello"}]},
require_max_tokens=False,
)
assert exc_info.value.status_code == 400
def test_count_tokens_missing_messages(self):
"""Count tokens validation should still require messages."""
with pytest.raises(HTTPException) as exc_info:
_validate_anthropic_request(
{"model": "claude-3"},
require_max_tokens=False,
)
assert exc_info.value.status_code == 400
def test_system_as_string_accepted(self):
"""System prompt as string should pass validation."""
_validate_anthropic_request(
{
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"system": "You are a helpful assistant",
},
require_max_tokens=True,
)
def test_tool_definitions_accepted(self):
"""Request with tools should pass validation."""
_validate_anthropic_request(
{
"model": "claude-3",
"messages": [{"role": "user", "content": "What is the weather?"}],
"max_tokens": 100,
"tools": [{"name": "get_weather", "input_schema": {"type": "object"}}],
"tool_choice": {"type": "auto"},
},
require_max_tokens=True,
)
class TestAnthropicEndpoints:
"""Integration tests for Anthropic API endpoints."""
def setup_method(self):
"""Setup test fixtures"""
self._im_patcher = patch("motor.coordinator.api_server.management_server.InstanceManager")
im_mock_cls = self._im_patcher.start()
im_instance = MagicMock()
im_instance.has_required_instances.return_value = True
im_instance.get_required_instances_status.return_value = InstanceReadiness.REQUIRED_MET
im_instance.refresh_instances = AsyncMock(return_value=None)
im_mock_cls.return_value = im_instance
async def mock_handle_request(request, config, scheduler=None, request_manager=None):
try:
body_json = await request.json()
except Exception:
body_json = {}
if "messages" in body_json:
json.dumps(body_json["messages"], ensure_ascii=False)
is_stream = body_json.get("stream", False)
if isinstance(is_stream, str):
is_stream = is_stream.lower() in ("true", "1", "yes")
import hashlib
request_id = f"req-{hashlib.md5(str(body_json).encode()).hexdigest()[:8]}"
response_data = {
"id": f"msg_{request_id}",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "Hello! How can I help you?"}],
"model": body_json.get("model", "unknown"),
"stop_reason": "end_turn",
"usage": {"input_tokens": 10, "output_tokens": 7},
}
return JSONResponse(content=response_data)
self._handle_request_patcher = patch(
"motor.coordinator.api_server.inference_server.handle_request", side_effect=mock_handle_request
)
self._handle_request_patcher.start()
self._is_available_patcher = patch(
"motor.coordinator.api_server.inference_server.InferenceServer._is_available",
new_callable=AsyncMock,
return_value=True,
)
self._is_available_patcher.start()
coordinator_config = CoordinatorConfig()
coordinator_config.api_key_config.enable_api_key = True
coordinator_config.api_key_config.valid_keys = {"sk-test123456789"}
self.coordinator_server = _TestServerShell(config=coordinator_config)
self.coordinator_server.setup_rate_limiting()
inf = self.coordinator_server._inf
inf._is_available = AsyncMock(return_value=True)
_mock_scheduler = MagicMock()
_mock_scheduler.get_available_instances = AsyncMock(return_value={})
inf._get_scheduler_client = lambda: _mock_scheduler
if not getattr(inf.app.state, "request_manager", None):
inf.app.state.request_manager = inf._request_manager
self.inference_client = TestClient(self.coordinator_server.inference_app)
self.valid_api_key = "sk-test123456789"
self.auth_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.valid_api_key}",
}
def teardown_method(self):
for patcher_name in ["_im_patcher", "_handle_request_patcher", "_is_available_patcher"]:
patcher = getattr(self, patcher_name, None)
if patcher is not None:
try:
patcher.stop()
except Exception:
pass
def test_anthropic_messages_basic(self):
"""Basic Anthropic messages request should succeed."""
response = self.inference_client.post(
"/v1/messages",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
},
headers=self.auth_headers,
)
assert response.status_code == 200, f"Expected 200, got: {response.status_code}"
data = response.json()
assert data["type"] == "message"
assert data["role"] == "assistant"
assert len(data["content"]) > 0
assert "usage" in data
def test_anthropic_messages_with_system_prompt(self):
"""Anthropic request with system prompt should succeed."""
response = self.inference_client.post(
"/v1/messages",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"system": "You are a helpful assistant",
},
headers=self.auth_headers,
)
assert response.status_code == 200
def test_anthropic_messages_with_tools(self):
"""Anthropic request with tools should succeed."""
response = self.inference_client.post(
"/v1/messages",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "What is the weather?"}],
"max_tokens": 100,
"tools": [{"name": "get_weather", "input_schema": {"type": "object"}}],
"tool_choice": {"type": "auto"},
},
headers=self.auth_headers,
)
assert response.status_code == 200
def test_anthropic_messages_streaming(self):
"""Streaming Anthropic request should succeed."""
response = self.inference_client.post(
"/v1/messages",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": True,
},
headers=self.auth_headers,
)
assert response.status_code == 200
def test_anthropic_count_tokens(self):
"""Count tokens request should succeed."""
response = self.inference_client.post(
"/v1/messages/count_tokens",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
},
headers=self.auth_headers,
)
assert response.status_code == 200
def test_anthropic_messages_missing_required_field(self):
"""Missing required field should return 400."""
response = self.inference_client.post(
"/v1/messages",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
},
headers=self.auth_headers,
)
assert response.status_code == 400
def test_anthropic_messages_auth_required(self):
"""Request without API key should return 401."""
response = self.inference_client.post(
"/v1/messages",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
},
headers={"Content-Type": "application/json"},
)
assert response.status_code == 401
def test_anthropic_messages_invalid_auth(self):
"""Request with invalid API key should return 403."""
response = self.inference_client.post(
"/v1/messages",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
},
headers={"Content-Type": "application/json", "Authorization": "Bearer invalid-key"},
)
assert response.status_code == 403
def test_anthropic_count_tokens_missing_model(self):
"""Count tokens without model should return 400."""
response = self.inference_client.post(
"/v1/messages/count_tokens",
json={"messages": [{"role": "user", "content": "Hello"}]},
headers=self.auth_headers,
)
assert response.status_code == 400
def test_anthropic_count_tokens_with_max_tokens_ok(self):
"""Count tokens with max_tokens (extra field) should still succeed."""
response = self.inference_client.post(
"/v1/messages/count_tokens",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
},
headers=self.auth_headers,
)
assert response.status_code == 200
def test_openai_endpoints_still_work(self):
"""Verify OpenAI endpoints still function."""
response = self.inference_client.post(
"/v1/chat/completions",
json={
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
},
headers=self.auth_headers,
)
assert response.status_code == 200
response = self.inference_client.post(
"/v1/completions",
json={"model": "text-davinci-003", "prompt": "Hello"},
headers=self.auth_headers,
)
assert response.status_code == 200
def test_service_unavailable(self):
"""When service is unavailable, should return 503."""
self.coordinator_server._inf._is_available = AsyncMock(return_value=False)
response = self.inference_client.post(
"/v1/messages",
json={
"model": "claude-3",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
},
headers=self.auth_headers,
)
assert response.status_code == 503