import torch
from tensordict import TensorDict
from transfer_queue.dataloader.streaming_dataset import StreamingDataset
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.logging_utils import get_logger
logger = get_logger(__name__)
def _identity_collate_fn(data: tuple[TensorDict, BatchMeta]) -> tuple[TensorDict, BatchMeta]:
"""Identity collate function for TransferQueue.
This function acts as a pass-through, preserving the `(TensorDict, BatchMeta)`
structure yielded by `StreamingDataset`. It prevents PyTorch from attempting
to stack or modify the already-batched data.
"""
return data
class StreamingDataLoader(torch.utils.data.DataLoader):
"""StreamingDataLoader interface for TransferQueue.
This DataLoader wraps StreamingDataset and provides a PyTorch DataLoader
interface for distributed training with streaming data access.
Key Features:
- Compatible with PyTorch training loops (for loop iteration)
- Works with StreamingDataset for streaming data access
- Supports distributed training via RankAwareSampler coordination
Note:
This DataLoader is typically used with StreamingDataset which manages
batch size internally. The standard PyTorch DataLoader batch_size
parameter is set to None because batching is handled by the dataset
in coordination with TransferQueue's sampling logic.
Example:
>>> dataset = StreamingDataset(
... config=config,
... micro_batch_size=4,
... required_fields=["input_ids", "attention_mask"],
... partition_id="train",
... task_name="update_actor",
... data_replica_group=0,
... data_replica_rank=0,
... data_replica_world_size=1,
... )
>>> dataloader = StreamingDataLoader(dataset, num_workers=0)
>>> for batch, batch_meta in dataloader:
... # batch: TensorDict with requested fields
... # batch_meta: Metadata for TransferQueue coordination
... loss = model(batch)
... loss.backward()
"""
def __init__(
self,
dataset: StreamingDataset,
num_workers: int = 0,
collate_fn=None,
pin_memory: bool = False,
worker_init_fn=None,
multiprocessing_context=None,
prefetch_factor: int | None = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
):
"""Initialize the StreamingDataLoader.
Args:
dataset: StreamingDataset instance.
num_workers: Number of subprocesses for data loading.
collate_fn: Function to collate samples into batches.
pin_memory: If True, pin memory for GPU transfer.
worker_init_fn: Worker initialization function.
multiprocessing_context: Multiprocessing context.
prefetch_factor: Number of batches to prefetch per worker.
persistent_workers: Keep workers alive between epochs.
pin_memory_device: Device for pin_memory.
Note:
This DataLoader is designed to work with StreamingDataset which handles
batch size internally via the micro_batch_size parameter. The batch_size
parameter in PyTorch DataLoader is set to None because batching is managed
by the StreamingDataset in coordination with RankAwareSampler.
"""
self.dataset: StreamingDataset = dataset
if collate_fn is None:
final_collate_fn = _identity_collate_fn
else:
final_collate_fn = collate_fn
super().__init__(
dataset=dataset,
batch_size=None,
shuffle=None,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=final_collate_fn,
pin_memory=pin_memory,
drop_last=False,
timeout=0,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=None,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device,
)
def reset(self):
"""Reset the dataset iterator to the beginning.
Clears the buffer and resets the batch index for a fresh iteration.
"""
self.dataset.reset()
def step(self, partition_id):
"""Switch to a new partition and reset the dataset state.
This method clears the buffer, resets the batch index, and updates the partition_id
to fetch data from a different partition (e.g., switching from "train" to "val").
Args:
partition_id: The new partition ID to switch to.
"""
self.dataset.step(partition_id)
def get_buffer(self):
"""Get the current buffer from the underlying dataset.
Returns the batch buffer maintained by StreamingDataset, which stores
pre-fetched batches for efficient data access.
Returns:
list: Buffer containing pre-fetched (TensorDict, BatchMeta) tuples.
"""
return self.dataset.buffer