import asyncio
import json
import os
import socket
import logging
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException, status
from fastapi.responses import Response
import uvicorn
from motor.common.http.cert_util import CertUtil
from motor.common.utils.net import detect_family, format_address
from motor.config.node_manager import NodeManagerConfig
from motor.node_manager.core.heartbeat_manager import HeartbeatManager
from motor.common.logger import ApiAccessFilter, get_logger
from motor.common.resources.http_msg_spec import StartCmdMsg
from motor.node_manager.core.engine_manager import EngineManager
from motor.node_manager.core.daemon import Daemon
from motor.node_manager.core.api_ready_event import clear_api_ready, mark_api_ready, wait_until_api_ready
from motor.common.resources.instance import PDRole
from motor.common.utils.snapshot_utils import is_restored_from_host_side_snapshot
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(application: FastAPI):
"""Lifespan context manager for FastAPI app"""
mark_api_ready()
logger.info("NodeManagerAPI server is ready")
yield
clear_api_ready()
app = FastAPI(lifespan=lifespan)
MAX_CONCURRENT_THREADS = 10
thread_semaphore = asyncio.Semaphore(MAX_CONCURRENT_THREADS)
@app.post("/node-manager/start")
async def start_instance(request: Request):
"""post instance and role info"""
try:
payload = await request.json()
start_msg = StartCmdMsg(**payload)
engine_manager = EngineManager()
async with thread_semaphore:
try:
parsed_ok = await asyncio.to_thread(engine_manager.parse_start_cmd, start_msg)
except Exception as inner_err:
logger.error("Failed to parse start command: %s", inner_err)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid start command payload"
) from inner_err
if not parsed_ok:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Start command validation failed"
)
if is_restored_from_host_side_snapshot():
await asyncio.to_thread(EngineManager().engine_resume_prepare, start_msg)
HeartbeatManager().update_endpoint(start_msg)
HeartbeatManager().set_started_after_restore(True)
return {}
await asyncio.to_thread(EngineManager().engine_suspend_prepare)
try:
await asyncio.to_thread(
Daemon().pull_engine,
PDRole(start_msg.role),
start_msg.endpoints,
start_msg.instance_id,
start_msg.master_dp_ip,
engine_manager.d2d_peer_ips,
start_msg.node_rank,
)
except Exception as pull_err:
logger.error("Failed to pull engine: %s", pull_err)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to start engine server"
) from pull_err
HeartbeatManager().update_endpoint(start_msg)
HeartbeatManager().start()
engine_manager.start()
return {}
except HTTPException as http_err:
raise http_err
except Exception as err:
logger.error("Unexpected error: %s", err)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal server error occurred"
) from err
@app.post("/node-manager/stop")
async def stop_instance(request: Request):
"""
Stop all engine processes by invoking Daemon.exit_daemon().
"""
try:
await asyncio.to_thread(Daemon().stop)
content = {"message": "All engine processes stopped successfully."}
return Response(status_code=status.HTTP_200_OK, content=json.dumps(content))
except Exception as err:
logger.error("Failed to stop engines via daemon: %s", err)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to stop engine processes"
) from err
@app.post("/node-manager/pause")
async def pause_instance(request: Request):
"""
PreStop hook: set all endpoints to PAUSED status.
Pod readiness probe returns false; liveness probe remains true.
Controller will receive PAUSED status via heartbeat and trigger
the pause flow to Coordinator.
"""
try:
await asyncio.to_thread(HeartbeatManager().pause_all_endpoints)
hm = HeartbeatManager()
engine_mgmt_addrs = hm.get_engine_mgmt_addrs()
content = {
"status": "ok",
"message": "Endpoints set to PAUSED",
"engine_mgmt_addrs": engine_mgmt_addrs,
}
return Response(status_code=status.HTTP_200_OK, content=json.dumps(content))
except Exception as err:
logger.error("Failed to set endpoints to PAUSED: %s", err)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to set endpoints to PAUSED"
) from err
@app.post("/node-manager/resume")
async def resume_instance(request: Request):
"""
Resume instance from PAUSED back to NORMAL status.
Used when PreStop is cancelled (e.g. rollout rollback).
"""
try:
await asyncio.to_thread(HeartbeatManager().resume_all_endpoints)
content = {"status": "ok", "message": "Endpoints resumed to NORMAL"}
return Response(status_code=status.HTTP_200_OK, content=json.dumps(content))
except Exception as err:
logger.error("Failed to resume endpoints: %s", err)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to resume endpoints"
) from err
async def _check_node_manager_ready() -> bool:
is_normal = await asyncio.to_thread(HeartbeatManager().check_all_endpoints_normal)
if is_restored_from_host_side_snapshot():
is_normal = is_normal and HeartbeatManager().is_started_after_restore()
return is_normal
@app.get("/node-manager/status")
async def get_instance_status():
"""
Check if all endpoints managed by this node manager are in normal status.
Returns True if all endpoints are normal, False if any endpoint is abnormal.
"""
try:
is_normal = await _check_node_manager_ready()
return {"status": is_normal}
except Exception as err:
logger.error("Failed to check endpoints status: %s", err)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to check endpoints status"
) from err
@app.get("/readiness")
async def readiness():
"""
Readiness probe - returns 200 when all endpoints are healthy.
Otherwise, returns 503.
"""
try:
is_ready = await _check_node_manager_ready()
except Exception as err:
logger.error("Failed to check node manager readiness: %s", err)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to check node manager readiness"
) from err
msg = "message"
reason = "reason"
if not is_ready:
if is_restored_from_host_side_snapshot() and not HeartbeatManager().is_started_after_restore():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={msg: "Node manager is not ready", reason: "Not started after container snapshot restore"},
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={msg: "Node manager is not ready", reason: "Endpoints not healthy"},
)
return {msg: "Node manager is ready"}
class NodeManagerAPI:
def __init__(self, config: NodeManagerConfig = None):
self._config = config
if self._config and self._config.api_config.pod_ip:
self.host = self._config.api_config.pod_ip
else:
self.host = "::" if detect_family(os.getenv("POD_IP", "")) == socket.AF_INET6 else "0.0.0.0"
if self._config:
self.port = self._config.api_config.node_manager_port
else:
self.port = 8080
self.server = None
self.serve_task = None
self._thread = None
clear_api_ready()
self._thread = threading.Thread(target=self._serve_in_thread, daemon=True, name="nm_api_server")
self._thread.start()
@staticmethod
def wait_until_ready(timeout: float = None) -> bool:
"""
Wait until the NodeManagerAPI server is ready.
Args:
timeout: Maximum time to wait in seconds. None means wait indefinitely.
Returns:
True if the server is ready, False if timeout occurred.
"""
return wait_until_api_ready(timeout=timeout)
async def stop(self):
self.stop_sync()
def stop_sync(self):
if self.server:
self.server.should_exit = True
if self._thread and self._thread.is_alive():
self._thread.join(timeout=1.0)
if self._thread.is_alive():
logger.warning("API server thread did not stop within timeout")
@staticmethod
def _suppress_probe_access_logs() -> None:
"""Suppress noisy uvicorn access logs from K8s readiness/liveness probes."""
probe_filter = ApiAccessFilter(
{
"/readiness": logging.ERROR,
}
)
logging.getLogger("uvicorn.access").addFilter(probe_filter)
def _serve_in_thread(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._suppress_probe_access_logs()
config = uvicorn.Config(app, host=self.host, port=self.port, loop="asyncio")
config.load()
if self._config.mgmt_tls_config.enable_tls:
context = CertUtil.create_ssl_context(self._config.mgmt_tls_config)
if not context:
raise RuntimeError("Failed to create SSL context")
config.ssl = context
logger.info("Node Manager server started: https://%s", format_address(self.host, self.port))
else:
logger.info("Node Manager server stated: http://%s", format_address(self.host, self.port))
self.server = uvicorn.Server(config)
try:
loop.run_until_complete(self.server.serve())
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
except Exception as e:
logger.error("Failed to shutdown server: %s", e)
loop.close()