from __future__ import annotations
import logging
import time
from dataclasses import dataclass
import torch
from executor.online.kv_transfer import AscendKVReceiver, KVPoll
from executor.core.forward_data_info import MTPInfo
logger = logging.getLogger(__name__)
@dataclass
class DecodeRequest:
req: object
kv_receiver: object | None = None
waiting_for_input: bool = False
metadata_buffer_index: int = -1
class DecodePreallocQueue:
MAX_RETRIES = 15
RETRY_INTERVAL_S = 1.0
def __init__(
self,
kv_transfer_manager,
kv_cache_manager,
metadata_pool,
transfer_queue,
running_requests,
num_reserved_decode_tokens: int,
max_prefill_tokens: int,
tp_cpu_group=None,
):
self.kv_transfer_manager = kv_transfer_manager
self.kv_cache_manager = kv_cache_manager
self.metadata_pool = metadata_pool
self.transfer_queue = transfer_queue
self.running_requests = running_requests
self.num_reserved_decode_tokens = num_reserved_decode_tokens
self.max_prefill_tokens = max_prefill_tokens
self.tp_cpu_group = tp_cpu_group
self.queue: list[DecodeRequest] = []
self.pending_reqs: list[DecodeRequest] = []
self._ensure_retry_count: dict[str, int] = {}
self._ensure_last_attempt_time: dict[str, float] = {}
self._query_retry_count: dict[int, int] = {}
def add(self, req) -> None:
decode_req = self._create_receiver_and_enqueue(req)
prefill_dp_rank = self._resolve_prefill_dp_rank(req)
if prefill_dp_rank is not None:
has_info = self.kv_transfer_manager.try_ensure_parallel_info(req.bootstrap_addr)
if has_info:
decode_req.kv_receiver.init(prefill_dp_rank)
return
self.pending_reqs.append(decode_req)
def _resolve_prefill_dp_rank(self, req):
if req.disagg_prefill_dp_rank is not None and req.disagg_prefill_dp_rank >= 0:
return req.disagg_prefill_dp_rank
prefill_info = self.kv_transfer_manager.prefill_info_table.get(req.bootstrap_addr)
if prefill_info is None:
return None
if prefill_info.dp_size == 1:
return 0
return None
def _create_receiver_and_enqueue(self, req) -> DecodeRequest:
kv_receiver = AscendKVReceiver(req.bootstrap_room, self.kv_transfer_manager, req.bootstrap_addr)
decode_req = DecodeRequest(req=req, kv_receiver=kv_receiver)
self.queue.append(decode_req)
return decode_req
def _ensure_prefill_info(self, addr_to_reqs: dict[str, list[DecodeRequest]]):
ready: dict[str, list[DecodeRequest]] = {}
remaining: list[DecodeRequest] = []
now = time.monotonic()
for bootstrap_addr, decode_reqs in addr_to_reqs.items():
last_attempt = self._ensure_last_attempt_time.get(bootstrap_addr)
if last_attempt is not None and now - last_attempt < self.RETRY_INTERVAL_S:
remaining.extend(decode_reqs)
continue
self._ensure_last_attempt_time[bootstrap_addr] = now
if self.kv_transfer_manager.try_ensure_parallel_info(bootstrap_addr):
self._ensure_retry_count.pop(bootstrap_addr, None)
self._ensure_last_attempt_time.pop(bootstrap_addr, None)
ready[bootstrap_addr] = decode_reqs
continue
count = self._ensure_retry_count.get(bootstrap_addr, 0) + 1
self._ensure_retry_count[bootstrap_addr] = count
if count >= self.MAX_RETRIES:
logger.warning(
"prefill %s parallel-info unreachable after %d retries; "
"aborting %d pending request(s)",
bootstrap_addr, self.MAX_RETRIES, len(decode_reqs),
)
for decode_req in decode_reqs:
decode_req.kv_receiver.abort()
else:
remaining.extend(decode_reqs)
return ready, remaining
def _resolve_pending_reqs(self) -> None:
if not self.pending_reqs:
return
addr_to_reqs: dict[str, list[DecodeRequest]] = {}
for decode_req in self.pending_reqs:
addr_to_reqs.setdefault(decode_req.req.bootstrap_addr, []).append(decode_req)
ready_addrs, remaining = self._ensure_prefill_info(addr_to_reqs)
resolved: list[tuple[DecodeRequest, int]] = []
for bootstrap_addr, decode_reqs in ready_addrs.items():
need_query: list[DecodeRequest] = []
for decode_req in decode_reqs:
prefill_dp_rank = self._resolve_prefill_dp_rank(decode_req.req)
if prefill_dp_rank is not None:
resolved.append((decode_req, prefill_dp_rank))
else:
need_query.append(decode_req)
if need_query:
rooms = [decode_req.req.bootstrap_room for decode_req in need_query]
room_to_rank = self.kv_transfer_manager.query_prefill_dp_ranks(bootstrap_addr, rooms)
for decode_req in need_query:
prefill_dp_rank = room_to_rank.get(decode_req.req.bootstrap_room)
if prefill_dp_rank is not None and int(prefill_dp_rank) >= 0:
self._query_retry_count.pop(decode_req.req.bootstrap_room, None)
resolved.append((decode_req, int(prefill_dp_rank)))
else:
count = self._query_retry_count.get(decode_req.req.bootstrap_room, 0) + 1
self._query_retry_count[decode_req.req.bootstrap_room] = count
if count >= self.MAX_RETRIES:
logger.warning(
"request %s: prefill_dp_rank query failed for room=%s "
"after %d retries; aborting",
decode_req.req.request_id,
decode_req.req.bootstrap_room,
self.MAX_RETRIES,
)
decode_req.kv_receiver.abort()
else:
remaining.append(decode_req)
self.pending_reqs = remaining
for decode_req, prefill_dp_rank in resolved:
decode_req.kv_receiver.init(prefill_dp_rank)
def _update_handshake_waiters(self) -> None:
for decode_req in self.queue:
if decode_req.waiting_for_input:
continue
poll = decode_req.kv_receiver.poll_and_all_reduce(group=self.tp_cpu_group)
if poll == KVPoll.WaitingForInput:
decode_req.waiting_for_input = True
elif poll == KVPoll.Failed:
logger.warning(
"handshake failed for room=%s addr=%s",
decode_req.req.bootstrap_room, decode_req.req.bootstrap_addr,
)
decode_req.req.is_finished = True
decode_req.req.finish_reason = "error"
self.transfer_queue.terminal_failed.append(decode_req)
def pop_preallocated(self, next_n: int = 0) -> tuple[list[DecodeRequest], list[DecodeRequest]]:
self._resolve_pending_reqs()
self._update_handshake_waiters()
n_wfi = sum(1 for r in self.queue if r.waiting_for_input)
if self.queue or self.pending_reqs:
logger.debug(
"pop_preallocated: queue=%d pending=%d wfi=%d",
len(self.queue), len(self.pending_reqs), n_wfi,
)
preallocated: list[DecodeRequest] = []
failed: list[DecodeRequest] = []
remaining: list[DecodeRequest] = []
for decode_req in self.queue:
room = decode_req.req.bootstrap_room
if decode_req.req.is_finished:
failed.append(decode_req)
continue
if not decode_req.waiting_for_input:
remaining.append(decode_req)
continue
num_tokens = int(decode_req.req.input_ids.numel())
if num_tokens > self.max_prefill_tokens:
logger.warning(
"Dropping room=%s from decode prealloc: prompt_tokens=%d "
"exceeds max_prefill_tokens=%d",
room, num_tokens, self.max_prefill_tokens,
)
decode_req.req.is_finished = True
decode_req.req.finish_reason = "prompt_too_long"
decode_req.req.prompt_tokens = num_tokens
self.transfer_queue.terminal_failed.append(decode_req)
continue
if self.metadata_pool.available_size() <= 0:
logger.debug("preallocate skip room=%s: no metadata slot available", room)
remaining.append(decode_req)
continue
if not self.kv_cache_manager.allocate_slots(
request_id=decode_req.req.request_id,
computed_tokens=num_tokens,
num_new_tokens=1 + next_n,
lookahead_tokens=max(next_n - 1, 0),
reserved_tokens=self.num_reserved_decode_tokens
):
logger.debug("preallocate skip room=%s: KV cache slot allocation failed", room)
remaining.append(decode_req)
continue
decode_req.metadata_buffer_index = self.metadata_pool.alloc()
decode_req.req.metadata_buffer_index = decode_req.metadata_buffer_index
dst_block_ids = self.kv_cache_manager.get_block_ids(decode_req.req.request_id)
decode_req.kv_receiver.send_metadata(decode_req.metadata_buffer_index, dst_block_ids)
self.transfer_queue.add(decode_req)
preallocated.append(decode_req)
self.queue = remaining
return preallocated, failed
class DecodeTransferQueue:
def __init__(self, metadata_pool, tp_cpu_group=None):
self.metadata_pool = metadata_pool
self.tp_cpu_group = tp_cpu_group
self.waiting: list[DecodeRequest] = []
self.terminal_failed: list[DecodeRequest] = []
def add(self, decode_req: DecodeRequest) -> None:
self.waiting.append(decode_req)
def _read_metadata(self, decode_req: DecodeRequest) -> dict | None:
return self.metadata_pool.read(decode_req.metadata_buffer_index)
def _commit_metadata(self, decode_req: DecodeRequest) -> bool:
"""True iff request is resolved (committed or errored); caller removes it."""
meta = self._read_metadata(decode_req)
if meta is None:
return False
actual_room = meta.get("output_bootstrap_room")
if actual_room is None:
return False
if actual_room != decode_req.req.bootstrap_room:
logger.warning(
"metadata bootstrap_room mismatch: got %s expected %s",
actual_room, decode_req.req.bootstrap_room,
)
decode_req.req.is_finished = True
decode_req.req.finish_reason = "error"
return True
output_id = meta.get("output_id")
if output_id is not None:
decode_req.req.output_id_list.append(output_id)
kv_len = meta.get("kv_len")
if kv_len is not None:
decode_req.req.computed_len = int(kv_len)
mtp_spec_tokens = meta.get("mtp_spec_tokens")
if mtp_spec_tokens is not None:
decode_req.req.mtp_info = MTPInfo(
spec_tokens=torch.tensor(mtp_spec_tokens, dtype=torch.long),
)
decode_req.req.is_prefill_done = True
return True
def pop_transferred(self) -> list[DecodeRequest]:
ready = []
indices_to_remove = set()
for i, decode_req in enumerate(self.waiting):
poll = decode_req.kv_receiver.poll_and_all_reduce(group=self.tp_cpu_group)
if poll == KVPoll.Failed:
logger.warning(
"request %s: KV transfer failed (room=%s) — marking as error",
decode_req.req.request_id, decode_req.req.bootstrap_room,
)
decode_req.req.is_finished = True
decode_req.req.finish_reason = "error"
decode_req.kv_receiver.clear()
self.terminal_failed.append(decode_req)
indices_to_remove.add(i)
elif poll == KVPoll.Success:
if self._commit_metadata(decode_req):
indices_to_remove.add(i)
decode_req.kv_receiver.clear()
if decode_req.req.finish_reason == "error":
self.terminal_failed.append(decode_req)
else:
ready.append(decode_req)
elif poll in (KVPoll.Bootstrapping, KVPoll.WaitingForInput, KVPoll.Transferring):
pass
else:
raise ValueError(f"Unexpected poll case: {poll}")
for i in indices_to_remove:
idx = self.waiting[i].metadata_buffer_index
if idx is not None and idx >= 0:
self.metadata_pool.free(idx)
self.waiting[i].req.metadata_buffer_index = -1
self.waiting = [
entry for i, entry in enumerate(self.waiting) if i not in indices_to_remove
]
return ready