from abc import ABC, abstractmethod
from typing import Any
class BaseSampler(ABC):
"""Base class for samplers that control how data is consumed from TransferQueue.
A sampler defines the logic for selecting which samples to retrieve from the
available samples, and which should be labeled as consumed (will never be retrieved in the future).
Based on this abstraction, users can implement various data consumption strategies
for different training scenarios, such as sequential sampling, grouped sampling for
reinforcement learning, or custom sampling patterns.
The sampler interface provides a clean separation between data production status
(handled by TransferQueueController) and data consumption strategy (implemented by samplers).
This allows users to customize data consumption behavior without modifying the TransferQueue codes.
Available Samplers:
- **SequentialSampler**: Default sampler, selects samples sequentially without replacement
- **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready.
It assumes the N samples associated with the same prompt are stored contiguously
- **RankAwareSampler**: Rank-aware sampling for distributed training where each rank retrieves data independently.
This sampler will guarantee ranks of the same DP group consume identical samples.
NOTE: Always return both sampled and consumed indexes (may be identical).
"""
def __init__(self):
self._states: dict[Any, Any] = {}
@abstractmethod
def sample(
self,
ready_indexes: list[int],
batch_size: int,
*args: Any,
**kwargs: Any,
) -> tuple[list[int], list[int]]:
"""Sample a batch of indices from the ready indices.
Args:
ready_indexes: List of global indices for which all required fields of the
corresponding samples have been produced, and the samples are not labeled as
consumed in the corresponding task.
batch_size: Number of samples to select
*args: Additional positional arguments for specific sampler implementations
**kwargs: Additional keyword arguments for specific sampler implementations
Returns:
List of sampled global indices of length batch_size
List of global indices of length batch_size that should be labeled as consumed
(will never be retrieved in the future)
Raises:
ValueError: If batch_size is invalid or ready_indexes is insufficient
"""
raise NotImplementedError("Subclasses must implement sample")
def __call__(self, *args: Any, **kwargs: Any) -> tuple[list[int], list[int]]:
return self.sample(*args, **kwargs)
def has_cached_result(
self,
partition_id: str,
task_name: str,
sampling_config: dict[str, Any] | None = None,
) -> bool:
"""Check whether the sampler has a cached sampling result for the given context.
This is used by the controller in polling mode to determine if a previously
computed sampling result is already available, so that it can skip waiting
for more data and directly proceed to return the cached result.
The check is based on the ``_states`` cache structure:
``_states[partition_id][task_name][dp_rank][batch_index]``.
Args:
partition_id: The partition identifier.
task_name: The consumer task name.
sampling_config: Optional sampling configuration dict that may contain
``dp_rank`` and ``batch_index`` keys used to locate the cached result.
Returns:
True if a cached result exists for the specified
``(partition_id, task_name, dp_rank, batch_index)`` combination,
False otherwise. Also returns False if ``dp_rank`` is not provided
in ``sampling_config``.
"""
sampling_config = sampling_config or {}
dp_rank = sampling_config.get("dp_rank", None)
batch_index = sampling_config.get("batch_index", None)
if dp_rank is None:
return False
states = self._states.get(partition_id, {}).get(task_name, {})
return dp_rank in states and batch_index in states[dp_rank]
def clear_cache(self, partition_id: str):
"""Clear cached states.
This method removes any cached sampling results that include the specified
global indexes, ensuring that future sampling operations do not reference
stale data.
Args:
partition_id: The partition ID associated with the task.
"""
if partition_id in self._states.keys():
self._states.pop(partition_id)