import os
import subprocess
import time
from importlib import resources
from typing import Any, Callable
import ray
import torch
from omegaconf import DictConfig, OmegaConf
from tensordict import TensorDict
from tensordict.tensorclass import NonTensorStack
from transfer_queue.client import TransferQueueClient
from transfer_queue.controller import TransferQueueController
from transfer_queue.metadata import KVBatchMeta
from transfer_queue.sampler import *
from transfer_queue.sampler import BaseSampler
from transfer_queue.storage.bootstrap import StorageBootstrapProvider
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.yuanrong_utils import cleanup_yuanrong_resources
from transfer_queue.utils.zmq_utils import process_zmq_server_info
logger = get_logger(__name__)
_TQ_CLIENT: Any = None
_TQ_STORAGE: Any = None
_TQ_CONTROLLER: Any = None
def _maybe_create_tq_client(conf: DictConfig | None = None) -> TransferQueueClient:
global _TQ_CLIENT
if _TQ_CLIENT is None:
if conf is None:
_init_from_existing()
assert _TQ_CLIENT is not None, (
"TransferQueueController has not been initialized yet. Please call init() first."
)
return _TQ_CLIENT
pid = os.getpid()
_TQ_CLIENT = TransferQueueClient(
client_id=f"TransferQueueClient_{pid}", controller_info=conf.controller.zmq_info
)
backend_name = conf.backend.storage_backend
_TQ_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name])
return _TQ_CLIENT
def _maybe_create_tq_storage(conf: DictConfig) -> DictConfig:
global _TQ_STORAGE
if _TQ_STORAGE is None:
_TQ_STORAGE = {}
backend_name = conf.backend.storage_backend
provider_fn = StorageBootstrapProvider.get_provider(backend_name)
if provider_fn is not None:
backend_resources = provider_fn(conf)
if backend_resources is not None:
_TQ_STORAGE[backend_name] = backend_resources
else:
logger.error(f"Not found available {backend_name} storage resources, please check the config.")
else:
logger.error(
f"Storage backend {backend_name} not registered. Please add it to the StorageBootstrapProvider."
)
return conf
def _init_from_existing() -> bool:
"""Initialize the TransferQueueClient from existing controller.
Returns:
True if successfully initialized from existing controller, False otherwise.
"""
global _TQ_CONTROLLER
try:
if _TQ_CONTROLLER is None:
_TQ_CONTROLLER = ray.get_actor("TransferQueueController", namespace="transfer_queue")
except ValueError:
logger.info("Called _init_from_existing() but TransferQueueController has not been initialized yet.")
return False
logger.info("Found existing TransferQueueController instance. Connecting...")
conf = None
while conf is None:
conf = ray.get(_TQ_CONTROLLER.get_config.remote())
if conf is not None:
_maybe_create_tq_client(conf)
logger.info("TransferQueueClient initialized.")
return True
logger.debug("Waiting for controller to initialize... Retrying in 1s")
time.sleep(1)
return False
def init(conf: DictConfig | None = None) -> DictConfig | None:
"""Initialize the TransferQueue system.
This function sets up the TransferQueue controller, distributed storage, and client.
It should be called once at the beginning of the program before any data operations.
If a controller already exists, reuse it and only initialize the client;
the provided `conf` will be ignored in this case.
Args:
conf: Optional custom config merged with default `config.yaml`.
Only takes effect on first-time initialization, ignored when attaching
to an existing controller.
Returns:
The merged configuration dictionary.
Example:
>>> # In process 0, node A
>>> import transfer_queue as tq
>>> tq.init() # Initialize the TransferQueue
>>> tq.put(...) # then you can use tq for data operations
>>>
>>> # In process 1, node B (with Ray connected to node A)
>>> import transfer_queue as tq
>>> tq.init() # This will only initialize a TransferQueueClient and link with existing TQ
>>> metadata = tq.get_meta(...)
>>> data = tq.get_data(metadata)
"""
if _init_from_existing():
return conf
logger.info("No TransferQueueController found. Starting first-time initialization...")
final_conf = OmegaConf.create({}, flags={"allow_objects": True})
default_conf = OmegaConf.load(resources.files("transfer_queue") / "config.yaml")
final_conf = OmegaConf.merge(final_conf, default_conf)
if conf:
final_conf = OmegaConf.merge(final_conf, conf)
try:
sampler = final_conf.controller.sampler
if isinstance(sampler, BaseSampler):
sampler = sampler
elif isinstance(sampler, type) and issubclass(sampler, BaseSampler):
sampler = sampler()
elif isinstance(sampler, str):
sampler = globals()[final_conf.controller.sampler]
except KeyError:
raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None
try:
global _TQ_CONTROLLER
_TQ_CONTROLLER = TransferQueueController.options(
name="TransferQueueController", namespace="transfer_queue"
).remote(sampler=sampler, polling_mode=final_conf.controller.polling_mode)
logger.info("TransferQueueController has been created.")
except ValueError:
logger.info("Some other rank has initialized TransferQueueController. Try to connect to existing controller.")
_init_from_existing()
return final_conf
controller_zmq_info = process_zmq_server_info(_TQ_CONTROLLER)
final_conf.controller.zmq_info = controller_zmq_info
final_conf = _maybe_create_tq_storage(final_conf)
ray.get(_TQ_CONTROLLER.store_config.remote(final_conf))
logger.info(f"TransferQueue config: {final_conf}")
metrics_conf = final_conf.get("metrics", {})
if metrics_conf.get("enabled", False):
metrics_port = metrics_conf.get("port", 0)
metrics_endpoint = ray.get(_TQ_CONTROLLER.start_metrics.remote(port=metrics_port))
final_conf.metrics.enabled = True
final_conf.metrics.endpoint = metrics_endpoint
ray.get(_TQ_CONTROLLER.store_config.remote(final_conf))
if final_conf.backend.storage_backend == "SimpleStorage":
storage_zmq_info = final_conf.backend.SimpleStorage.get("zmq_info")
if storage_zmq_info:
ray.get(_TQ_CONTROLLER.register_storage_units_for_metrics.remote(storage_zmq_info))
if _TQ_STORAGE and "SimpleStorage" in _TQ_STORAGE:
futures = [handle.start_metrics.remote(port=0) for handle in _TQ_STORAGE["SimpleStorage"].values()]
ray.get(futures)
logger.info(f"Prometheus metrics endpoint: http://{metrics_endpoint}/metrics")
_maybe_create_tq_client(final_conf)
return final_conf
def close():
"""Close the TransferQueue system.
This function cleans up the TransferQueue system, including:
- Closing the client and its associated resources
- Cleaning up distributed storage (only for the process that initialized it)
- Killing the controller actor
Note:
This function should be called when the TransferQueue system is no longer needed.
"""
global _TQ_CLIENT
global _TQ_STORAGE
global _TQ_CONTROLLER
try:
if _TQ_STORAGE:
for key, value in _TQ_STORAGE.items():
if key == "SimpleStorage":
for storage in value.values():
ray.kill(storage)
elif key == "MooncakeStore":
check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True)
if check.returncode == 0:
pids = check.stdout.strip().replace("\n", ", ")
logger.warning(
f"TransferQueue will not stop mooncake_master process with PID: {pids}. "
f"Consider manually killing the mooncake_master."
)
if _TQ_CLIENT:
try:
ret = _TQ_CLIENT.storage_manager.storage_client._store.remove_all()
if ret < 0:
logger.error("Failed to remove existing keys in mooncake_master.")
else:
logger.info("Successfully removed all existing keys in mooncake_master.")
except Exception:
pass
elif key == "Yuanrong":
cleanup_yuanrong_resources(value)
else:
logger.warning(f"close for _TQ_STORAGE with key {key} is not supported for now.")
_TQ_STORAGE = None
except Exception:
pass
if _TQ_CLIENT:
_TQ_CLIENT.close()
_TQ_CLIENT = None
if _TQ_CONTROLLER:
try:
ray.kill(_TQ_CONTROLLER)
except Exception:
pass
_TQ_CONTROLLER = None
def get_metrics_endpoint() -> str | None:
"""Return the Prometheus metrics endpoint address (``host:port``), or *None* if metrics are disabled.
Works from any process — the endpoint is stored in the Controller's config
so that processes joining via ``_init_from_existing()`` can discover it too.
Example:
>>> import transfer_queue as tq
>>> tq.init({"metrics": {"enabled": True}})
>>> endpoint = tq.get_metrics_endpoint()
>>> print(endpoint) # e.g. "10.0.1.42:38271"
>>> # Use endpoint to register Prometheus scrape target
"""
if _TQ_CONTROLLER is None:
_init_from_existing()
if _TQ_CONTROLLER is None:
return None
conf = ray.get(_TQ_CONTROLLER.get_config.remote())
if conf is None:
return None
return conf.get("metrics", {}).get("endpoint", None)
def kv_put(
key: str,
partition_id: str,
fields: TensorDict | dict[str, Any] | None = None,
tag: dict[str, Any] | None = None,
data_parser: Callable[[Any], Any] | None = None,
) -> KVBatchMeta:
"""Put a single key-value pair to TransferQueue.
This is a convenience method for putting data using a user-specified key
instead of BatchMeta. Internally, the key is translated to a BatchMeta
and the data is stored using the regular put mechanism.
Args:
key: User-specified key for the data sample (in row)
partition_id: Logical partition to store the data in
fields: Data fields to store. Can be a TensorDict or a dict of tensors.
Each key in `fields` will be treated as a column for the data sample.
If dict is provided, tensors will be unsqueezed to add batch dimension.
If not provided, will only update the newly given tag to the key.
tag: Optional metadata tag to associate with the key
data_parser: Optional callable to parse reference data (e.g., URLs) into real
content. The input is a slice of the `fields` parameter passed to
kv_put / kv_batch_put, in plain dict format (not TensorDict),
mapping field_name -> batched values. For a regular tensor column
the value is a batched tensor; for nested tensors (jagged or
strided) and NonTensorStack columns the values are extracted into
a list. It must modify values in-place based on the original keys;
do not add or remove keys. The number of elements per column must
also remain unchanged. Do not change the inner order of values
within each column. Only supported by SimpleStorage.
Returns:
KVBatchMeta: Metadata containing the key, tags, partition_id, and fields.
The `fields` attribute includes all fields stored for this sample,
including any new fields written by this put operation.
Raises:
ValueError: If neither fields nor tag is provided
ValueError: If nested tensors are provided (use kv_batch_put instead)
RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
Example:
>>> import transfer_queue as tq
>>> import torch
>>> tq.init()
>>> # Put with both fields and tag
>>> meta = tq.kv_put(
... key="sample_1",
... partition_id="train",
... fields={"input_ids": torch.tensor([1, 2, 3])},
... tag={"score": 0.95}
... )
>>> print(meta.fields) # ['input_ids']
"""
if fields is None and tag is None:
raise ValueError("Please provide at least one parameter of `fields` or `tag`.")
tq_client = _maybe_create_tq_client()
batch_meta = tq_client.kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True)
if batch_meta.size != 1:
raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
if tag is not None:
batch_meta.update_custom_meta([tag])
if fields is not None:
if isinstance(fields, dict):
batch = {}
for field_name, value in fields.items():
if isinstance(value, torch.Tensor):
if value.is_nested:
raise ValueError("Please use (async)kv_batch_put for batch operation")
batch[field_name] = value.unsqueeze(0)
else:
batch[field_name] = NonTensorStack(value)
fields = TensorDict(batch, batch_size=[1])
elif not isinstance(fields, TensorDict):
raise ValueError("`fields` can only be dict or TensorDict")
batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser)
else:
tq_client.set_custom_meta(batch_meta)
fields_to_return = batch_meta.field_names
return KVBatchMeta(
keys=[key],
tags=batch_meta.custom_meta,
partition_id=partition_id,
fields=fields_to_return,
extra_info=batch_meta.extra_info,
)
def kv_batch_put(
keys: list[str],
partition_id: str,
fields: TensorDict | None = None,
tags: list[dict[str, Any]] | None = None,
data_parser: Callable[[Any], Any] | None = None,
) -> KVBatchMeta:
"""Batch put multiple key-value pairs into the TransferQueue.
This method stores multiple key-value entries in a single operation,
which is significantly more efficient than repeated calls to ``kv_put``.
Args:
keys: List of user-defined unique keys for the data entries.
partition_id: Logical partition where the data will be stored.
fields: TensorDict containing batched data for all keys. Must have ``batch_size == len(keys)``.
If not provided, only the associated tags will be updated.
tags: List of metadata dictionaries, one per key. Length must match the number of keys.
data_parser: Optional callable to parse raw reference data (e.g., URLs) into real content
before storage. The input is a plain dict (not TensorDict) mapping field names to
batched values. The parser **must modify data in-place** without adding/removing
keys or changing element counts/order. Only supported by ``SimpleStorage`` backend.
Returns:
KVBatchMeta: Metadata object containing stored keys, tags, partition ID,
and field information. The ``fields`` attribute includes all
persisted fields for the written samples.
Raises:
ValueError: When both ``fields`` and ``tags`` are empty.
ValueError: When ``fields`` batch size mismatches key count.
ValueError: When ``tags`` length mismatches key count.
RuntimeError: When retrieved metadata size mismatches input key count.
Example:
>>> import transfer_queue as tq
>>> from tensordict import TensorDict
>>> tq.init()
>>> keys = ["sample_1", "sample_2", "sample_3"]
>>> fields = TensorDict({
... "input_ids": torch.randn(3, 10),
... "attention_mask": torch.ones(3, 10),
... }, batch_size=3)
>>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
>>> meta = tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
>>> print(meta.fields)
"""
num_keys = len(keys)
if fields is None and tags is None:
raise ValueError("Please provide at least one parameter of fields or tag.")
if fields is not None and fields.batch_size[0] != num_keys:
raise ValueError(f"Length of `keys` ({num_keys}) does not match `fields` batch size ({fields.batch_size[0]}).")
tq_client = _maybe_create_tq_client()
batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True)
if batch_meta.size != num_keys:
raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match input `keys` size {num_keys}.")
if tags is not None:
if len(tags) != num_keys:
raise ValueError(f"Length of `keys` ({num_keys}) does not match length of `tags` ({len(tags)}).")
batch_meta.update_custom_meta(tags)
if fields is not None:
batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser)
else:
tq_client.set_custom_meta(batch_meta)
return KVBatchMeta(
keys=keys,
tags=batch_meta.custom_meta,
partition_id=partition_id,
fields=batch_meta.field_names,
extra_info=batch_meta.extra_info,
)
def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: list[str] | str | None = None) -> TensorDict:
"""Get data from TransferQueue using KVBatchMeta.
This is a convenience method for retrieving data using KVBatchMeta returned
from a previous put operation. It extracts the keys and partition_id from
the metadata to fetch the corresponding data.
Args:
meta: KVBatchMeta object returned from a previous put operation (e.g., kv_put,
kv_batch_put). It contains keys, partition_id, and fields information.
select_fields: Optional field(s) to retrieve, which overrides the fields
recorded in the given KVBatchMeta. If None, uses all fields
from meta.fields. Can be a single field name (str) or a list
of field names.
Returns:
TensorDict with the requested data
Raises:
ValueError: If keys or partition are not found
ValueError: If empty fields exist in any key (sample)
ValueError: If any field in select_fields doesn't exist in KVBatchMeta.fields
Example:
>>> import transfer_queue as tq
>>> tq.init()
>>> # First put some data
>>> keys = ["sample_1", "sample_2", "sample_3"]
>>> fields = TensorDict({
... "input_ids": torch.randn(3, 10),
... "attention_mask": torch.ones(3, 10),
... }, batch_size=3)
>>> meta = tq.kv_batch_put(keys=keys, partition_id="train", fields=fields)
>>> # Then retrieve it using the returned metadata
>>> data = tq.kv_batch_get_by_meta(meta)
"""
if meta.partition_id is None:
raise ValueError("Must provide partition_id in the input KVBatchMeta.")
if select_fields is not None:
if isinstance(select_fields, str):
fields_to_fetch: list[str] | None = [select_fields]
else:
fields_to_fetch = select_fields
assert fields_to_fetch is not None
if meta.fields is None or any(f not in meta.fields for f in fields_to_fetch):
raise ValueError(
f"Some fields assigned in select_fields not found in the metadata. "
f"Assigned: {fields_to_fetch}; Fields in KVBatchMeta: {meta.fields}."
)
else:
fields_to_fetch = meta.fields
return kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=fields_to_fetch)
def kv_batch_get(keys: list[str] | str, partition_id: str, select_fields: list[str] | str | None = None) -> TensorDict:
"""Get data from TransferQueue using user-specified keys.
This is a convenience method for retrieving data using keys instead of indexes.
Args:
keys: Single key or list of keys to retrieve
partition_id: Partition containing the keys
select_fields: Optional field(s) to retrieve. If None, retrieves all fields
Returns:
TensorDict with the requested data
Raises:
ValueError: If keys or partition are not found
ValueError: If empty fields exist in any key (sample)
Example:
>>> import transfer_queue as tq
>>> tq.init()
>>> # Get single key with all fields
>>> data = tq.kv_batch_get(keys="sample_1", partition_id="train")
>>> # Get multiple keys with specific fields
>>> data = tq.kv_batch_get(
... keys=["sample_1", "sample_2"],
... partition_id="train",
... select_fields="input_ids"
... )
"""
tq_client = _maybe_create_tq_client()
batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False)
if batch_meta.size == 0:
raise ValueError("keys or partition were not found!")
fields_to_fetch: list[str] | None
if select_fields is not None:
if isinstance(select_fields, str):
fields_to_fetch = [select_fields]
else:
fields_to_fetch = select_fields
batch_meta = batch_meta.select_fields(fields_to_fetch)
if not batch_meta.is_ready:
raise ValueError("Some fields are not ready in all the requested keys!")
data = tq_client.get_data(batch_meta)
return data
def kv_list(partition_id: str | None = None) -> dict[str, dict[str, Any]]:
"""List all keys and their metadata in one or all partitions.
Args:
partition_id: The specific partition_id to query.
If None (default), returns keys from all partitions.
Returns:
A nested dictionary mapping partition IDs to their keys and metadata.
Structure:
{
"partition_id": {
"key_name": {
"tag1": <value>,
... (other metadata)
},
...,
},
...
}
Example:
>>> import transfer_queue as tq
>>> tq.init()
>>> # Case 1: Retrieve a specific partition
>>> partitions = tq.kv_list(partition_id="train")
>>> print(f"Keys: {list(partitions['train'].keys())}")
>>> print(f"Tags: {list(partitions['train'].values())}")
>>> # Case 2: Retrieve all partitions
>>> all_partitions = tq.kv_list()
>>> for pid, keys in all_partitions.items():
>>> print(f"Partition: {pid}, Key count: {len(keys)}")
"""
tq_client = _maybe_create_tq_client()
partition_info = tq_client.kv_list(partition_id)
return partition_info
def kv_clear(keys: list[str] | str, partition_id: str) -> None:
"""Clear key-value pairs from TransferQueue.
This removes the specified keys and their associated data from both
the controller and storage units.
Args:
keys: Single key or list of keys to clear
partition_id: Partition containing the keys
Example:
>>> import transfer_queue as tq
>>> tq.init()
>>> # Clear single key
>>> tq.kv_clear(keys="sample_1", partition_id="train")
>>> # Clear multiple keys
>>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
"""
if isinstance(keys, str):
keys = [keys]
tq_client = _maybe_create_tq_client()
batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False)
if batch_meta.size > 0:
tq_client.clear_samples(batch_meta)
async def async_kv_put(
key: str,
partition_id: str,
fields: TensorDict | dict[str, Any] | None = None,
tag: dict[str, Any] | None = None,
data_parser: Callable[[Any], Any] | None = None,
) -> KVBatchMeta:
"""Asynchronously put a single key-value pair to TransferQueue.
This is a convenience method for putting data using a user-specified key
instead of BatchMeta. Internally, the key is translated to a BatchMeta
and the data is stored using the regular put mechanism.
Args:
key: User-specified key for the data sample (in row)
partition_id: Logical partition to store the data in
fields: Data fields to store. Can be a TensorDict or a dict of tensors.
Each key in `fields` will be treated as a column for the data sample.
If dict is provided, tensors will be unsqueezed to add batch dimension.
If not provided, will only update the newly given tag to the key.
tag: Optional metadata tag to associate with the key
data_parser: Optional callable to parse reference data (e.g., URLs) into real
content. The input is a slice of the `fields` parameter passed to
kv_put / kv_batch_put, in plain dict format (not TensorDict),
mapping field_name -> batched values. For a regular tensor column
the value is a batched tensor; for nested tensors (jagged or
strided) and NonTensorStack columns the values are extracted into
a list. It must modify values in-place based on the original keys;
do not add or remove keys. The number of elements per column must
also remain unchanged. Do not change the inner order of values
within each column. Only supported by SimpleStorage.
Returns:
KVBatchMeta: Metadata containing the key, tags, partition_id, and fields.
The `fields` attribute includes all fields stored for this sample,
including any new fields written by this put operation.
Raises:
ValueError: If neither fields nor tag is provided
ValueError: If nested tensors are provided (use kv_batch_put instead)
RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
Example:
>>> import transfer_queue as tq
>>> import torch
>>> tq.init()
>>> # Put with both fields and tag
>>> meta = await tq.async_kv_put(
... key="sample_1",
... partition_id="train",
... fields={"input_ids": torch.tensor([1, 2, 3])},
... tag={"score": 0.95}
... )
>>> print(meta.fields) # ['input_ids']
"""
if fields is None and tag is None:
raise ValueError("Please provide at least one parameter of fields or tag.")
tq_client = _maybe_create_tq_client()
batch_meta = await tq_client.async_kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True)
if batch_meta.size != 1:
raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
if tag is not None:
batch_meta.update_custom_meta([tag])
if fields is not None:
if isinstance(fields, dict):
batch = {}
for field_name, value in fields.items():
if isinstance(value, torch.Tensor):
if value.is_nested:
raise ValueError("Please use (async)kv_batch_put for batch operation")
batch[field_name] = value.unsqueeze(0)
else:
batch[field_name] = NonTensorStack(value)
fields = TensorDict(batch, batch_size=[1])
elif not isinstance(fields, TensorDict):
raise ValueError("`fields` can only be dict or TensorDict")
batch_meta = await tq_client.async_put(fields, batch_meta, data_parser=data_parser)
else:
await tq_client.async_set_custom_meta(batch_meta)
fields_to_return = batch_meta.field_names
return KVBatchMeta(
keys=[key],
tags=batch_meta.custom_meta,
partition_id=partition_id,
fields=fields_to_return,
extra_info=batch_meta.extra_info,
)
async def async_kv_batch_put(
keys: list[str],
partition_id: str,
fields: TensorDict | None = None,
tags: list[dict[str, Any]] | None = None,
data_parser: Callable[[Any], Any] | None = None,
) -> KVBatchMeta:
"""Asynchronously put multiple key-value pairs to TransferQueue in batch.
This method stores multiple key-value pairs in a single operation, which is more
efficient than calling kv_put multiple times.
Args:
keys: List of user-specified keys for the data
partition_id: Logical partition to store the data in
fields: TensorDict containing data for all keys. Must have batch_size == len(keys).
If not provided, will only update the newly given tags to the keys.
tags: List of metadata tags, one for each key
data_parser: Optional callable to parse reference data (e.g., URLs) into real
content. The input is a slice of the `fields` parameter passed to
kv_put / kv_batch_put, in plain dict format (not TensorDict),
mapping field_name -> batched values. For a regular tensor column
the value is a batched tensor; for nested tensors (jagged or
strided) and NonTensorStack columns the values are extracted into
a list. It must modify values in-place based on the original keys;
do not add or remove keys. The number of elements per column must
also remain unchanged. Do not change the inner order of values
within each column. Only supported by SimpleStorage.
Returns:
KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields.
The `fields` attribute includes all fields stored for these samples,
including any new fields written by this put operation.
Raises:
ValueError: If neither `fields` nor `tags` is provided
ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
Example:
>>> import transfer_queue as tq
>>> tq.init()
>>> keys = ["sample_1", "sample_2", "sample_3"]
>>> fields = TensorDict({
... "input_ids": torch.randn(3, 10),
... "attention_mask": torch.ones(3, 10),
... }, batch_size=3)
>>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
>>> meta = await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
>>> print(meta.fields) # ['input_ids', 'attention_mask']
"""
if fields is None and tags is None:
raise ValueError("Please provide at least one parameter of `fields` or `tags`.")
if fields is not None and fields.batch_size[0] != len(keys):
raise ValueError(
f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
f"batch_size {fields.batch_size[0]}"
)
tq_client = _maybe_create_tq_client()
batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True)
if batch_meta.size != len(keys):
raise RuntimeError(
f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
)
if tags is not None:
if len(tags) != len(keys):
raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
batch_meta.update_custom_meta(tags)
if fields is not None:
batch_meta = await tq_client.async_put(fields, batch_meta, data_parser=data_parser)
else:
await tq_client.async_set_custom_meta(batch_meta)
fields_to_return = batch_meta.field_names
return KVBatchMeta(
keys=keys,
tags=batch_meta.custom_meta,
partition_id=partition_id,
fields=fields_to_return,
extra_info=batch_meta.extra_info,
)
async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: list[str] | str | None = None) -> TensorDict:
"""Asynchronously get data from TransferQueue using KVBatchMeta.
This is a convenience method for retrieving data using KVBatchMeta returned
from a previous put operation. It extracts the keys and partition_id from
the metadata to fetch the corresponding data.
Args:
meta: KVBatchMeta object returned from a previous put operation (e.g., async_kv_put,
async_kv_batch_put). It contains keys, partition_id, and fields information.
select_fields: Optional field(s) to retrieve, which overrides the fields
recorded in the given KVBatchMeta. If None, uses all fields
from meta.fields. Can be a single field name (str) or a list
of field names.
Returns:
TensorDict with the requested data
Raises:
ValueError: If keys or partition are not found
ValueError: If empty fields exist in any key (sample)
ValueError: If any field in select_fields doesn't exist in KVBatchMeta.fields
Example:
>>> import transfer_queue as tq
>>> tq.init()
>>> # First put some data
>>> keys = ["sample_1", "sample_2", "sample_3"]
>>> fields = TensorDict({
... "input_ids": torch.randn(3, 10),
... "attention_mask": torch.ones(3, 10),
... }, batch_size=3)
>>> meta = await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields)
>>> # Then retrieve it using the returned metadata
>>> data = await tq.async_kv_batch_get_by_meta(meta)
"""
if meta.partition_id is None:
raise ValueError("Must provide partition_id in the input KVBatchMeta.")
fields_to_fetch: list[str] | None
if select_fields is not None:
if isinstance(select_fields, str):
fields_to_fetch = [select_fields]
else:
fields_to_fetch = select_fields
assert fields_to_fetch is not None
if meta.fields is None or any(f not in meta.fields for f in fields_to_fetch):
raise ValueError(
f"Some fields assigned in select_fields not found in the metadata. "
f"Assigned: {fields_to_fetch}; Fields in KVBatchMeta: {meta.fields}."
)
else:
fields_to_fetch = meta.fields
return await async_kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=fields_to_fetch)
async def async_kv_batch_get(
keys: list[str] | str, partition_id: str, select_fields: list[str] | str | None = None
) -> TensorDict:
"""Asynchronously get data from TransferQueue using user-specified keys.
This is a convenience method for retrieving data using keys instead of indexes.
Args:
keys: Single key or list of keys to retrieve
partition_id: Partition containing the keys
select_fields: Optional field(s) to retrieve. If None, retrieves all fields
Returns:
TensorDict with the requested data
Raises:
ValueError: If keys or partition are not found
ValueError: If empty fields exist in any key (sample)
Example:
>>> import transfer_queue as tq
>>> tq.init()
>>> # Get single key with all fields
>>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train")
>>> # Get multiple keys with specific fields
>>> data = await tq.async_kv_batch_get(
... keys=["sample_1", "sample_2"],
... partition_id="train",
... select_fields="input_ids"
... )
"""
tq_client = _maybe_create_tq_client()
batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False)
if batch_meta.size == 0:
raise ValueError("keys or partition were not found!")
if select_fields is not None:
if isinstance(select_fields, str):
fields_to_fetch = [select_fields]
else:
fields_to_fetch = select_fields
batch_meta = batch_meta.select_fields(fields_to_fetch)
if not batch_meta.is_ready:
raise ValueError("Some fields are not ready in all the requested keys!")
data = await tq_client.async_get_data(batch_meta)
return data
async def async_kv_list(partition_id: str | None = None) -> dict[str, dict[str, Any]]:
"""Asynchronously list all keys and their metadata in one or all partitions.
Args:
partition_id: The specific partition_id to query.
If None (default), returns keys from all partitions.
Returns:
A nested dictionary mapping partition IDs to their keys and metadata.
Structure:
{
"partition_id": {
"key_name": {
"tag1": <value>,
... (other metadata)
},
...,
},
...
}
Example:
>>> import transfer_queue as tq
>>> tq.init()
>>> # Case 1: Retrieve a specific partition
>>> partitions = await tq.async_kv_list(partition_id="train")
>>> print(f"Keys: {list(partitions['train'].keys())}")
>>> print(f"Tags: {list(partitions['train'].values())}")
>>> # Case 2: Retrieve all partitions
>>> all_partitions = await tq.async_kv_list()
>>> for pid, keys in all_partitions.items():
>>> print(f"Partition: {pid}, Key count: {len(keys)}")
"""
tq_client = _maybe_create_tq_client()
partition_info = await tq_client.async_kv_list(partition_id)
return partition_info
async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
"""Asynchronously clear key-value pairs from TransferQueue.
This removes the specified keys and their associated data from both
the controller and storage units.
Args:
keys: Single key or list of keys to clear
partition_id: Partition containing the keys
Example:
>>> import transfer_queue as tq
>>> tq.init()
>>> # Clear single key
>>> await tq.async_kv_clear(keys="sample_1", partition_id="train")
>>> # Clear multiple keys
>>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
"""
if isinstance(keys, str):
keys = [keys]
tq_client = _maybe_create_tq_client()
batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False)
if batch_meta.size > 0:
await tq_client.async_clear_samples(batch_meta)
def get_client():
"""Get a TransferQueueClient for using low-level API"""
assert _TQ_CLIENT is not None, "Please initialize the TransferQueue first by calling `tq.init()`!"
return _TQ_CLIENT