import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass(frozen=True)
class StepDecision:
"""Prefill/decode concurrency selected for one simulated mixed step."""
p_step: int
d_step: int
@dataclass(frozen=True)
class SchedulerState:
"""Queue and token-budget snapshot used by a scheduler decision."""
ready_decode: int
pending_prefill: int
chunk_query_len: int
max_batched_tokens: int
class Scheduler(ABC):
"""Policy interface for selecting prefill/decode work in each simulation step."""
@abstractmethod
def decide(self, state: SchedulerState) -> StepDecision:
"""Return how many prefill and decode requests should run next."""
...
@abstractmethod
def step_latency(self, prefill_latency: float, decode_latency: float) -> float:
"""Combine the modeled latency of prefill and decode work in one step."""
...
class DecodeFirstWithSlack(Scheduler):
"""Default scheduler: prioritize decode, then admit prefill with a small slack budget."""
slack_ratio = 1.15
def decide(self, state: SchedulerState) -> StepDecision:
"""Prefer ready decode work, then admit prefill within the slack-adjusted budget."""
limit = math.floor(state.max_batched_tokens * self.slack_ratio)
d_step = min(state.ready_decode, state.max_batched_tokens)
if state.chunk_query_len <= 0:
p_step = 0
else:
p_step = max(
0,
min(
state.pending_prefill,
math.floor((limit - d_step) / state.chunk_query_len),
),
)
return StepDecision(p_step=p_step, d_step=d_step)
def step_latency(self, prefill_latency: float, decode_latency: float) -> float:
"""Model mixed prefill/decode work as overlapped in one scheduling step."""
return max(prefill_latency, decode_latency)