import itertools
from enum import auto, Enum
from typing import Optional
from blinker import signal
import serving_cast.stime as stime
class RequestState(Enum):
INITIAL = auto()
LEAVES_CLIENT = auto()
ARRIVES_SERVER = auto()
PREFILLING = auto()
PREFILL_DONE = auto()
RECOMPUTATION = auto()
KVS_TRANSFERRING = auto()
DECODING = auto()
DECODE_DONE = auto()
COMPLETED = DECODE_DONE
class Request:
id_counter = itertools.count()
def __init__(self, **kwargs):
super().__init__()
given_id = kwargs.get("id")
if given_id is not None:
if isinstance(given_id, int):
self.id = given_id
else:
raise ValueError("Request.__init__ failed: given id should be int")
else:
self.id = next(self.id_counter)
self.model_name: Optional[str] = kwargs.get("model_name")
self.num_input_tokens: int = kwargs.get("num_input_tokens", 0)
self.num_output_tokens: int = kwargs.get("num_output_tokens", 0)
self._state: RequestState = RequestState.INITIAL
self.state_change_signal = signal(f"state_changed_{self.id}")
self.before_prefill_done_signal = signal(f"before_prefill_done_{self.id}")
self.kvs_transferring_signal = signal(f"kvs_transferring_{self.id}")
self.prefill_done_signal = signal(f"prefill_done_{self.id}")
self.decode_done_signal = signal(f"decode_done_{self.id}")
self.num_decoded_tokens: int = 0
self.num_current_max_new_tokens = self.num_input_tokens
self.seq_len = 0
self.query_len = 0
self.leaves_client_time = 0
self.arrives_server_time = 0
self.prefill_done_time = 0
self.decode_done_time = 0
self.prefill_done_time_already_recorded = False
self.decode_done_time_already_recorded = False
self.need_kv_transfer = False
self.kv_transfer_done = False
def __str__(self) -> str:
ttft = ""
tpot = ""
total = ""
if self.state.value >= RequestState.PREFILL_DONE.value:
ttft = f", ttft={self.time_to_first_token():.3f}"
if self.state.value >= RequestState.DECODE_DONE.value:
tpot = f", tpot={self.time_per_output_token():.3f}"
total = f", total={self.serving_time():.3f}"
res = (
f"Request(id={self.id}, model_name={self.model_name}, state={self.state}{ttft}{tpot}{total}, "
f"num_decoded={self.num_decoded_tokens},num_inputs={self.num_input_tokens}, "
f"num_outputs={self.num_output_tokens})"
)
return res
@property
def state(self):
return self._state
@state.setter
def state(self, new_state):
old_state = self._state
if new_state == RequestState.PREFILL_DONE:
self.before_prefill_done_signal.send(self)
self._state = new_state
self.state_change_signal.send(self, old_state=old_state, new_state=new_state)
if new_state == RequestState.DECODE_DONE:
if not self.decode_done_time_already_recorded:
self.decode_done_time = stime.now()
self.decode_done_time_already_recorded = True
self.decode_done_signal.send(self)
elif new_state == RequestState.PREFILL_DONE:
if not self.prefill_done_time_already_recorded:
self.prefill_done_time = stime.now()
self.prefill_done_time_already_recorded = True
self.prefill_done_signal.send(self)
elif new_state == RequestState.ARRIVES_SERVER:
self.arrives_server_time = stime.now()
elif new_state == RequestState.LEAVES_CLIENT:
self.leaves_client_time = stime.now()
elif new_state == RequestState.KVS_TRANSFERRING:
self.kvs_transferring_time = stime.now()
self.kvs_transferring_signal.send(self)
def time_to_first_token(self):
return self.prefill_done_time - self.leaves_client_time
def time_per_output_token(self):
if self.num_output_tokens == 1:
return 0
return (self.decode_done_time - self.prefill_done_time) / (self.num_output_tokens - 1)
def serving_time(self):
return self.decode_done_time - self.leaves_client_time