import os
import sys
import textwrap
import uuid
import warnings
from pathlib import Path
warnings.filterwarnings(
action="ignore",
message=r"The PyTorch API of nested tensors is in prototype stage*",
category=UserWarning,
module=r"torch\.nested",
)
warnings.filterwarnings(
action="ignore",
message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible "
r"devices env var if num_gpus=0 or num_gpus=None.*",
category=FutureWarning,
module=r"ray\._private\.worker",
)
import numpy as np
import ray
import torch
from tensordict import TensorDict
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
import transfer_queue as tq
from transfer_queue.metadata import BatchMeta
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_DEBUG"] = "1"
def demonstrate_batch_meta_schema():
"""
Demonstrate BatchMeta basic usage.
"""
print("=" * 80)
print("BatchMeta - Fine-Grained Metadata in Field Level")
print("=" * 80)
print("field_schema stores metadata for each field:")
print("- dtype: Data type (torch.float32, torch.int64, etc.)")
print("- shape: Shape of ONE sample's data")
print("- is_nested: Whether the field uses nested/ragged tensors")
print("- is_non_tensor: Whether the field is non-tensor data")
print("[Example 1] Creating field schema entry for input_ids...")
batch = BatchMeta(
global_indexes=[0, 1, 2],
partition_ids=["train_0"] * 3,
field_schema={
"input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False},
},
)
print("✓ Created: BatchMeta with field 'input_ids'")
print(f" input_ids schema: {batch.field_schema['input_ids']}")
print(f" Is ready: {batch.is_ready}")
print(" Note: Shape (512,) means ONE sample has 512 tokens (no batch dimension)")
print("[Example 2] Creating field schema entry for attention_mask...")
batch2 = BatchMeta(
global_indexes=[0, 1, 2],
partition_ids=["train_0"] * 3,
field_schema={
"attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False},
},
)
print("✓ Created: BatchMeta with field 'attention_mask'")
print(f" attention_mask schema: {batch2.field_schema['attention_mask']}")
print(f" Is ready: {batch2.is_ready}")
print("[Example 3] Checking field readiness...")
ready_batch = BatchMeta(
global_indexes=[0, 1, 2],
partition_ids=["train_0"] * 3,
field_schema={
"input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False},
"attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.array([1, 1, 1], dtype="int8"),
)
print(f" input_ids field exists: {'input_ids' in ready_batch.field_schema}")
print(f" attention_mask field exists: {'attention_mask' in ready_batch.field_schema}")
print(f" not-ready batch is_ready: {batch.is_ready}")
print(f" ready batch is_ready: {ready_batch.is_ready}")
print("[Example 4] Accessing sample view and individual field by key...")
view = ready_batch.samples[0]
print(f" batch.samples[0].fields -> {view.fields}")
print(f" batch.samples[0].fields['input_ids'] -> {view.fields['input_ids']}")
print(f" batch.samples[0].fields['input_ids']['dtype'] -> {view.fields['input_ids']['dtype']}")
def demonstrate_batch_meta_operations():
"""
Demonstrate BatchMeta construction and operations.
Covers: manual creation, add_fields, select_fields, select_samples,
reorder, chunk, concat, union, extra_info, custom_meta.
"""
print("=" * 80)
print("BatchMeta - Construction & Operations")
print("=" * 80)
print("BatchMeta uses a columnar layout:")
print("- global_indexes: list[int] - unique IDs across ALL partitions")
print("- partition_ids: list[str] - which partition each sample belongs to")
print("- field_schema: dict[str, dict] - field metadata")
print("- Operations: add_fields, select_fields, select_samples, reorder, chunk, concat, union")
def make_batch(global_indexes, fields=None):
if fields is None:
fields = ["input_ids", "attention_mask", "responses"]
schema = {
"input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False},
"attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False},
"responses": {"dtype": torch.int64, "shape": (128,), "is_nested": False, "is_non_tensor": False},
}
return BatchMeta(
global_indexes=global_indexes,
partition_ids=["train_0"] * len(global_indexes),
field_schema={k: v for k, v in schema.items() if k in fields},
)
print("[Example 1] Creating a BatchMeta with input_ids and attention_mask...")
batch = BatchMeta(
global_indexes=[0, 1, 2, 3, 4],
partition_ids=["train_0"] * 5,
field_schema={
"input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False},
"attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False},
},
)
print(f"✓ Created: {len(batch)} samples")
print(f" Global indexes: {batch.global_indexes}, Fields: {batch.field_names}")
print(f" Is ready: {batch.is_ready}")
print("[Example 2] Adding new fields via add_fields(TensorDict)...")
new_data = TensorDict(
{"responses": torch.randint(0, 1000, (5, 128)), "log_probs": torch.randn(5, 128)},
batch_size=5,
)
batch.add_fields(new_data)
print(f"✓ Added fields: ['responses', 'log_probs']. Now has: {batch.field_names}")
print(f" Is ready: {batch.is_ready} (add_fields sets all to READY by default)")
print("[Example 3] Adding batch-level extra_info and sample-level custom_meta...")
batch.extra_info["epoch"] = 1
batch.extra_info["batch_idx"] = 0
batch.update_custom_meta([{"uid": f"prompt@{i}", "session_id": "session@0"} for i in range(5)])
print(f" Extra info: {batch.get_all_extra_info()}")
print(f" custom_meta[0]: {batch.custom_meta[0]}")
print("[Example 4] Selecting specific fields...")
selected = batch.select_fields(["input_ids", "responses"])
print(f"✓ Selected: {selected.field_names} (original: {batch.field_names})")
print("[Example 5] Selecting specific samples...")
selected_samples = batch.select_samples([0, 2, 4])
print(f"✓ Selected samples at [0,2,4]: indexes={selected_samples.global_indexes}")
print("[Example 6] Reordering samples...")
print(f" Before: {batch.global_indexes}")
batch.reorder([4, 3, 2, 1, 0])
print(f" After: {batch.global_indexes}")
print("[Example 7] Chunking a batch into parts...")
batch_for_chunk = make_batch(list(range(10)))
chunks = batch_for_chunk.chunk(3)
print(f"✓ Split into {len(chunks)} chunks:")
for i, chunk in enumerate(chunks):
print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}")
print("[Example 8] Concatenating batches...")
batch1 = make_batch(list(range(3)))
batch2 = make_batch(list(range(3, 6)))
concatenated = BatchMeta.concat([batch1, batch2])
print(f"✓ Concatenated {len(batch1)} + {len(batch2)} = {len(concatenated)} samples")
print(f" Global indexes: {concatenated.global_indexes}")
print("[Example 9] Unioning batches with same global_indexes but different fields...")
batch_a = make_batch(list(range(3)), fields=["input_ids", "attention_mask"])
batch_b = make_batch(list(range(3)), fields=["attention_mask", "responses"])
print(f" BatchA fields: {batch_a.field_names}, BatchB fields: {batch_b.field_names}")
unioned = batch_a.union(batch_b)
print(f"✓ Unioned fields: {unioned.field_names} (same global_indexes={unioned.global_indexes})")
print(" Note: 'attention_mask' was present in both; other's definition is kept.")
print("[Example 10] Creating an empty BatchMeta...")
empty = BatchMeta.empty()
print(f"✓ Empty: size={empty.size}, is_ready={empty.is_ready}")
print("=" * 80)
print("concat vs union:")
print(" - concat: Combines batches with SAME field structure (append rows)")
print(" - union: Merges batches with SAME global_indexes (append columns/fields)")
print("=" * 80)
def demonstrate_real_workflow():
"""
Demonstrate a realistic workflow with actual TransferQueue interaction.
"""
print("=" * 80)
print("Real Workflow: Interacting with TransferQueue")
print("=" * 80)
if not ray.is_initialized():
ray.init(namespace="TransferQueueTutorial")
tq.init()
tq_client = tq.get_client()
print("[Step 1] Putting data into TransferQueue...")
input_ids = torch.randint(0, 1000, (8, 512))
attention_mask = torch.ones(8, 512)
data_batch = TensorDict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
},
batch_size=8,
)
partition_id = "demo_partition"
batch_meta = tq_client.put(data=data_batch, partition_id=partition_id)
print(f"✓ Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.")
print("[Step 2] [Optional] Setting sample-level custom_meta...")
custom_meta = [
{"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0}
for _ in range(batch_meta.size)
]
batch_meta.update_custom_meta(custom_meta)
print(f"✓ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}")
tq_client.set_custom_meta(batch_meta)
print("✓ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.")
print("[Step 3] Try to get metadata from TransferQueue from other places...")
batch_meta = tq_client.get_meta(
data_fields=["input_ids", "attention_mask"],
batch_size=8,
partition_id=partition_id,
task_name="demo_task",
)
print("✓ Got BatchMeta from TransferQueue:")
print(f" Number of samples: {len(batch_meta)}")
print(f" Global indexes: {batch_meta.global_indexes}")
print(f" Field names: {batch_meta.field_names}")
print(f" Partition IDs: {batch_meta.partition_ids}")
print(f" Custom Meta: {batch_meta.get_all_custom_meta()}")
print("[Step 4] Retrieve samples with specific fields..")
selected_meta = batch_meta.select_fields(["input_ids"])
print("✓ Selected 'input_ids' field only:")
print(f" Field names in new BatchMeta: {selected_meta.field_names}")
print(f" Samples still have same global indexes: {selected_meta.global_indexes}")
retrieved_data = tq_client.get_data(selected_meta)
print(f" Retrieved data keys: {list(retrieved_data.keys())}")
print("[Step 5] Select specific samples from the retrieved BatchMeta...")
partial_meta = batch_meta.select_samples([0, 2, 4, 6])
print("✓ Selected samples at indices [0, 2, 4, 6]:")
print(f" Global indexes in new BatchMeta: {partial_meta.global_indexes}")
print(f" Number of samples: {len(partial_meta)}")
retrieved_data = tq_client.get_data(partial_meta)
print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}")
print("[Step 6] Demonstrate chunk operation...")
chunks = batch_meta.chunk(2)
print(f"✓ Chunked into {len(chunks)} parts:")
for i, chunk in enumerate(chunks):
print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}")
chunk_data = tq_client.get_data(chunk)
print(f" Chunk {i}: Retrieved chunk data: {chunk_data}")
tq_client.clear_partition(partition_id=partition_id)
tq.close()
ray.shutdown()
print("✓ Partition cleared and resources cleaned up")
def main():
"""Main function to run the tutorial."""
print("=" * 80)
print(
textwrap.dedent(
"""
TransferQueue Tutorial 3: Metadata System
This script introduces the metadata system in TransferQueue, which tracks
the structure and state of data:
Key Concepts:
- BatchMeta stores global_indexes, partition_ids, and field_schema directly
- field_schema: dict[field_name, {dtype, shape, is_nested, is_non_tensor}]
- custom_meta: list[dict] aligned with global_indexes (one dict per sample)
- Metadata operations: chunk, concat, union, select_fields, select_samples, reorder
- batch.samples[i] returns a lazy view with .fields -> field_schema (read-only)
"""
)
)
print("=" * 80)
try:
demonstrate_batch_meta_schema()
demonstrate_batch_meta_operations()
demonstrate_real_workflow()
print("=" * 80)
print("Tutorial Complete!")
print("=" * 80)
print("Key Takeaways:")
print("1. BatchMeta uses columnar storage")
print("2. Construct BatchMeta with: BatchMeta(global_indexes=[...], partition_ids=[...], field_schema={...})")
print("3. BatchMeta operations: chunk, concat, union, select_fields, select_samples, reorder")
print("4. extra_info is batch-level; custom_meta is sample-level (list[dict])")
print("5. Store custom_meta via TQ controller: tq_client.set_custom_meta(batch_meta)")
ray.shutdown()
print("\n✓ Cleanup complete")
except Exception as e:
print(f"Error during tutorial: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()