from abc import ABC, abstractmethod
from typing import List
from serving_cast.config import Config
from serving_cast.instance import Instance, InstanceLoadBalancer
from serving_cast.request import Request, RequestState
import serving_cast.stime as stime
logger = stime.get_logger(__name__)
class Serving(ABC):
"""
The abstract class for inference request serving.
Requests could come from either the client side (an initial request) or some server instance such as
Prefill instance which has completed prefill and wants to hand over the request to the Decode instance.
Serving is responsible for picking the right server instances to dispatch to according
to a pre-defined policy.
"""
def __init__(self):
self.max_concurrency = Config.get_instance().common_config.serving_config.max_concurrency
@abstractmethod
def serve(self, args, **kwargs) -> None:
"""
Serves a request.
"""
raise NotImplementedError
@abstractmethod
def get_work_load(self) -> int:
"""
Returns the number of requests currently being served.
"""
raise NotImplementedError
def exceed_concurrency_limit(self) -> bool:
"""
check whether the concurrency limit is exceeded
"""
return self.get_work_load() >= self.max_concurrency
def _before_serve(self, request: Request):
"""
process request, LEAVES_CLIENT --> ARRIVES_SERVER. Same for all kinds of serving
"""
if request.state != RequestState.LEAVES_CLIENT:
raise ValueError("request.state != RequestState.LEAVES_CLIENT")
request.state = RequestState.ARRIVES_SERVER
logger.debug("Start serving %s", request)
class PdDisaggregationServing(Serving):
"""
P/D disaggregation case
The overall request serving flow looks like below:
Requests are firstly dispatched to a prefill server instance, then the instance dispatches the requests to
an Engine which corresponds to a Data-Parallel partition. Then the Engine schedules the incoming Requests.
After request have done prefilling, it is sent to decode server instance, and do the similar thing as that in
prefill server instance.
"""
def __init__(self, prefill_instances: List[Instance], decode_instances: List[Instance]):
super().__init__()
self.prefill_instances = prefill_instances
self.decode_instances = decode_instances
self.prefill_balancer = InstanceLoadBalancer(prefill_instances)
self.decode_balancer = InstanceLoadBalancer(decode_instances)
def serve(self, request: Request):
"""Handle the request from the client side"""
self._before_serve(request)
request.need_kv_transfer = True
request.kvs_transferring_signal.connect(self._continue_serve_callback)
prefill_instance = self.prefill_balancer.select(request)
prefill_instance.handle(request)
def get_work_load(self):
work_load = sum(instance.get_work_load() for instance in self.prefill_instances) + sum(
instance.get_work_load() for instance in self.decode_instances
)
return work_load
def _continue_serve_callback(self, request: Request):
"""Continue serving"""
logger.debug("Continue serving %s", request)
if request.state != RequestState.KVS_TRANSFERRING:
raise ValueError(f"In continue serving: request.state shoulf be KVS_TRANSFERRING, but get {request.state}")
decode_instance = self.decode_balancer.select(request)
decode_instance.handle(request)
class PdAggregationServing(Serving):
"""
P/D aggregation case
The overall request serving flow looks like below:
Requests are firstly dispatched to a server instance, then the instance dispatches the requests to
an Engine which corresponds to a Data-Parallel partition.
Then the Engine schedule the incoming requests to waiting queue or running queue.
After the requests are scheduled to running queue, the ModelRunner will start to execute the requests.
"""
def __init__(self, prefill_decode_instances: List[Instance]):
super().__init__()
self.prefill_decode_instances = prefill_decode_instances
self.prefill_decode_balancer = InstanceLoadBalancer(prefill_decode_instances)
def serve(self, request: Request):
"""Handle the request from the client side"""
self._before_serve(request)
prefill_decode_instance = self.prefill_decode_balancer.select(request)
prefill_decode_instance.handle(request)
def get_work_load(self):
"""Get the work load of the instance group"""
return sum(instance.get_work_load() for instance in self.prefill_decode_instances)