import os
import time
import uuid
from typing import Any, Callable, Iterator
import torch
from omegaconf import DictConfig
from tensordict import TensorDict
from torch.utils.data import IterableDataset
from transfer_queue.client import TransferQueueClient
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.logging_utils import get_logger
TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float(
os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1)
)
logger = get_logger(__name__)
class StreamingDataset(IterableDataset):
"""Streaming dataset for distributed training with TransferQueue.
This dataset is designed to work with RankAwareSampler for distributed training
scenarios where each rank independently retrieves data through TransferQueue.
Usage Example:
>>> dataset = StreamingDataset(
... config=config,
... micro_batch_size=4,
... required_fields=["input_ids", "attention_mask"],
... partition_id="train",
... task_name="update_actor",
... dp_rank=dp_rank, # Same for all ranks in data replica group
... )
>>> dataloader = StreamingDataLoader(
... dataset,
... num_workers=2, # num_workers for data retrieval, each has a TQ client for async data retrieval
... prefetch_factor=2, # number of batches loaded in advance by each worker
... )
>>> for batch, batch_meta in dataloader:
... # batch is a TensorDict with the requested fields
... # batch_meta contains metadata for TransferQueue coordination
... pass
"""
def __init__(
self,
config: DictConfig,
batch_size: int,
micro_batch_size: int,
data_fields: list[str],
partition_id: str,
task_name: str,
dp_rank: int,
should_check_consumption_status: bool = False,
fetch_batch_fn: Callable | None = None,
process_batch_fn: Callable | None = None,
):
"""Initialize the StreamingDataset.
Args:
config: Configuration dictionary containing:
- controller.controller_info: ZMQServerInfo for the TransferQueueController
- backend.storage_backend: Storage backend type (e.g., "SimpleStorage")
- Other backend-specific configuration
batch_size: Batch size for data loading per iter.
micro_batch_size: Number of samples per micro-batch. This is the batch size
that will be requested from TransferQueue for each iteration.
data_fields: List of field names to retrieve from storage. Only these
fields will be included in the returned batch.
partition_id: Partition ID for data versioning. Different partitions can
be used for different data versions or splits (e.g., "train", "val").
task_name: Unique identifier for the training task. This is used to track
which samples have been consumed by which task.
dp_rank: The group ID of the current data group. All
ranks with the same dp_rank will receive identical samples.
should_check_consumption_status: Whether to check the consumption status of the
partition to decide when to stop iterating. Defaults to ``False``, which
means the iterator runs as an **infinite stream** — it will continuously
poll for new data and never exit on its own. This is the typical mode for
online/streaming training where producers keep feeding data indefinitely.
Set to ``True`` when the total number of samples is known in advance (i.e.
finite-dataset mode); the iterator will then stop once all samples in the
partition have been consumed.
fetch_batch_fn: Optional custom function to retrieve batch data.
If None, uses default_fetch_batch_fn function.
process_batch_fn: Optional custom function to post-process
and split data into micro-batches. If None, uses chunk_batch_fn.
Raises:
ValueError: If input parameters are invalid.
"""
if micro_batch_size < 1:
raise ValueError(f"micro_batch_size must be >= 1, got {micro_batch_size}")
if len(data_fields) < 1:
raise ValueError(f"data_fields must be a list with at least one field name, got {data_fields}")
if dp_rank < 0:
raise ValueError(f"dp_rank {dp_rank} must be greater than or equal to 0")
self.config = config
self.batch_size = batch_size
self.micro_batch_size = micro_batch_size
self.data_fields = data_fields
self.partition_id = partition_id
self.task_name = task_name
self.dp_rank = dp_rank
self.should_check_consumption_status = should_check_consumption_status
self.fetch_batch_fn = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn
self.process_batch_fn = process_batch_fn if process_batch_fn else chunk_batch_fn
self.sampling_config = {
"dp_rank": self.dp_rank,
"task_name": self.task_name,
}
self._tq_client = None
self.buffer: list[tuple[TensorDict, BatchMeta]] = []
self.batch_index = 0
super().__init__()
def _create_client(self):
"""Create and initialize a TransferQueue client directly from config.
This method creates a ``TransferQueueClient`` using the ZMQ address and
storage backend information already present in ``self.config``. It
intentionally does **not** call ``tq.init()`` because that relies on Ray
internally (``ray.get_actor`` / ``ray.get``), which is **unsafe in
forked subprocesses** spawned by PyTorch DataLoader (``num_workers > 0``).
Creating the client directly via ZMQ avoids this issue.
"""
client_id = f"StreamingDataset_{uuid.uuid4().hex[:8]}"
controller_info = self.config.controller.zmq_info
storage_backend = self.config.backend.storage_backend
backend_config = self.config.backend[storage_backend]
self._tq_client = TransferQueueClient(client_id, controller_info)
self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config)
def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
"""Iterate over the dataset, yielding batches of data.
The iteration behaviour depends on ``should_check_consumption_status``:
- **False (default — streaming mode)**: The iterator runs as an
infinite stream, continuously polling TransferQueue for new data.
It will sleep for `TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL` seconds
(default=1) when no data is available and
resume once new batches are produced. This is the standard mode for
online / streaming training pipelines where producers feed data
indefinitely.
- **True (finite-dataset mode)**: The iterator terminates once all
samples in the partition have been consumed (as reported by
``check_consumption_status``), *and* all buffered batches have been
yielded.
Yields:
Tuple[TensorDict, BatchMeta]: A tuple containing:
- TensorDict: Batch of data with the requested fields.
- BatchMeta: Corresponding metadata to interact with TransferQueue.
"""
if self._tq_client is None:
self._create_client()
assert self._tq_client is not None, "Failed to create TransferQueue client"
while (
not self.should_check_consumption_status
or not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
or self.batch_index <= len(self.buffer) - 1
):
try:
if self.batch_index <= len(self.buffer) - 1:
current_data = self.buffer[self.batch_index]
self.batch_index += 1
logger.debug(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")
yield from self.process_batch_fn(*current_data, micro_batch_size=self.micro_batch_size)
else:
batch_data, batch_meta = self.fetch_batch_fn(
tq_client=self._tq_client,
data_fields=self.data_fields,
batch_size=self.batch_size,
partition_id=self.partition_id,
task_name=self.task_name,
sampling_config=self.sampling_config,
batch_index=self.batch_index,
)
if batch_data is not None:
self.buffer.append((batch_data, batch_meta))
else:
time.sleep(1)
except Exception as e:
logger.error(f"[StreamingDataset]: Error in data iteration: {e}")
raise
def reset(self):
"""Reset the dataset iterator to the beginning.
Clears the buffer and resets the batch index for a fresh iteration.
"""
self.batch_index = 0
def step(self, partition_id):
"""Switch to a new partition and reset the dataset state.
This method clears the buffer, resets the batch index, and updates the partition_id
to fetch data from a different partition (e.g., switching from "train" to "val").
Args:
partition_id: The new partition ID to switch to.
"""
self.buffer = []
self.batch_index = 0
self.partition_id = partition_id
def default_fetch_batch_fn(tq_client, data_fields, batch_size, partition_id, task_name, sampling_config, batch_index):
"""Retrieve a batch of data from TransferQueue.
This function queries the TransferQueue controller for batch metadata and retrieves
the actual data if available. It handles empty batches gracefully.
Args:
tq_client: The TransferQueueClient instance for data retrieval.
data_fields: List of field names to retrieve from the batch.
batch_size: The requested batch size.
partition_id: The partition ID for data versioning.
task_name: Unique identifier for the training task.
sampling_config: Configuration dictionary for sampling strategy.
batch_index: Current batch index for tracking consumption progress.
Returns:
tuple: A tuple containing:
- batch: TensorDict with the retrieved data, or None if batch is empty.
- batch_meta: BatchMeta object containing batch metadata.
"""
config = {**sampling_config, "batch_index": batch_index, "partition_id": partition_id}
batch_meta = tq_client.get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
task_name=task_name,
sampling_config=config,
)
if batch_meta.size == 0:
logger.debug(
f"[StreamingDataset]: Received empty batch, waiting for more data... "
f"Required batch_size={batch_size}, data_fields={data_fields},"
f"partition_id={partition_id}, task_name={task_name}."
)
return None, batch_meta
else:
batch = tq_client.get_data(batch_meta)
return batch, batch_meta
def chunk_batch_fn(td, batch_meta, micro_batch_size=1):
"""Split TensorDict into micro-batches along the batch dimension.
This function chunks a TensorDict into smaller micro-batches with the specified size,
along with corresponding metadata chunks. Handles cases where batch size is not
evenly divisible by micro_batch_size.
Args:
td: Input TensorDict with non-empty batch_size.
batch_meta: BatchMeta object to be chunked along with the TensorDict.
micro_batch_size: Target size for each micro-batch (positive integer, default: 1).
Returns:
list: List of tuples (micro_batch_td, micro_batch_meta) where each tuple
contains a TensorDict chunk and corresponding metadata chunk.
Raises:
TypeError: If td is not a TensorDict.
ValueError: If micro_batch_size is not a positive integer, batch_size is empty,
or micro_batch_size exceeds total batch size.
"""
if not isinstance(td, TensorDict):
raise TypeError(f"Expected TensorDict, got {type(td).__name__}")
if not isinstance(micro_batch_size, int) or micro_batch_size <= 0:
raise ValueError(f"micro_batch_size must be a positive integer, got {micro_batch_size}")
if len(td.batch_size) == 0:
raise ValueError("Input TensorDict must have non-empty batch_size")
total_size = td.batch_size[0]
if micro_batch_size > total_size:
raise ValueError(f"micro_batch_size ({micro_batch_size}) exceeds total batch size ({total_size})")
num_splits = (total_size + micro_batch_size - 1) // micro_batch_size
splits = []
batch_meta_list = batch_meta.chunk(num_splits)
nested_fields: dict[str, tuple[tuple, torch.layout]] = {}
dense_data: dict[str, Any] = {}
for key, value in td.items():
if isinstance(value, torch.Tensor) and value.is_nested:
nested_fields[key] = (value.unbind(0), value.layout)
else:
dense_data[key] = value
dense_td = TensorDict(dense_data, batch_size=td.batch_size) if dense_data else None
for i in range(num_splits):
start = i * micro_batch_size
end = min(start + micro_batch_size, total_size)
if dense_td is not None:
chunk_td = dense_td[start:end]
else:
chunk_td = TensorDict({}, batch_size=(end - start,))
for key, (tensors, layout) in nested_fields.items():
chunk_td[key] = torch.nested.as_nested_tensor(tensors[start:end], layout=layout)
splits.append((chunk_td, batch_meta_list[i]))
return splits