import os
import sys
import textwrap
import warnings
from pathlib import Path
warnings.filterwarnings(
action="ignore",
message=r"The PyTorch API of nested tensors is in prototype stage*",
category=UserWarning,
module=r"torch\.nested",
)
warnings.filterwarnings(
action="ignore",
message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible "
r"devices env var if num_gpus=0 or num_gpus=None.*",
category=FutureWarning,
module=r"ray\._private\.worker",
)
import ray
import torch
from tensordict import TensorDict
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
import transfer_queue as tq
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_DEBUG"] = "1"
if not ray.is_initialized():
ray.init(namespace="TransferQueueTutorial")
def demonstrate_data_workflow():
"""
Demonstrate basic data workflow: put → get → clear.
"""
print("=" * 80)
print("Data Workflow Demo: put → get → clear")
print("=" * 80)
print("[Step 1] Putting data into TransferQueue...")
tq_client = tq.get_client()
input_ids = torch.tensor(
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
]
)
attention_mask = torch.ones_like(input_ids)
data_batch = TensorDict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
},
batch_size=input_ids.size(0),
)
print(f" Created {data_batch.batch_size[0]} samples")
partition_id = "tutorial_partition_0"
tq_client.put(data=data_batch, partition_id=partition_id)
print(f" ✓ Data written to partition: {partition_id}")
print("[Step 2] Requesting data metadata...")
batch_meta = tq_client.get_meta(
data_fields=["input_ids", "attention_mask"],
batch_size=data_batch.batch_size[0],
partition_id=partition_id,
task_name="tutorial_task",
)
print(f" ✓ Got metadata: {len(batch_meta)} samples")
print(f" Global indexes: {batch_meta.global_indexes}")
print("[Step 3] Retrieving actual data...")
retrieved_data = tq_client.get_data(batch_meta)
print(" ✓ Data retrieved successfully")
print(f" Keys: {list(retrieved_data.keys())}")
print("[Step 4] Verifying data integrity...")
for key in ["input_ids", "attention_mask"]:
expected = input_ids if key == "input_ids" else attention_mask
for t1, t2 in zip(retrieved_data[key], expected, strict=True):
assert torch.equal(t1, t2), f"Mismatch in {key}"
print(" ✓ Data matches original!")
print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)")
tq_client.clear_partition(partition_id=partition_id)
print(" ✓ Partition cleared")
def demonstrate_storage_backend_options():
"""
Show different storage backend options.
"""
print("=" * 80)
print("Storage Backend Options")
print("=" * 80)
print("TransferQueue supports multiple storage backends:")
print("1. SimpleStorage (default)")
print(" - In-memory storage, fast and simple")
print(" - Leveraging ZMQ for communication, with zero-copy serialization and transfer")
print(" - No extra dependencies, good for development and testing")
print("2. Yuanrong")
print(" - Ascend native distributed storage solution")
print(" - Hierarchical storage interfaces including HBM/DRAM/SSD")
print("3. MooncakeStore (on the way)")
print(" - Support multiple transmission protocols")
print(" - RDMA between DRAM and HBM")
print("4. Ray RDT (on the way)")
print(" - Leverage Ray's distributed object store to store data")
print("5. Custom Storage Backends")
print(" - Implement your own storage manager by inheriting from `StorageManager` base class")
print(" - For KV based storage, you only need to provide a storage client and integrate with `KVStorageManager`")
def main():
print("=" * 80)
print(
textwrap.dedent(
"""
TransferQueue Tutorial 1: Core Components Introduction
This script introduces the three core components of TransferQueue:
1. TransferQueueController - Manages all the metadata and tracks the production and consumption states
2. StorageBackend - Pluggable distributed storage backend that holds the actual data
3. TransferQueueClient - Client interface for reading/writing data (user-facing API)
Key Concepts:
- Data is organized into logical partitions (e.g., "train", "val")
- Each sample has multiple fields, with a global index for identification
- Controller maintains production/consumption state tracking
- Client is the main interface users interact with
"""
)
)
print("=" * 80)
try:
print("Setting up TransferQueue...")
tq.init()
print("Demonstrating the user workflow...")
demonstrate_data_workflow()
demonstrate_storage_backend_options()
print("=" * 80)
print("Tutorial Complete!")
print("=" * 80)
print("Key Takeaways:")
print("1. TransferQueue has 3 core components:")
print(" - Controller: Manages data production/consumption state")
print(" - StorageBackend: Persists actual data")
print(" - Client: User-facing API (what you use)")
print("2. Client is the main interface users interact with")
print("3. You can swap out different storage backends easily")
tq.close()
ray.shutdown()
print("\n✓ Cleanup complete")
except Exception as e:
print(f"Error during tutorial: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()