"""Cache metadata structures for paged-attention initialization."""
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import torch
@dataclass
class CacheEntry:
"""Single cache entry metadata."""
cache_name: str
attn_type: str
dim: Union[int, List[int]]
num_head: int
dtype: torch.dtype
needs_block: bool
block_size: Optional[int] = None
manager_key: Optional[str] = None
tensor_setter: Optional[Callable[[torch.Tensor], None]] = None
sliding_window: Optional[int] = None
tensor: Optional[torch.Tensor] = None
@property
def group_key(self) -> str:
"""Manager grouping key for cache allocation and metadata tables."""
return self.manager_key if self.manager_key is not None else self.attn_type
def cache_dim_numel(self) -> int:
"""Return flattened element count for a cache entry's trailing dim."""
dims = self.dim if isinstance(self.dim, list) else [self.dim]
numel = 1
for cur_dim in dims:
numel *= cur_dim
return numel
@dataclass
class LayerCacheInfo:
"""Cache metadata for one transformer layer."""
layer_idx: int
caches: List[CacheEntry]
@dataclass
class ModelCacheInfo:
"""Whole-model cache metadata."""
num_layers: int
layer_infos: List[LayerCacheInfo]
is_mla_backend: bool = False
def merge(self, other: "ModelCacheInfo") -> "ModelCacheInfo":
"""Merge two cache-info objects into one complete model description."""
if self.is_mla_backend != other.is_mla_backend:
raise ValueError(
"is_mla_backend mismatch across merged cache infos: "
f"{self.is_mla_backend} vs {other.is_mla_backend}"
)
merged_layer_infos = list(self.layer_infos)
layer_idx_offset = len(merged_layer_infos)
for layer_info in other.layer_infos:
layer_info.layer_idx += layer_idx_offset
merged_layer_infos.extend(other.layer_infos)
return ModelCacheInfo(
num_layers=len(merged_layer_infos),
layer_infos=merged_layer_infos,
is_mla_backend=self.is_mla_backend,
)