import subprocess
import time
from pathlib import Path
from typing import Tuple, Optional
from math import isclose
import numpy as np
from loguru import logger
from ..common import get_train_sub_path, is_mindie, is_vllm
from ..config.base_config import REAL_EVALUATION
from ..config.config import (
get_settings,
DecodeContext,
PerformanceIndex,
OptimizerConfigField,
map_param_with_value,
ErrorSeverity,
)
from ..config.base_config import FOLDER_LIMIT_SIZE, REQUESTRATES
from ..config.constant import Stage
from ..optimizer.plugins.simulate import Simulator
from ..optimizer.store import DataStorage
from ..optimizer.utils import get_folder_size
from ..optimizer.health_check import (
ErrorContext,
BenchmarkHealthCheckHook,
ServiceHealthCheckHook,
ServiceHookPoint,
BenchmarkHookPoint,
HealthCheckContext,
service_health_checks_hooks,
benchmark_health_checks_hooks,
FatalError,
RetryableError,
)
class Scheduler:
def __init__(
self,
simulator,
benchmark,
data_storage: DataStorage,
bak_path: Optional[Path] = None,
retry_number: int = 3,
wait_start_time: Optional[int] = None,
):
self.simulator = simulator
self.benchmark = benchmark
self.data_storage = data_storage
self.bak_path = bak_path
self.retry_number = retry_number
self.wait_time = wait_start_time or get_settings().wait_start_time
self.current_back_path = None
self.simulate_run_info = None
self.performance_index = None
self.error_info = None
self.run_start_timestamp = None
self.first_duration = None
self.del_log = None
self.service_checks = ServiceHealthCheckHook()
self.benchmark_checks = BenchmarkHealthCheckHook()
self._register_default_checks()
def _register_default_checks(self):
"""Register default health checks (can be overridden by subclasses)"""
for name, func, priority in service_health_checks_hooks:
self.service_checks.register(name, func, priority=priority)
for name, func, priority in benchmark_health_checks_hooks:
self.benchmark_checks.register(name, func, priority=priority)
def _create_check_context(self, elapsed: float) -> HealthCheckContext:
"""Create check context"""
return HealthCheckContext(
simulator=self.simulator,
benchmark=self.benchmark,
scheduler=self,
current_time=time.time(),
elapsed_time=elapsed,
)
def _handle_error(self, error_context: ErrorContext) -> None:
"""Raise different exceptions based on error type"""
if error_context.severity == ErrorSeverity.FATAL:
logger.error(f"Fatal error: {error_context.message}")
raise FatalError(error_context.message)
else:
logger.warning(f"Retryable error: {error_context.message}")
raise RetryableError(error_context.message)
def set_back_up_path(self):
if self.bak_path:
if get_folder_size(self.bak_path) > FOLDER_LIMIT_SIZE:
self.simulator.bak_path = None
self.benchmark.bak_path = None
else:
self.current_back_path = get_train_sub_path(self.bak_path)
self.simulator.bak_path = self.current_back_path
self.benchmark.bak_path = self.current_back_path
def wait_simulate(self):
logger.debug("wait run simulator")
start_time = time.time()
for _ in range(self.wait_time):
time.sleep(1)
elapsed = time.time() - start_time
context = self._create_check_context(elapsed)
result = self.service_checks.run(ServiceHookPoint.STARTUP_POLLING, context)
if result.is_healthy:
if hasattr(self.simulator, "check_success") and self.simulator.check_success():
logger.info(f"Successfully started the {self.simulator.process} process.")
return
elif hasattr(self.simulator, "health"):
res = self.simulator.health()
if res.stage == Stage.running:
logger.info(f"Successfully started the {self.simulator.process} process.")
return
elif res.stage == Stage.start:
if int(elapsed) % 60 == 0:
logger.warning(
f"Check the service status at {elapsed} seconds. status: {res.stage}. info: {res.info}"
)
continue
elif res.stage in (Stage.stop, Stage.error):
logger.error(f"Failed to start the {self.simulator.process} process.")
raise subprocess.SubprocessError(
f"Failed in run simulator. status: {res.stage}. info: {res.info}"
)
else:
logger.warning(f" Unknown Status. status: {res.stage}. info: {res.info}")
else:
raise RuntimeError(
f"No actionable method found. the expected is check_success or health. "
f"simulator: {type(self.simulator)}"
)
else:
self._handle_error(result.error_context)
raise TimeoutError(self.wait_time)
def run_simulate(self, params: np.ndarray, params_field: Tuple[OptimizerConfigField]):
if hasattr(self.benchmark, "prepare"):
self.benchmark.prepare()
self.simulator.run(tuple(self.simulate_run_info))
self.wait_simulate()
def backup(self):
self.simulator.backup()
self.benchmark.backup()
def monitoring_status(self):
logger.debug("monitor status")
start_time = time.time()
for _ in range(get_settings().particles_time_out):
elapsed = time.time() - start_time
context = self._create_check_context(elapsed)
service_result = self.service_checks.run(ServiceHookPoint.RUNTIME_MONITOR, context)
if not service_result.is_healthy:
self._handle_error(service_result.error_context)
benchmark_result = self.benchmark_checks.run(BenchmarkHookPoint.RUNTIME_MONITOR, context)
if not benchmark_result.is_healthy:
self._handle_error(benchmark_result.error_context)
if hasattr(self.simulator, "check_success"):
if is_mindie() or is_vllm():
if self.simulator.process.poll() is not None:
raise subprocess.SubprocessError(
f"Failed in run simulator. return code: {self.simulator.process.returncode}."
)
if self.benchmark.check_success():
return
if hasattr(self.simulator, "health"):
if not isinstance(self.simulator, Simulator):
res = self.simulator.health()
if res.stage != Stage.running:
raise subprocess.SubprocessError(
f"Failed in run simulator. error: {res.stage} info: {res.info}."
)
res = self.benchmark.health()
if res.stage != Stage.running:
return
if self.run_start_timestamp and self.first_duration:
_duration = time.time() - self.run_start_timestamp
if _duration > 2 * self.first_duration:
logger.warning("The current runtime is more than twice the duration of the first run.")
time.sleep(1)
raise TimeoutError(f"{get_settings().particles_time_out}")
def run_target_server(self, params: np.ndarray, params_field: Tuple[OptimizerConfigField]):
"""
1. Start mindie simulation
2. Start benchmark test
3. Check mindie status, check benchmark status
"""
for attempt in range(self.retry_number):
try:
self.run_simulate(params, params_field)
time.sleep(1)
self.benchmark.run(tuple(self.simulate_run_info))
time.sleep(1)
self.monitoring_status()
return
except FatalError as e:
logger.error(
f"Fatal error in run_target_server (attempt {attempt + 1}/{self.retry_number}): {e}, \n"
f"simulator log: {self.simulator.run_log}, \n"
f"log last info: {self.simulator.get_last_log()}"
)
self.stop_target_server(False)
raise
except RetryableError as e:
logger.warning(
f"Retryable error in run_target_server (attempt {attempt + 1}/{self.retry_number}): {e}, \n"
f"simulator log: {self.simulator.run_log}, \n"
f"log last info: {self.simulator.get_last_log()}"
)
self.stop_target_server(False)
continue
raise ValueError(f"Failed in run_target_server after {self.retry_number} attempts")
def stop_target_server(self, del_log: bool = False):
self.simulator.stop(del_log)
self.benchmark.stop(del_log)
def save_result(self, **kwargs):
duration = None
if self.run_start_timestamp:
duration = time.time() - self.run_start_timestamp
if not self.first_duration:
self.first_duration = duration
real_evaluation = True
if REAL_EVALUATION in kwargs:
real_evaluation = kwargs.pop(REAL_EVALUATION)
self.data_storage.save(
self.performance_index,
tuple(self.simulate_run_info),
error=self.error_info,
backup=self.current_back_path,
duration=duration,
real_evaluation=real_evaluation,
**kwargs,
)
if self.bak_path:
self.backup()
self.stop_target_server()
def update_data_field(self, params_field: Tuple[OptimizerConfigField]):
if hasattr(self.simulator, "data_field"):
self.simulator.data_field = params_field
if hasattr(self.simulator, "update_command"):
self.simulator.update_command()
if hasattr(self.benchmark, "data_field"):
self.benchmark.data_field = params_field
if hasattr(self.benchmark, "update_command"):
self.benchmark.update_command()
def run(
self,
params: np.ndarray,
params_field: Tuple[OptimizerConfigField],
decode_context: Optional[DecodeContext] = None,
) -> PerformanceIndex:
"""
1. Start mindie simulation
2. Start benchmark test
3. Get benchmark test results
4. Stop mindie simulation
5. Return benchmark test results
params: 1D array whose values correspond to mindie related configurations.
"""
self.run_start_timestamp = time.time()
logger.debug("Start run in scheduler.")
self.set_back_up_path()
self.simulate_run_info = map_param_with_value(params, params_field, decode_context)
logger.info("run param info {}", {v.name: v.value for v in self.simulate_run_info})
self.error_info = None
self.del_log = True
self.performance_index = PerformanceIndex()
try:
self.update_data_field(self.simulate_run_info)
self.run_target_server(params, params_field)
time.sleep(1)
self.performance_index = self.benchmark.get_performance_index()
except Exception as e:
logger.error(
f"Failed running. bak path: {self.simulator.bak_path}. error {e}"
f"simulator log {self.simulator.run_log}, benchmark log {self.benchmark.run_log}"
)
self.error_info = e
self.del_log = False
return self.performance_index
def run_with_request_rate(
self,
params: np.ndarray,
params_field: Tuple[OptimizerConfigField],
decode_context: Optional[DecodeContext] = None,
) -> PerformanceIndex:
"""
Run the service: first run at max concurrency to get request rate,
then run based on concurrency and request rate, using the last run as the evaluation result.
params: 1D array whose values correspond to mindie related configurations.
"""
self.run_start_timestamp = time.time()
self.set_back_up_path()
self.simulate_run_info = map_param_with_value(params, params_field, decode_context)
logger.info("run param info {}", {v.name: v.value for v in self.simulate_run_info})
self.error_info = None
self.del_log = True
self.performance_index = PerformanceIndex()
try:
self.update_data_field(self.simulate_run_info)
self.run_target_server(params, params_field)
time.sleep(1)
self.performance_index = self.benchmark.get_performance_index()
self.benchmark.stop()
need_second_run = False
for _field in self.simulate_run_info:
if _field.name in REQUESTRATES:
if not isclose(_field.min, _field.max):
_field.value = _field.find_available_value(self.performance_index.throughput * 1.05)
need_second_run = True
if not need_second_run:
logger.info("REQUESTRATE is fixed (min == max), skipping second run.")
else:
logger.info(
"second run param info {}",
{v.name: v.value for v in self.simulate_run_info},
)
if hasattr(self.benchmark, "data_field"):
self.benchmark.data_field = params_field
self.benchmark.update_command()
try:
if hasattr(self.benchmark, "prepare"):
self.benchmark.prepare()
self.benchmark.run(tuple(self.simulate_run_info))
except Exception as e:
logger.error(f"Failed in Benchmark Running. error: {e}, benchmark log {self.benchmark.run_log}")
raise e
try:
self.monitoring_status()
except Exception as e:
logger.error(
f"Failed in monitoring status. error: {e}, simulator log {self.simulator.run_log}, "
f"benchmark log {self.benchmark.run_log}"
)
raise e
time.sleep(1)
self.performance_index = self.benchmark.get_performance_index()
except Exception as e:
logger.error(
f"Failed running. bak path: {self.simulator.bak_path}. error {e}"
f"simulator log {self.simulator.run_log}, benchmark log {self.benchmark.run_log}"
)
self.error_info = e
self.del_log = False
return self.performance_index