import dataclasses
from typing import Optional
import torch
@dataclasses.dataclass
class SpecDecodeMetadata:
"""Shape metadata for speculative decode logits selection."""
logits_indices: torch.Tensor
num_active_requests: int
num_speculative_tokens: int
@dataclasses.dataclass
class SamplingMetadata:
"""
A simplified sampling data assuming the sampling parameters like top_k/top_k are the same across all
requests.
"""
query_start_loc: Optional[torch.Tensor] = None
"""(batch_size + 1,), the start location of each request in query Tensor. If not set,
the request inputs have the same length indicated by the input_ids shape:
(batch_size, query_length, hidden_size).
"""
selected_token_indices: Optional[torch.Tensor] = dataclasses.field(
default_factory=lambda: torch.tensor(-1, dtype=torch.long)
)
spec_decode_metadata: Optional[SpecDecodeMetadata] = None
top_k: Optional[int] = None
def _index_select_token_dim(hidden_states: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
indices = indices.to(hidden_states.device)
if indices.ndim == 0:
indices = indices.unsqueeze(0)
if hidden_states.ndim == 3:
return hidden_states.index_select(1, indices)
return hidden_states.index_select(0, indices)
def _has_explicit_selected_token_indices(indices: Optional[torch.Tensor]) -> bool:
return indices is not None and indices.ndim > 0
def _select_logits_for_sampling(logits: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
indices = indices.to(logits.device)
selected_rows = indices.numel()
if indices.ndim == 0:
selected_rows = 1
if logits.ndim == 3 and logits.size(1) == selected_rows:
return logits.reshape(-1, logits.size(-1))
if logits.ndim == 2 and logits.size(0) == selected_rows:
return logits
logits = _index_select_token_dim(logits, indices)
if logits.ndim == 3:
return logits.reshape(-1, logits.size(-1))
return logits
def _spec_window_size(spec_metadata: SpecDecodeMetadata) -> int:
if spec_metadata.num_speculative_tokens <= 0:
raise ValueError("num_speculative_tokens must be positive for spec decode metadata")
if spec_metadata.num_active_requests <= 0:
raise ValueError("num_active_requests must be positive for spec decode metadata")
return spec_metadata.num_speculative_tokens + 1
def _proposal_logits_indices(spec_metadata: SpecDecodeMetadata) -> torch.Tensor:
return spec_metadata.logits_indices.reshape(spec_metadata.num_active_requests, _spec_window_size(spec_metadata))[
:, -1
]
def _validate_spec_decode_indices(spec_metadata: SpecDecodeMetadata) -> None:
spec_window = _spec_window_size(spec_metadata)
expected_indices = spec_metadata.num_active_requests * spec_window
actual_indices = spec_metadata.logits_indices.numel()
if actual_indices != expected_indices:
raise ValueError(
"Spec decode logits_indices length must equal num_active_requests * "
f"(num_speculative_tokens + 1), got {actual_indices} for "
f"{spec_metadata.num_active_requests} requests and {spec_metadata.num_speculative_tokens} "
f"speculative tokens; expected {expected_indices}."
)
def select_lm_head_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata],
mode: str = "target",
) -> torch.Tensor:
if sampling_metadata is None:
return hidden_states
spec_metadata = sampling_metadata.spec_decode_metadata
if spec_metadata is not None:
if hidden_states.ndim == 3:
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
_validate_spec_decode_indices(spec_metadata)
if mode == "target":
return _index_select_token_dim(hidden_states, spec_metadata.logits_indices)
if mode == "proposal":
return _index_select_token_dim(hidden_states, _proposal_logits_indices(spec_metadata))
raise ValueError(f"Unsupported lm_head selection mode: {mode}")
if _has_explicit_selected_token_indices(sampling_metadata.selected_token_indices):
return _index_select_token_dim(hidden_states, sampling_metadata.selected_token_indices)
return hidden_states
class Sampler(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, **kwargs) -> torch.Tensor:
"""Return greedy next tokens.
Spec decode models TensorCast dependency/shape flow only: verification rows
are reduced with greedy argmax for target+bonus tokens, not probability
acceptance/rejection sampling.
"""
spec_metadata = sampling_metadata.spec_decode_metadata
if spec_metadata is not None:
logits = hidden_states.reshape(-1, hidden_states.size(-1))
num_active_requests = spec_metadata.num_active_requests
num_speculative_tokens = spec_metadata.num_speculative_tokens
spec_window = _spec_window_size(spec_metadata)
verification_rows = num_active_requests * spec_window
if logits.size(0) == verification_rows:
verification_logits = logits.reshape(num_active_requests, spec_window, logits.size(-1))
target_logits = verification_logits[:, :num_speculative_tokens, :]
bonus_logits = verification_logits[:, num_speculative_tokens, :]
target_tokens = torch.argmax(target_logits, dim=-1)
bonus_tokens = torch.argmax(bonus_logits, dim=-1, keepdim=True)
return torch.cat([target_tokens, bonus_tokens], dim=-1)
if logits.size(0) == num_active_requests:
return torch.argmax(logits, dim=-1, keepdim=True)
raise ValueError(
"Spec decode logits rows must match active requests or "
f"verification rows, got {logits.size(0)} rows for {num_active_requests} "
f"requests and {num_speculative_tokens} speculative tokens. "
f"Expected {verification_rows} verification rows or {num_active_requests} proposal rows."
)
selected_token_indices = sampling_metadata.selected_token_indices
if _has_explicit_selected_token_indices(selected_token_indices):
logits = _select_logits_for_sampling(hidden_states, selected_token_indices)
elif sampling_metadata.query_start_loc is None:
assert hidden_states.ndim == 3
logits = hidden_states[:, -1, :]
else:
query_start_loc = sampling_metadata.query_start_loc.to(hidden_states.device)
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
logits = hidden_states.index_select(0, query_start_loc[1:] - 1)
next_tokens = torch.argmax(logits, dim=-1, keepdim=True)
return next_tokens