import itertools
from typing import List
from serving_cast.engine import Engine, EngineLoadBalancer
from serving_cast.request import Request, RequestState
import serving_cast.stime as stime
logger = stime.get_logger(__name__)
class Instance:
id_counter = itertools.count()
def __init__(self, instance_config):
self.id = next(self.id_counter)
num_devices = instance_config.num_devices_per_instance
dp_size = instance_config.parallel_config.dp_size
device_type = instance_config.device_type
if num_devices % dp_size != 0:
raise ValueError(
"In instance __init__, num_devices must be divisible by dp_size,but got num_devices = %d, dp_size = %d",
num_devices,
dp_size,
)
self.engines: List[Engine] = [
Engine(instance_config=instance_config, device_type=device_type, dp_rank=i) for i in range(dp_size)
]
self.load_balancer = EngineLoadBalancer(self.engines)
def handle(self, request: Request):
logger.debug("Instance %d handling %s", self.id, request)
if request.state not in [
RequestState.ARRIVES_SERVER,
RequestState.KVS_TRANSFERRING,
]:
raise ValueError(
"Instance.handle failed, request.state should be ARRIVES_SERVER or KVS_TRANSFERRING, but get %s",
request.state,
)
if request.state == RequestState.ARRIVES_SERVER:
request.state = RequestState.PREFILLING
else:
request.state = RequestState.DECODING
engine = self.load_balancer.select(request)
engine.handle(request)
def get_work_load(self):
return sum(engine.get_work_load() for engine in self.engines)
def shutdown(self):
for engine in self.engines:
engine.shutdown()
class InstanceLoadBalancer:
def __init__(self, instances: List[Instance]):
self.instances = instances
def select(self, request: Request) -> Instance:
work_loads = [instance.get_work_load() for instance in self.instances]
min_value = min(work_loads)
min_index = work_loads.index(min_value)
return self.instances[min_index]