import time
from enum import Enum
from pydantic import BaseModel, Field, PrivateAttr
import anyio
from motor.common.resources.instance import PDRole
from motor.coordinator.domain.scheduling_constraint import SchedulingConstraint
from motor.coordinator.tracer.tracing import TraceObj
from motor.coordinator.models.constants import OpenAIField
class RequestType(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
TRITON = "triton"
TGI = "tgi"
VLLM = "vllm"
MINDIE = "mindie"
class ReqState(Enum):
ARRIVE = 'Arrive'
E_SCHEDULING = 'E_Scheduling'
E_ALLOCATED = 'E_Allocated'
P_SCHEDULING = 'P_Scheduling'
P_ALLOCATED = 'P_Allocated'
PREFILL_END = 'Prefill End'
FIRST_TOKEN_FINISH = "First Token Finish"
D_SCHEDULING = 'D_Scheduling'
D_ALLOCATED = 'D_Allocated'
DECODE_END = 'Decode End'
INVALID = 'Invalid'
TIMEOUT = 'Timeout'
EXCEPTION = 'Exception'
RECOMPUTE = 'Recompute'
class RequestInfo(BaseModel):
req_id: str = Field(..., description="Request ID generated by RequestManager")
req_data: dict = Field(..., description="Request json content")
req_len: int = Field(..., description="Request body length")
token_ids: list[int] | None = Field(
default=None,
description="Prompt token ids tokenized once at routing (KV affinity); reused for "
"prefill load accounting so load and affinity share the same token unit",
)
kv_affinity_debug: dict | None = Field(
default=None,
exclude=True,
description="Per-endpoint (matched_tokens, load_cost, prefill_cost) cached by the "
"kv_cache_affinity policy at selection; the worker forwards prefill_cost for the "
"scheduler's global fresh-load re-rank and logs matched/load for the committed endpoint. "
"Keyed by (instance_id, endpoint_id) tuples, so excluded from serialization.",
)
api: str = Field(..., description="API need to be forwarded")
entry_api: str = Field(
default="",
description="Original client HTTP path (e.g. v1/chat/completions); unchanged on recompute retry",
)
recompute_engine_mode: str | None = Field(
default=None,
description="Set to 'completions' when recompute retry uses v1/completions body",
)
client_expects_token_ids: bool = Field(
default=False,
description="True if the original client request contains a 'return_token_ids' parameter",
)
client_expects_chat_shape: bool = Field(
default=False,
description="True if the original client request was chat completions (messages present at ingress)",
)
state: ReqState = Field(default=ReqState.ARRIVE, description="Request current status")
status: dict[ReqState, float] = Field(default={}, description="Request status time")
trace_obj: TraceObj = Field(default_factory=TraceObj, description="Tracing object")
_p_cancel_scope: anyio.CancelScope | None = PrivateAttr(default=None)
_d_cancel_scope: anyio.CancelScope | None = PrivateAttr(default=None)
_e_cancel_scope: anyio.CancelScope | None = PrivateAttr(default=None)
prompt_tokens_details: dict = Field(default={}, description="prefill prompt_tokens_details")
prompt_token_ids: list = Field(default=[], description="prefill prompt_token_ids")
cached_token_ids: list = Field(default=[], description="Cached token_ids")
p_instance_id: int | None = Field(
default=None, description="P instance ID set by metaserver handler in CDP-like modes"
)
scheduling_constraint: SchedulingConstraint | None = Field(
default=None,
description="Internal pin-to-instance constraint (e.g. precision probe); not from client API",
)
def __init__(self, **data):
super().__init__(**data)
self.status[ReqState.ARRIVE] = time.time()
@property
def is_cancelled(self) -> bool:
return (
(self._p_cancel_scope and self._p_cancel_scope.cancel_called)
or (self._d_cancel_scope and self._d_cancel_scope.cancel_called)
or (self._e_cancel_scope and self._e_cancel_scope.cancel_called)
)
def effective_entry_api(self) -> str:
"""Path used for client-contract checks (Chat vs Completion); falls back to ``api``."""
return self.entry_api or self.api
def update_state(self, new_state: ReqState):
self.state = new_state
self.status[new_state] = time.time()
def update_prompt_tokens_details(self, prompt_tokens_details: dict):
self.prompt_tokens_details = prompt_tokens_details
def set_cancel_scope(self, cancel_scope: anyio.CancelScope, role: PDRole):
if role == PDRole.ROLE_P:
self._p_cancel_scope = cancel_scope
elif role == PDRole.ROLE_D:
self._d_cancel_scope = cancel_scope
elif role == PDRole.ROLE_E:
self._e_cancel_scope = cancel_scope
def cancel_scope(self):
if self._p_cancel_scope and not self._p_cancel_scope.cancel_called:
self._p_cancel_scope.cancel()
if self._d_cancel_scope and not self._d_cancel_scope.cancel_called:
self._d_cancel_scope.cancel()
if self._e_cancel_scope and not self._e_cancel_scope.cancel_called:
self._e_cancel_scope.cancel()
def update_token_id_cache(self, chunk_json: dict) -> None:
"""Accumulate ``return_token_ids`` response fields and cached them.
- Root ``prompt_token_ids``: set ``cached_prompt_token_ids`` once (first non-null list).
- ``choices[0].prompt_token_ids`` (Completion stream): promoted when root is absent.
- ``choices[0].token_ids``: extend ``cached_output_token_ids`` when a list.
"""
pti = chunk_json.get(OpenAIField.PROMPT_TOKEN_IDS)
if pti is None:
choices = chunk_json.get(OpenAIField.CHOICES) or []
if choices and isinstance(choices[0], dict):
pti = choices[0].get(OpenAIField.PROMPT_TOKEN_IDS)
if isinstance(pti, (list, tuple)) and len(self.prompt_token_ids) == 0:
self.prompt_token_ids = list(pti)
choices = chunk_json.get(OpenAIField.CHOICES) or []
if not choices:
return
c0 = choices[0]
token_ids = c0.get(OpenAIField.TOKEN_IDS)
if isinstance(token_ids, list):
self.cached_token_ids.extend(token_ids)