import dataclasses
from typing import Optional
import torch
@dataclasses.dataclass
class SamplingMetadata:
"""
A simplified sampling data assuming the sampling parameters like top_k/top_k are the same across all
requests.
"""
query_start_loc: Optional[torch.Tensor] = None
"""(batch_size + 1,), the start location of each request in query Tensor. If not set,
the request inputs have the same length indicated by the input_ids shape:
(batch_size, query_length, hidden_size).
"""
selected_token_indices: Optional[torch.Tensor] = dataclasses.field(
default_factory=lambda: torch.tensor(-1, dtype=torch.long)
)
top_k: Optional[int] = None
class Sampler(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, **kwargs) -> torch.Tensor:
if sampling_metadata.query_start_loc is None:
assert hidden_states.ndim == 3
logits = hidden_states[:, -1, :]
else:
query_start_loc = sampling_metadata.query_start_loc.to(hidden_states.device)
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
logits = hidden_states.index_select(0, query_start_loc[1:] - 1)
next_tokens = torch.argmax(logits, dim=-1, keepdim=True)
return next_tokens