import os
import sys
import textwrap
import warnings
from pathlib import Path
warnings.filterwarnings(
action="ignore",
message=r"The PyTorch API of nested tensors is in prototype stage*",
category=UserWarning,
module=r"torch\.nested",
)
warnings.filterwarnings(
action="ignore",
message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible "
r"devices env var if num_gpus=0 or num_gpus=None.*",
category=FutureWarning,
module=r"ray\._private\.worker",
)
import ray
import torch
from tensordict import TensorDict
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
import transfer_queue as tq
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_DEBUG"] = "1"
if not ray.is_initialized():
ray.init(namespace="TransferQueueTutorial")
def demonstrate_kv_api():
"""
Demonstrate the Key-Value (KV) semantic API:
kv_put & kv_batch_put -> kv_list -> kv_batch_get -> kv_clear
"""
print("=" * 80)
print("Key-Value Semantic API Demo: kv_put/kv_batch_put → kv_list → kv_batch_get → kv_clear")
print("=" * 80)
print("[Step 1] Putting a single sample with kv_put...")
input_ids = torch.tensor([[1, 2, 3]])
attention_mask = torch.ones(input_ids.size())
single_sample = TensorDict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
},
batch_size=input_ids.size(0),
)
partition_id = "Train"
key = "0_0"
tag = {"global_steps": 0, "status": "running", "model_version": 0}
print(f" Inserting Key: {key}")
print(f" Fields (Columns): {list(single_sample.keys())}")
print(f" Tag (Metadata): {tag}")
tq.kv_put(key=key, partition_id=partition_id, fields=single_sample, tag=tag)
print(" ✓ kv_put success.")
print("\n[Step 2] Putting batch data with kv_batch_put...")
batch_input_ids = torch.tensor(
[
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15],
]
)
batch_attention_mask = torch.ones_like(batch_input_ids)
data_batch = TensorDict(
{
"input_ids": batch_input_ids,
"attention_mask": batch_attention_mask,
},
batch_size=batch_input_ids.size(0),
)
keys = ["1_0", "1_1", "1_2", "2_0"]
tags = [{"global_steps": 1, "status": "running", "model_version": 1} for _ in range(len(keys))]
print(f" Inserting batch of {len(keys)} samples.")
print(f" Fields (Columns): {list(data_batch.keys())}")
print(f" Tag (Metadata): {tags}")
tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data_batch, tags=tags)
print(" ✓ kv_batch_put success.")
print("\n[Step 3] Appending new fields (Columns) to existing samples...")
batch_response = torch.tensor(
[
[4, 5, 6],
[7, 8, 9],
]
)
response_batch = TensorDict(
{
"response": batch_response,
},
batch_size=batch_response.size(0),
)
append_keys = ["1_1", "2_0"]
append_tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(append_keys))]
print(f" Target Keys: {append_keys}")
print(" New Field to Add is: 'response'")
print(f" The updated tags are: {append_tags}")
tq.kv_batch_put(keys=append_keys, partition_id=partition_id, fields=response_batch, tags=append_tags)
print(" ✓ Update success: Samples '1_1' and '2_0' now contain {input_ids, attention_mask, response}.")
print("\n[Step 4] Update existing tags without providing value...")
key_for_update_tags = "0_0"
tag_update = {"global_steps": 0, "status": "finish", "model_version": 0}
print(f" Target Key: {key_for_update_tags}")
print(f" The updated tag is: {tag_update}")
tq.kv_put(key=key_for_update_tags, partition_id=partition_id, fields=None, tag=tag_update)
print(f" ✓ Update success: Samples '0_0' now has tag as {tag_update}.")
print("\n[Step 5] Listing all keys and tags in partition...")
partition_info = tq.kv_list()
print(f" Found {len(partition_info.keys())} partitions: '{list(partition_info.keys())}'")
for pid, keys_and_tags in partition_info.items():
for k, t in keys_and_tags.items():
print(f"Partition: {pid}, - key='{k}' | tag={t}")
print("\n[Step 6] Retrieving specific fields (Column) with kv_batch_get...")
print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').")
all_keys = list(partition_info[partition_id].keys())
retrieved_input_ids = tq.kv_batch_get(keys=all_keys, partition_id=partition_id, select_fields="input_ids")
print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.")
print("\n[Step 7] Retrieving all fields with kv_batch_get...")
retrieved_all = tq.kv_batch_get(keys=all_keys, partition_id=partition_id)
print(f" Retrieved all fields for {all_keys}:")
print(f" Fields: {list(retrieved_all.keys())}")
print(
f" Note: We cannot retrieve fields {list(response_batch.keys())}, since they only available in {append_keys}"
)
print("\n[Step 8] Clearing keys from partition...")
keys_to_clear = all_keys[:2]
tq.kv_clear(keys=keys_to_clear, partition_id=partition_id)
print(f" ✓ Cleared keys: {keys_to_clear}")
partition_info_after_clear = tq.kv_list(partition_id=partition_id)
print(f" Remaining keys in partition: {list(partition_info_after_clear[partition_id].keys())}")
def main():
print("=" * 80)
print(
textwrap.dedent(
"""
TransferQueue Tutorial 2: Key-Value (KV) Semantic API
This tutorial demonstrates the KV semantic API, which provides a simple
interface for data storage and retrieval using user-defined string keys.
Key Methods:
1. (async_)kv_put - Insert/Update a multi-column sample by key, with optional metadata tag
2. (async_)kv_batch_put - Put multiple key-value pairs efficiently in batch
3. (async_)kv_batch_get - Retrieve samples (by keys), supporting column selection (by fields)
4. (async_)kv_list - List keys and tags (metadata) in a partition
5. (async_)kv_clear - Remove key-value pairs from storage
Key Features:
✓ Redis-style Semantics - Familiar KV interface (Put/Get/List) for zero learning curve
✓ Fine-grained Access - Update or retrieve specific fields (columns) within a key (row) without full op.
✓ Partition Isolation - Logical separation of storage namespaces
✓ Metadata Tags - Lightweight metadata for status tracking
✓ Pluggable Backends - Supports multiple backends
Use Cases:
- Focusing on fine-grained data access where extreme streaming performance is non-essential
- Integration with external ReplayBuffer/single-controller that manage sample dispatching
Limitations (vs low-level native APIs):
- No built-in production/consumption tracking: Users must manually check status via tags externally.
- No built-in Sampler support: Must implement data dispatch by ReplayBuffer or single-controller externally.
- Not fully streaming: Consumers must wait for single-controller to dispatch `keys`.
"""
)
)
print("=" * 80)
try:
print("Setting up TransferQueue...")
tq.init()
print("\nDemonstrating the KV semantic API...")
demonstrate_kv_api()
print("\n" + "=" * 80)
print("Tutorial Complete!")
print("=" * 80)
print("\nKey Takeaways:")
print(" 1. KV API simplifies data access with Redis-style semantics")
print(" 2. Use 'fields' parameter to get/put specific fields only")
print(" 3. Tags enable custom metadata for production status, scores, etc.")
print(" 4. Use kv_list to inspect partition contents")
tq.close()
ray.shutdown()
print("\nCleanup complete")
except Exception as e:
print(f"Error during tutorial: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()