"""
Tutorial 5: Streaming DataLoader for Distributed Training
This script demonstrates how to use StreamingDataset and StreamingDataLoader
for efficient streaming data loading in distributed training scenarios.
Key Components:
- StreamingDataset: PyTorch IterableDataset that integrates with TransferQueue
- StreamingDataLoader: DataLoader wrapper that yields (batch, batch_meta) tuples
- RankAwareSampler: Enables DP group coordination for consistent
sampling across multiple ranks
Use Cases:
- Distributed training with multiple DP groups
- Fine-grained micro-batch-level data retrieval
"""
import os
import sys
import textwrap
import time
import warnings
from pathlib import Path
os.environ["RAY_DEDUP_LOGS"] = "0"
warnings.filterwarnings(
action="ignore",
message=r"The PyTorch API of nested tensors is in prototype stage*",
category=UserWarning,
module=r"torch\.nested",
)
warnings.filterwarnings(
action="ignore",
message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible "
r"devices env var if num_gpus=0 or num_gpus=None.*",
category=FutureWarning,
module=r"ray\._private\.worker",
)
import ray
import torch
from omegaconf import OmegaConf
from tensordict import TensorDict
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
import transfer_queue as tq
from transfer_queue import (
RankAwareSampler,
StreamingDataLoader,
StreamingDataset,
)
def setup_transfer_queue():
"""Setup TransferQueue components."""
if not ray.is_initialized():
ray.init(namespace="TransferQueueTutorial")
print("[Setup]: Setup TransferQueue components")
print(
"Note: Using RankAwareSampler when each rank retrieves data independently. It guarantees that "
"The same DP rank receives the same sample indices."
)
print(
"Note: When using streaming data retrieval, please set polling_mode=True when initializing "
"TransferQueueController. In polling_mode, the controller will return empty BatchMeta when "
"available data cannot meet the consumption requirements. User side need to retry later."
)
config = OmegaConf.create(
{
"controller": {
"sampler": RankAwareSampler,
"polling_mode": True,
},
"backend": {"SimpleStorage": {"num_data_storage_units": 2}},
},
flags={"allow_objects": True},
)
tq.init(config)
@ray.remote(num_cpus=0.1)
def generate_worker(rank_id: int, num_samples: int = 20):
"""
Generate actor that produces training samples.
This actor simulates a data producer that generates training samples
and puts them into the TransferQueue for consumption by training actors.
Args:
rank_id: Unique identifier for this generator (used for sample indexing)
num_samples: Number of samples to generate
Note:
Each sample has a unique sequence ID calculated as: seq_id = i + (rank_id * 10000)
This ensures global uniqueness across all generator actors.
"""
tq.init()
tq_client = tq.get_client()
for i in range(num_samples):
seq_id = i + (rank_id * 10000)
data = TensorDict(
{"input_ids": torch.full((1, 16), seq_id, dtype=torch.long), "meta_idx": torch.tensor([seq_id])},
batch_size=1,
)
print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue")
tq_client.put(data, partition_id="train")
print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue")
@ray.remote(num_cpus=0.1)
def update_worker(
rank_id: int,
dp_rank: int,
max_steps: int = 5,
):
"""
Update actor that retrieves and processes training batches.
This actor simulates a training worker that consumes data from the
TransferQueue using StreamingDataLoader. It demonstrates how to use
the streaming data loading infrastructure in a distributed setting.
Args:
rank_id: Global rank identifier for logging and display purposes
dp_rank: Data parallel rank ID that this worker belongs to
The same Ranks receive the same data samples
max_steps: Maximum number of batches to consume
Returns:
dict: Contains dp_rank and consumed_ids
Example:
For a setup with 2 data rank (0 and 1):
- Rank 0: receive identical samples
- Rank 1: receive identical samples
All ranks within the same rank index get the same global indexes.
Note:
The StreamingDataLoader yields tuples of (batch, batch_meta) where:
- batch: TensorDict containing the requested data fields
- batch_meta: Metadata for TransferQueue coordination (contains global_indexes)
"""
tq.init()
controller = ray.get_actor("TransferQueueController", namespace="transfer_queue")
config = ray.get(controller.get_config.remote())
dataset = StreamingDataset(
config=config,
batch_size=2,
micro_batch_size=2,
data_fields=["meta_idx"],
partition_id="train",
task_name="update_task",
dp_rank=dp_rank,
should_check_consumption_status=False,
)
print(f"[Update Worker@{rank_id}] StreamingDataset created successfully")
dataloader = StreamingDataLoader(
dataset=dataset,
num_workers=2,
prefetch_factor=2,
)
print(
f"[Update Worker@{rank_id}] StreamingDataLoader ready, enabling data pre-fetching through num_workers "
f"and prefetch_factor."
)
print(f"[Update Worker@{rank_id}] Starting data consumption...")
consumed_ids = []
step = 0
for batch, batch_meta in dataloader:
ids = batch["meta_idx"].view(-1).tolist()
print(f"[Update Worker@{rank_id}]: dp_rank {dp_rank} retrieved samples: {ids}")
consumed_ids.extend(ids)
time.sleep(5)
step += 1
if step >= max_steps:
print(f"[Update Worker@{rank_id}] Reached max steps ({max_steps}), stopping...")
break
del dataloader
print(f"[Update Worker@{rank_id}] Completed {step} steps, consumed {len(consumed_ids)} samples")
return {
"dp_rank": dp_rank,
"consumed_ids": consumed_ids,
}
def start_all_generate_actors():
"""
Launch generate_actors for producing training samples.
"""
num_workers = 2
handlers = []
for i in range(num_workers):
handlers.append(generate_worker.remote(rank_id=i, num_samples=20))
return handlers
def start_all_update_actors():
"""
Launch update_actors for consuming training samples.
"""
rank_ids = [0, 1, 2, 3]
dp_rank = [0, 0, 1, 1]
print("Training topology configuration:")
print(f" - Total ranks: {len(rank_ids)}")
print(f" - Data parallel rank: {len(set(dp_rank))}")
handlers = []
for i in range(len(rank_ids)):
handlers.append(
update_worker.remote(
rank_id=rank_ids[i],
dp_rank=dp_rank[i],
)
)
return handlers
def main():
"""
Main function demonstrating end-to-end streaming data loading.
This tutorial showcases:
1. Setting up TransferQueue with streaming capabilities
2. Launching data generation actors
3. Launching data consumption actors with distributed training topology
4. Verifying that ranks in the same group receive identical samples
"""
print("=" * 80)
print(
textwrap.dedent(
"""
TransferQueue Tutorial 6: StreamingDataLoader for Distributed Training
This tutorial demonstrates the StreamingDataLoader interface for distributed
training scenarios. It showcases how to use StreamingDataset and StreamingDataLoader
to efficiently consume micro-batch of samples from TransferQueue with proper coordination
across multiple training ranks.
Key Concepts:
- StreamingDataset: PyTorch IterableDataset that integrates with TransferQueue
- StreamingDataLoader: DataLoader wrapper yielding (batch, batch_meta) tuples
- RankAwareSampler: Enables correct data consumption across DP ranks
- DP Rank: Ranks that should receive identical data samples
"""
)
)
print("=" * 80)
print("\n[Phase 1] Setting up TransferQueue infrastructure...")
print(
"\nIn real-world usage, please export the environment variable of TQ_PRE_ALLOC_SAMPLE_NUM to "
"global_batch_size to make sure consumers can accurately determine consumption status even before "
"producers have generated the samples."
)
setup_transfer_queue()
print("\n[Phase 2] Starting data generation...")
generate_worker_handlers = start_all_generate_actors()
print("\n[Phase 3] Starting data consumption...")
update_worker_handlers = start_all_update_actors()
print("\n[Phase 4] Waiting for actors to complete...")
print("=" * 80)
ray.get(generate_worker_handlers)
print("✓ All generation actors completed")
update_results = ray.get(update_worker_handlers)
print("✓ All update actors completed")
print("\n" + "=" * 80)
print("Results Summary")
print("=" * 80)
for result in update_results:
print(f" DP Rank {result['dp_rank']}: consumed {len(result['consumed_ids'])} samples")
print("\n" + "=" * 80)
print("Tutorial Complete!")
print("=" * 80)
print("Key Takeaways:")
print("1. StreamingDataset provides PyTorch IterableDataset interface for TransferQueue")
print("2. StreamingDataLoader wraps the dataset and yields (batch, batch_meta) tuples")
print("3. Ranks with the same DP rank receive identical samples")
print("4. The system enables efficient streaming capabilities")
if __name__ == "__main__":
main()