"""Block pool abstractions for paged-attention KV cache management."""
from collections import deque
from typing import Deque, Iterable, List, Optional
class BlockPool:
"""Owns the global pool of KV cache blocks."""
def __init__(self, num_blocks: int, block_size: int):
if num_blocks <= 0:
raise ValueError("num_blocks must be positive")
self.num_blocks = num_blocks
self.block_size = block_size
self._free_block_queue: Deque[int] = deque(range(num_blocks))
self._null_block = self._free_block_queue.popleft()
self._free_block_set = set(self._free_block_queue)
def get_num_free_blocks(self) -> int:
"""Return the number of currently available blocks."""
return len(self._free_block_queue)
def get_block(self, block_id: int) -> int:
"""Return block metadata by block id."""
if block_id < 0 or block_id >= self.num_blocks:
raise IndexError(f"block_id {block_id} is out of range [0, {self.num_blocks})")
return block_id
def get_new_blocks(self, num_blocks: int) -> List[int]:
"""Allocate a number of blocks from the pool."""
if num_blocks < 0:
raise ValueError("num_blocks must be non-negative")
if num_blocks > len(self._free_block_queue):
raise ValueError(
f"requested {num_blocks} blocks, but only {len(self._free_block_queue)} are available"
)
allocated_blocks: List[int] = []
for _ in range(num_blocks):
block_id = self._free_block_queue.popleft()
self._free_block_set.remove(block_id)
allocated_blocks.append(block_id)
return allocated_blocks
def free_blocks(self, blocks: List[int]) -> None:
"""Recycle blocks back to the pool."""
for block_id in blocks:
self.get_block(block_id)
if block_id == self._null_block:
continue
if block_id in self._free_block_set:
raise ValueError(f"block_id {block_id} is already free")
self._free_block_queue.append(block_id)
self._free_block_set.add(block_id)
def get_null_block(self) -> Optional[int]:
"""Return the optional null block used by sparse managers."""
return self._null_block