import argparse
import asyncio
import logging
import heapq
import os
import sys
import threading
import time
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
MAX_INT = sys.maxsize
HEADER_PREFILL_POD_ADDRESS_PORT = "x-openfuyao-prefill-pod-address-port"
HEADER_DECODE_POD_ADDRESS_PORT = "x-openfuyao-decode-pod-address-port"
@dataclass
class ServiceRequest:
"""Request parameters used when forwarding to backend services."""
client: httpx.AsyncClient
prefiller_key: str
endpoint: str
payload: dict
request_id: str
max_retries: int = 3
retry_interval: float = 1
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
try:
from kubernetes import client as k8s_client, config as k8s_config
K8S_AVAILABLE = True
except ImportError:
k8s_client = None
k8s_config = None
K8S_AVAILABLE = False
logger.warning(
"kubernetes package not available. Install with: pip install kubernetes"
)
class VllmServiceDiscovery:
"""Periodically discovers and updates the endpoints of prefiller / decoder pods in a namespace."""
def __init__(self, args):
self.prefiller_labels = args.prefiller_labels
self.decoder_labels = args.decoder_labels
self.namespace = args.namespace
self.discovery_interval = args.discovery_interval
self.prefiller_container = args.prefiller_container
self.decoder_container = args.decoder_container
self.prefiller_port_name = args.prefiller_port_name
self.decoder_port_name = args.decoder_port_name
self._prefiller_endpoints: List[Tuple[str, int]] = []
self._decoder_endpoints: List[Tuple[str, int]] = []
self._lock = threading.Lock()
self._running = False
self._thread: Optional[threading.Thread] = None
if not K8S_AVAILABLE:
logger.warning("K8s client not available. Service discovery disabled.")
return
try:
try:
k8s_config.load_incluster_config()
except Exception:
k8s_config.load_kube_config()
except Exception:
logger.warning(
"Could not load Kubernetes configuration; "
"automatic service discovery will not run."
)
return
self._core_v1 = k8s_client.CoreV1Api()
self.start()
def start(self) -> None:
if self._running:
return
if not K8S_AVAILABLE or not hasattr(self, "_core_v1"):
return
self._running = True
self._thread = threading.Thread(target=self._discovery_loop, daemon=True)
self._thread.start()
logger.info("K8s vLLM service discovery started.")
def stop(self) -> None:
self._running = False
if self._thread:
self._thread.join(timeout=5)
logger.info("K8s vLLM service discovery stopped.")
def get_prefiller_instances(self) -> List[Tuple[str, int]]:
with self._lock:
return list(self._prefiller_endpoints)
def get_decoder_instances(self) -> List[Tuple[str, int]]:
with self._lock:
return list(self._decoder_endpoints)
@staticmethod
def _resolve_container_port(pod, container_name: str, port_name: str) -> int:
"""Best-effort extraction of the target container port."""
default_port = 8100 if "prefill" in pod.metadata.name.lower() else 8200
try:
containers = pod.spec.containers or []
if not containers:
return default_port
target = None
if container_name:
for c in containers:
if c.name == container_name:
target = c
break
if target is None:
target = containers[0]
if target.ports:
for port in target.ports:
if getattr(port, "name", None) and port.name.lower() == port_name:
logger.debug("Found named port '%s' with value %s for container %s",
port.name, port.container_port, target.name)
return port.container_port
logger.debug("Using first port %s for container %s",
target.ports[0].container_port, target.name)
return target.ports[0].container_port
return default_port
except Exception as exc:
logger.warning(
"Failed to resolve container port for pod %s: %s",
pod.metadata.name,
exc,
)
return default_port
def _discovery_once(self) -> None:
if not K8S_AVAILABLE or not hasattr(self, "_core_v1"):
return
try:
prefiller_eps: List[Tuple[str, int]] = []
decoder_eps: List[Tuple[str, int]] = []
pods = self._core_v1.list_namespaced_pod(
namespace=self.namespace, label_selector=self.prefiller_labels
)
for pod in pods.items:
if pod.status.phase != "Running":
continue
ip = pod.status.pod_ip
port = self._resolve_container_port(
pod, self.prefiller_container, self.prefiller_port_name
)
prefiller_eps.append((ip, port))
pods = self._core_v1.list_namespaced_pod(
namespace=self.namespace, label_selector=self.decoder_labels
)
for pod in pods.items:
if pod.status.phase != "Running":
continue
ip = pod.status.pod_ip
port = self._resolve_container_port(
pod, self.decoder_container, self.decoder_port_name
)
decoder_eps.append((ip, port))
with self._lock:
self._prefiller_endpoints = prefiller_eps
self._decoder_endpoints = decoder_eps
logger.debug(
"Service discovery: %d prefillers, %d decoders",
len(prefiller_eps),
len(decoder_eps),
)
except Exception as exc:
logger.error("Service discovery iteration failed: %s", exc)
def _discovery_loop(self) -> None:
while self._running:
self._discovery_once()
time.sleep(self.discovery_interval)
class ServerState:
"""Single backend node (prefiller / decoder)."""
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.url = f"http://{host}:{port}"
self.client = httpx.AsyncClient(
timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000,
),
)
self.active_tokens = 0
self.active_kv_cache = 0
self.active_requests = 0
self.aborted_requests = set()
class BackendPool:
"""Manage a homogeneous group of backend nodes and provide load‑balanced selection."""
def __init__(self, role: str, weight_kv_cache: bool):
self.role = role
self.weight_kv_cache = weight_kv_cache
self.servers: Dict[str, ServerState] = {}
self._heap: List[Tuple[float, str]] = []
self._lock = threading.Lock()
def sync_from_endpoints(self, endpoints: List[Tuple[str, int]]) -> None:
"""Reconcile current nodes with the given endpoint list."""
with self._lock:
desired = {f"{h}:{p}" for h, p in endpoints}
for host, port in endpoints:
key = f"{host}:{port}"
if key not in self.servers:
self.servers[key] = ServerState(host, port)
logger.info("Added %s node %s", self.role, key)
heapq.heappush(self._heap, (0.0, key))
existing = set(self.servers.keys())
for key in existing - desired:
server = self.servers.pop(key, None)
if server is not None:
server.active_tokens = MAX_INT
logger.info("Removed %s node %s (marked as overloaded)", self.role, key)
def _priority_for(self, server: ServerState) -> float:
value = float(server.active_tokens)
if self.weight_kv_cache:
value += server.active_kv_cache * 0.3
return value
def _prune_stale_head(self) -> None:
while self._heap:
_, key = self._heap[0]
if key in self.servers:
return
heapq.heappop(self._heap)
def choose(self, token_cost: float) -> str:
"""Pick the least‑loaded node and charge it."""
with self._lock:
self._prune_stale_head()
if not self._heap:
raise RuntimeError(f"No {self.role} servers available")
_, key = heapq.heappop(self._heap)
server = self.servers.get(key)
if server is None:
return self.choose(token_cost)
server.active_tokens += token_cost
if self.weight_kv_cache:
server.active_kv_cache += token_cost
heapq.heappush(self._heap, (self._priority_for(server), key))
return key
def choose_specific(self, host: str, port: int, token_cost: float) -> str:
"""Use a caller‑specified node, if present, and charge it."""
key = f"{host}:{port}"
with self._lock:
server = self.servers.get(key)
if server is None:
raise RuntimeError(
f"{self.role.capitalize()} server {host}:{port} not registered"
)
server.active_tokens += token_cost
if self.weight_kv_cache:
server.active_kv_cache += token_cost
heapq.heappush(self._heap, (self._priority_for(server), key))
return key
def release_tokens(self, key: str, token_cost: float) -> None:
with self._lock:
server = self.servers.get(key)
if server is None:
return
server.active_tokens -= token_cost
heapq.heappush(self._heap, (self._priority_for(server), key))
def release_kv(self, key: str, token_cost: float) -> None:
with self._lock:
server = self.servers.get(key)
if server is None:
return
if server.active_kv_cache > 0:
server.active_kv_cache -= token_cost
heapq.heappush(self._heap, (self._priority_for(server), key))
def peek_next_key(self) -> str:
"""Return the key of the node that would be selected next."""
self._prune_stale_head()
if not self._heap:
raise RuntimeError(f"No {self.role} servers available")
return self._heap[0][1]
def record_aborted(self, key: str, request_id: str) -> None:
server = self.servers.get(key)
if server is not None:
server.aborted_requests.add(request_id)
def take_aborted(self, key: str):
server = self.servers.get(key)
if server is None:
return set()
ids = server.aborted_requests.copy()
server.aborted_requests.clear()
return ids
class LoadBalanceRouter:
"""Top‑level proxy router: discovery + prefill / decode pools + routing helpers."""
def __init__(self, discovery: VllmServiceDiscovery):
self._discovery = discovery
self._prefill_pool = BackendPool("prefiller", weight_kv_cache=True)
self._decode_pool = BackendPool("decoder", weight_kv_cache=False)
self._req_id_lock = asyncio.Lock()
self._instance_lock = threading.Lock()
self._sync_from_discovery()
@property
def prefillers(self) -> Dict[str, ServerState]:
return self._prefill_pool.servers
@property
def decoders(self) -> Dict[str, ServerState]:
return self._decode_pool.servers
def _sync_from_discovery(self) -> None:
with self._instance_lock:
self._prefill_pool.sync_from_endpoints(
self._discovery.get_prefiller_instances()
)
self._decode_pool.sync_from_endpoints(
self._discovery.get_decoder_instances()
)
logger.debug(
"Synced instances: prefiller=%d, decoder=%d",
len(self.prefillers),
len(self.decoders),
)
def refresh_instances(self) -> None:
self._sync_from_discovery()
def mark_prefill_aborted(self, prefiller_key: str, request_id: str) -> None:
self._prefill_pool.record_aborted(prefiller_key, request_id)
def acquire_aborted_for_prefiller(self, prefiller_key: str):
return self._prefill_pool.take_aborted(prefiller_key)
async def generate_request_id(self) -> str:
async with self._req_id_lock:
return str(uuid.uuid4())
def next_metrics_prefiller_key(self) -> str:
return self._prefill_pool.peek_next_key()
def select_prefiller(self, cost: float) -> str:
return self._prefill_pool.choose(cost)
def select_prefiller_by_address(self, host: str, port: int,
cost: float) -> str:
return self._prefill_pool.choose_specific(host, port, cost)
def release_prefiller_tokens(self, key: str, cost: float) -> None:
self._prefill_pool.release_tokens(key, cost)
def release_kv(self, key: str, cost: float) -> None:
"""Release KV cache usage on a prefiller node."""
self._prefill_pool.release_kv(key, cost)
def select_decoder(self, cost: float) -> str:
return self._decode_pool.choose(cost)
def select_decoder_by_address(self, host: str, port: int,
cost: float) -> str:
return self._decode_pool.choose_specific(host, port, cost)
def release_decoder_tokens(self, key: str, cost: float) -> None:
self._decode_pool.release_tokens(key, cost)
router: Optional[LoadBalanceRouter] = None
global_args = None
async def stop_proxy_server(discovery: VllmServiceDiscovery,
state: LoadBalanceRouter,
refresh_task: asyncio.Task) -> None:
"""Gracefully stop background tasks and close all backend clients."""
refresh_task.cancel()
try:
await refresh_task
except asyncio.CancelledError:
pass
discovery.stop()
for server in state.prefillers.values():
await server.client.aclose()
for server in state.decoders.values():
await server.client.aclose()
@asynccontextmanager
async def lifespan(app: FastAPI):
global router
discovery = VllmServiceDiscovery(global_args)
router = LoadBalanceRouter(discovery)
logger.info(
"Proxy initialized: %d prefiller nodes, %d decoder nodes.",
len(router.prefillers),
len(router.decoders),
)
async def refresh_loop():
while True:
await asyncio.sleep(global_args.discovery_interval)
try:
router.refresh_instances()
except Exception as exc:
logger.error("Failed to refresh instances: %s", exc)
refresh_task = asyncio.create_task(refresh_loop())
try:
yield
finally:
await stop_proxy_server(discovery, router, refresh_task)
async def _listen_for_disconnect(request: Request) -> None:
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
break
def with_cancellation(handler):
import functools
@functools.wraps(handler)
async def wrapper(*args, **kwargs):
request: Request = kwargs["request"]
task_handler = asyncio.create_task(handler(*args, **kwargs))
task_cancel = asyncio.create_task(_listen_for_disconnect(request))
done, pending = await asyncio.wait(
[task_handler, task_cancel], return_when=asyncio.FIRST_COMPLETED
)
for t in pending:
t.cancel()
if task_handler in done:
return task_handler.result()
return None
return wrapper
app = FastAPI(lifespan=lifespan)
async def _forward_prefill_once(req: ServiceRequest):
"""Call prefiller with non‑streaming settings (single response)."""
aborted_ids = router.acquire_aborted_for_prefiller(req.prefiller_key)
body = req.payload.copy()
body["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
"aborted_request": list(aborted_ids),
}
body["stream"] = False
body["max_tokens"] = 1
if "stream_options" in body:
body.pop("stream_options", None)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": req.request_id,
}
attempt = 0
while attempt < req.max_retries:
attempt += 1
try:
resp = await req.client.post(req.endpoint,
json=body,
headers=headers)
resp.raise_for_status()
return resp
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
logger.warning("Prefiller call failed on attempt %d for %s: %s",
attempt, req.endpoint, exc)
if attempt >= req.max_retries:
logger.error(
"Prefiller failed after %d attempts requesting to %s.",
attempt,
req.endpoint,
)
raise
await asyncio.sleep(req.retry_interval)
async def _stream_decoder(req: ServiceRequest):
"""Stream response from decoder to the client with basic retry semantics."""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": req.request_id,
}
attempt = 0
while attempt < req.max_retries:
attempt += 1
has_stream_started = False
try:
async with req.client.stream(
"POST", req.endpoint, json=req.payload, headers=headers
) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
has_stream_started = True
yield chunk
return
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
logger.warning("Decoder stream attempt %d failed for %s: %s",
attempt, req.endpoint, exc)
if attempt >= req.max_retries:
logger.error(
"Decoder failed after %d attempts streaming to %s.",
attempt,
req.endpoint,
)
raise
except Exception as exc:
if has_stream_started:
logger.error(
"Decoder streaming interrupted after data started flowing: %s",
exc,
)
return
logger.warning("Decoder stream attempt %d raised %s for %s",
attempt, type(exc).__name__, req.endpoint)
if attempt >= req.max_retries:
logger.error(
"Decoder failed after %d attempts streaming to %s.",
attempt,
req.endpoint,
)
raise
await asyncio.sleep(req.retry_interval)
def _parse_node_overrides(
request: Request,
) -> Optional[Tuple[Tuple[str, int], Tuple[str, int]]]:
"""Parse prefill / decode overrides from headers if both are present."""
prefill_header = request.headers.get(HEADER_PREFILL_POD_ADDRESS_PORT)
decode_header = request.headers.get(HEADER_DECODE_POD_ADDRESS_PORT)
if not prefill_header or not decode_header:
return None
try:
prefill_host, prefill_port = prefill_header.split(":")
decode_host, decode_port = decode_header.split(":")
return (
(prefill_host.strip(), int(prefill_port.strip())),
(decode_host.strip(), int(decode_port.strip())),
)
except (ValueError, AttributeError) as exc:
logger.error(
"Failed to parse override headers (%s, %s): %s",
prefill_header,
decode_header,
exc,
)
raise HTTPException(
status_code=400,
detail="Invalid node override headers: expected <ip>:<port> format.",
) from exc
async def _read_request_and_costs(
request: Request,
) -> Tuple[dict, bool, str, float, float, str]:
"""Read body once and derive basic cost / metadata."""
body = await request.json()
raw_bytes = await request.body()
req_len = len(raw_bytes)
is_stream = body.get("stream", False)
media_type = "text/event-stream" if is_stream else "application/json"
request_length_cost = req_len / 4.0
estimated_prefill_cost = request_length_cost * 0.0345 + 120.0745
estimated_decode_cost = float(req_len)
req_id = await router.generate_request_id()
return body, is_stream, media_type, estimated_prefill_cost, estimated_decode_cost, req_id
def _choose_prefiller(
overrides: Optional[Tuple[Tuple[str, int], Tuple[str, int]]],
prefill_cost: float,
req_id: str,
) -> str:
"""Pick a prefiller node, honoring explicit overrides when provided."""
prefiller_key: Optional[str] = None
if overrides is not None:
(phost, pport), _ = overrides
try:
prefiller_key = router.select_prefiller_by_address(
phost, pport, prefill_cost
)
logger.info("Using specified prefiller %s:%d for request %s",
phost, pport, req_id)
except RuntimeError as exc:
logger.error("Requested prefiller %s:%d not available: %s",
phost, pport, exc)
raise HTTPException(
status_code=400,
detail=(
"Requested prefiller node is not registered with the proxy. "
f"Check {HEADER_PREFILL_POD_ADDRESS_PORT} header."
),
) from exc
if prefiller_key is None:
prefiller_key = router.select_prefiller(prefill_cost)
return prefiller_key
async def _invoke_prefiller_and_update_body(
api_path: str,
body: dict,
prefiller_key: str,
prefill_cost: float,
req_id: str,
) -> None:
"""Call the selected prefiller to obtain KV transfer params and patch body."""
prefiller = router.prefillers[prefiller_key]
prefill_req = ServiceRequest(
client=prefiller.client,
prefiller_key=prefiller_key,
endpoint=api_path,
payload=body,
request_id=req_id,
max_retries=global_args.max_retries,
retry_interval=global_args.retry_interval,
)
resp = await _forward_prefill_once(prefill_req)
router.release_prefiller_tokens(prefiller_key, prefill_cost)
resp_json = resp.json()
kv_params = resp_json.get("kv_transfer_params", {}) or {}
if kv_params:
body["kv_transfer_params"] = kv_params
def _choose_decoder(
overrides: Optional[Tuple[Tuple[str, int], Tuple[str, int]]],
decode_cost: float,
req_id: str,
) -> str:
"""Pick a decoder node, honoring explicit overrides when provided."""
decoder_key: Optional[str] = None
if overrides is not None:
try:
_, (dhost, dport) = overrides
decoder_key = router.select_decoder_by_address(
dhost, dport, decode_cost
)
logger.info("Using specified decoder %s:%d for request %s",
dhost, dport, req_id)
except RuntimeError as exc:
logger.error("Requested decoder %s:%d not available: %s",
dhost, dport, exc)
raise HTTPException(
status_code=400,
detail=(
"Requested decoder node is not registered with the proxy. "
f"Check {HEADER_DECODE_POD_ADDRESS_PORT} header."
),
) from exc
if decoder_key is None:
decoder_key = router.select_decoder(decode_cost)
return decoder_key
def _build_streaming_response(
api_path: str,
body: dict,
req_id: str,
prefiller_key: str,
decoder_key: str,
prefill_cost: float,
decode_cost: float,
is_stream: bool,
media_type: str,
) -> StreamingResponse:
"""Construct the StreamingResponse that forwards bytes from decoder to client."""
prefiller = router.prefillers[prefiller_key]
decoder = router.decoders[decoder_key]
logger.debug(
"Routing request %s via %s (prefill) -> %s (decode)",
req_id,
prefiller.url,
decoder.url,
)
released_kv = False
async def generate():
nonlocal released_kv
try:
stream_req = ServiceRequest(
client=decoder.client,
prefiller_key=prefiller_key,
endpoint=api_path,
payload=body,
request_id=req_id,
max_retries=global_args.max_retries,
retry_interval=global_args.retry_interval,
)
async for chunk in _stream_decoder(stream_req):
if not released_kv and chunk:
router.release_kv(prefiller_key, prefill_cost)
released_kv = True
yield chunk
except Exception as exc:
logger.error(
"Error streaming from decoder %s for request %s: %s. "
"The aborted request will be recorded for KV release.",
decoder.url,
req_id,
exc,
)
router.mark_prefill_aborted(prefiller_key, req_id)
router.release_kv(prefiller_key, prefill_cost)
finally:
router.release_decoder_tokens(decoder_key, decode_cost)
headers: Dict[str, str] = {}
if is_stream:
headers["Cache-Control"] = "no-cache"
headers["Connection"] = "keep-alive"
headers["X-Accel-Buffering"] = "no"
return StreamingResponse(generate(), media_type=media_type, headers=headers)
async def _proxy_openai_request(api_path: str, request: Request):
"""Core handler that proxies OpenAI-style completion/chat requests."""
try:
(
body,
is_stream,
media_type,
prefill_cost,
decode_cost,
req_id,
) = await _read_request_and_costs(request)
overrides = _parse_node_overrides(request)
prefiller_key = _choose_prefiller(overrides, prefill_cost, req_id)
await _invoke_prefiller_and_update_body(
api_path, body, prefiller_key, prefill_cost, req_id
)
decoder_key = _choose_decoder(overrides, decode_cost, req_id)
return _build_streaming_response(
api_path=api_path,
body=body,
req_id=req_id,
prefiller_key=prefiller_key,
decoder_key=decoder_key,
prefill_cost=prefill_cost,
decode_cost=decode_cost,
is_stream=is_stream,
media_type=media_type,
)
except Exception as exc:
import traceback
logger.error("Error in proxy completion handler for %s: %s", api_path, exc)
logger.error("".join(traceback.format_exception(*sys.exc_info())))
raise
async def _forward_metrics(prefiller_key: str, request: Request,
max_retries: int, retry_interval: float):
"""Forward /metrics call to a specific prefiller node as a streaming response."""
node = router.prefillers[prefiller_key]
headers = dict(request.headers)
headers.pop("host", None)
headers.pop("content-length", None)
async def metrics_stream():
"""Open a streaming connection to the prefiller and relay bytes to the client."""
attempt = 0
while attempt < max_retries:
attempt += 1
try:
async with node.client.stream("GET", "/metrics",
headers=headers) as resp:
resp.raise_for_status()
yield b""
async for chunk in resp.aiter_bytes():
if chunk:
yield chunk
return
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
logger.warning(
"Metrics attempt %s failed for prefiller %s: %s",
attempt,
prefiller_key,
exc,
)
if attempt >= max_retries:
logger.error(
"Metrics exhausted retries (%s) on prefiller %s",
max_retries,
prefiller_key,
)
raise
await asyncio.sleep(retry_interval)
return StreamingResponse(metrics_stream(), media_type="application/json")
@app.post("/v1/completions")
@with_cancellation
async def completions(request: Request):
return await _proxy_openai_request("/v1/completions", request)
@app.post("/v1/chat/completions")
@with_cancellation
async def chat_completions(request: Request):
return await _proxy_openai_request("/v1/chat/completions", request)
@app.get("/listEndPoints")
async def listEndPoints():
"""Return current prefiller/decoder endpoints and their approximate load."""
if not router:
return {
"status": "initializing",
"prefill_nodes": [],
"decode_nodes": [],
}
prefill_nodes = [
{
"endpoint": server.url,
"host": server.host,
"port": server.port,
"active_tokens": server.active_tokens,
}
for server in router.prefillers.values()
]
decode_nodes = [
{
"endpoint": server.url,
"host": server.host,
"port": server.port,
"active_tokens": server.active_tokens,
}
for server in router.decoders.values()
]
return {"status": "ok", "prefill_nodes": prefill_nodes, "decode_nodes": decode_nodes}
@app.get("/metrics")
async def metrics(request: Request):
"""Proxy /metrics to the prefiller that would currently be chosen next."""
try:
key = router.next_metrics_prefiller_key()
logger.debug("Forwarding metrics to prefiller %s", key)
return await _forward_metrics(
key,
request,
max_retries=global_args.max_retries,
retry_interval=global_args.retry_interval,
)
except Exception as exc:
logger.error("Error in metrics proxy: %s", exc)
raise
def _build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--prefiller-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--decoder-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports",
type=int,
nargs="+",
default=[8002])
parser.add_argument(
"--prefiller-labels",
type=str,
default="app=prefill",
help="K8s label selector for prefiller pods "
"(comma-separated, AND semantics: a=b,c=d).",
)
parser.add_argument(
"--decoder-labels",
type=str,
default="app=decode",
help="K8s label selector for decoder pods "
"(comma-separated, AND semantics: a=b,c=d).",
)
parser.add_argument("--prefiller-container",
type=str,
default="prefill-engine")
parser.add_argument("--decoder-container",
type=str,
default="decode-engine")
parser.add_argument("--prefiller-port-name",
type=str,
default="prefill-port")
parser.add_argument("--decoder-port-name",
type=str,
default="decode-port")
parser.add_argument("--namespace", type=str, default="default")
parser.add_argument("--discovery-interval",
type=int,
default=10,
help="K8s service discovery interval (seconds).")
parser.add_argument(
"--max-retries",
type=int,
default=3,
help="Maximum retry attempts when calling backend.",
)
parser.add_argument(
"--retry-interval",
type=float,
default=0.001,
help="Fixed interval (seconds) to wait between retry attempts.",
)
return parser
def parse_args():
parser = _build_arg_parser()
args = parser.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
if __name__ == "__main__":
import uvicorn
global_args = parse_args()
uvicorn.run(app, host=global_args.host, port=global_args.port)