import os
import sys
import textwrap
import warnings
from pathlib import Path
warnings.filterwarnings(
action="ignore",
message=r"The PyTorch API of nested tensors is in prototype stage*",
category=UserWarning,
module=r"torch\.nested",
)
warnings.filterwarnings(
action="ignore",
message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible "
r"devices env var if num_gpus=0 or num_gpus=None.*",
category=FutureWarning,
module=r"ray\._private\.worker",
)
import ray
import torch
from tensordict import TensorDict
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
import transfer_queue as tq
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_DEBUG"] = "1"
def demonstrate_partition_isolation():
"""Feature 1: Different partitions are isolated - data doesn't interfere."""
print("=" * 80)
print("Feature 1: Partition Isolation")
print("=" * 80)
print("\nDifferent partitions are completely isolated - data doesn't interfere between partitions")
if not ray.is_initialized():
ray.init(namespace="TransferQueueTutorial")
tq.init()
tq_client = tq.get_client()
print("\n[Partition 1] Putting training data...")
train_data = TensorDict(
{
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
"labels": torch.tensor([[0], [1]]),
},
batch_size=2,
)
tq_client.put(data=train_data, partition_id="train")
print(" ✓ Training data added to 'train' partition")
print("\n[Partition 2] Putting validation data...")
val_data = TensorDict(
{
"input_ids": torch.tensor([[7, 8, 9], [10, 11, 12]]),
"labels": torch.tensor([[2], [3]]),
},
batch_size=2,
)
tq_client.put(data=val_data, partition_id="val")
print(" ✓ Validation data added to 'val' partition")
print("\n[Retrieving from 'train' partition]")
train_meta = tq_client.get_meta(
data_fields=["input_ids", "labels"], batch_size=2, partition_id="train", task_name="train_task"
)
retrieved_train_data = tq_client.get_data(train_meta)
print(f" ✓ Got BatchMeta={train_meta} from partition 'train'")
print(f" ✓ Retrieved Data: input_ids={retrieved_train_data['input_ids']}, labels={retrieved_train_data['labels']}")
print("\n[Retrieving from 'val' partition]")
val_meta = tq_client.get_meta(
data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task"
)
retrieved_val_data = tq_client.get_data(val_meta)
print(f" ✓ Got BatchMeta={val_meta} from partition 'val'")
print(f" ✓ Retrieved Data: input_ids={retrieved_val_data['input_ids']}, labels={retrieved_val_data['labels']}")
print("\n[Verification]")
print(" ✓ Data isolation: 'train' and 'val' partitions are completely independent")
tq_client.clear_partition(partition_id="train")
tq_client.clear_partition(partition_id="val")
tq.close()
ray.shutdown()
def demonstrate_dynamic_expansion():
"""Feature 2: Dynamic expansion - can add rows and columns anytime."""
print("\n" + "=" * 80)
print("Feature 2: Dynamic Expansion - Flexible Row/Column Addition")
print("=" * 80)
print("\nPartitions dynamically expand to accommodate new data (rows and columns)")
if not ray.is_initialized():
ray.init(namespace="TransferQueueTutorial")
tq.init()
tq_client = tq.get_client()
print("\n[Step 1] Adding initial data (2 samples, 2 fields)...")
data1 = TensorDict(
{
"field1": torch.tensor([[1, 2], [3, 4]]),
"field2": torch.tensor([[5, 6], [7, 8]]),
},
batch_size=2,
)
meta1 = tq_client.put(data=data1, partition_id="dynamic")
print(" ✓ Added 2 samples")
print(f" ✓ Got BatchMeta: {meta1} samples")
print("\n[Step 2] Adding more samples (expanding rows)...")
data2 = TensorDict(
{
"field1": torch.tensor([[9, 10], [11, 12], [13, 14]]),
"field2": torch.tensor([[15, 16], [17, 18], [19, 20]]),
},
batch_size=3,
)
meta2 = tq_client.put(data=data2, partition_id="dynamic")
all_meta = tq_client.get_meta(
data_fields=["field1", "field2"], batch_size=5, partition_id="dynamic", task_name="dynamic_task"
)
print(" ✓ Added 3 more samples (total: 5)")
print(f" ✓ Got BatchMeta {meta2} for newly put data.")
print(f" ✓ All BatchMeta in controller is {all_meta}")
print("\n[Step 3] Adding new field (expanding columns)...")
data3 = TensorDict(
{
"field3": torch.tensor([[25, 26], [27, 28]]),
},
batch_size=2,
)
meta3 = tq_client.put(data=data3, metadata=meta1)
print(" ✓ Added 2 samples with new field 'field3'")
print(f" ✓ Got BatchMeta: {meta3} for newly put data with new field")
print("\n[Verification]")
print(" ✓ Rows auto-expand: Can add more samples anytime")
print(" ✓ Columns auto-expand: Can add new fields anytime")
tq_client.clear_partition(partition_id="dynamic")
tq.close()
ray.shutdown()
def demonstrate_default_consumption_sample_strategy():
"""Feature 3: Default sequential sampling without replacement."""
print("\n" + "=" * 80)
print("Feature 3: Default Sampling Strategy for Controller - No Duplicate, Sequential Samples")
print("=" * 80)
if not ray.is_initialized():
ray.init(namespace="TransferQueueTutorial")
tq.init()
tq_client = tq.get_client()
print("\n[Setup] Adding 6 samples...")
all_data = TensorDict(
{
"data": torch.tensor([[i] for i in range(6)]),
},
batch_size=6,
)
tq_client.put(data=all_data, partition_id="sampling")
print(" ✓ 6 samples added")
print("\n[Task A, Get 1] Requesting 3 samples...")
meta1 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
print(f" ✓ Got samples: {meta1.global_indexes}")
print("\n[Task A, Get 2] Requesting 3 more samples...")
meta2 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
print(f" ✓ Got samples: {meta2.global_indexes}")
print("\n[Task B, Get 1] Requesting 2 samples...")
meta3 = tq_client.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B")
print(f" ✓ Got samples: {meta3.global_indexes}")
print("\n[Verification]")
print(" ✓ Same task_name: Sequential sampling, no duplicates")
print(" ✓ First get (Task A): samples 0,1,2")
print(" ✓ Second get (Task A): samples 3,4,5")
print(" ✓ Different task_name: Independent consumption with other tasks")
print(" ✓ Third get (Task B): samples 0,1")
tq_client.clear_partition(partition_id="sampling")
tq.close()
ray.shutdown()
def main():
"""Main function to run the tutorial."""
print("=" * 80)
print(
textwrap.dedent(
"""
TransferQueue Tutorial 4: Understanding TransferQueueController
This script demonstrates TransferQueueController's key features:
1. Partition Isolation - Different partition_id creates isolated virtual partitions
2. Dynamic Expansion - Auto-expand rows and columns, get BatchMeta anytime
3. Sequential Sampling - Same task_name gets non-duplicate samples sequentially by default
4. Independent Tasks - Different task_name have independent consumption tracking
Key Concepts:
- Partition-based organization with complete isolation
- Dynamic scaling without pre-allocation
- Sample strategy prevents duplicate consumption
- Task-specific consumption tracking
"""
)
)
print("=" * 80)
try:
demonstrate_partition_isolation()
demonstrate_dynamic_expansion()
demonstrate_default_consumption_sample_strategy()
print("\n" + "=" * 80)
print("Tutorial Complete!")
print("=" * 80)
print("Key Takeaways:")
print("1. Partitions are completely isolated - different partition_id = independent data")
print("2. Dynamic expansion - add rows/columns anytime, get fresh BatchMeta")
print("3. Sequential sampling - same task_name gets unique samples in order by default")
print("4. Independent consumption - different task_name don't interfere")
except Exception as e:
print(f"Error during tutorial: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()