import asyncio
import os
import threading
from typing import Any, Callable
import torch
import zmq
import zmq.asyncio
from tensordict import TensorDict
from transfer_queue.metadata import BatchMeta
from transfer_queue.storage import StorageManagerFactory
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
with_zmq_socket,
)
logger = get_logger(__name__)
TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8))
with_controller_socket = with_zmq_socket(
"request_handle_socket",
get_identity=lambda self: self.client_id,
get_peer=lambda self, target: self._controller,
)
class AsyncTransferQueueClient:
"""Asynchronous client for interacting with TransferQueue controller and storage systems.
This client provides async methods for data transfer operations including getting metadata,
reading data from storage, writing data to storage, and clearing data.
"""
def __init__(
self,
client_id: str,
controller_info: ZMQServerInfo,
):
"""Initialize the asynchronous TransferQueue client.
Args:
client_id: Unique identifier for this client instance
controller_info: Single controller ZMQ server information
"""
if controller_info is None:
raise ValueError("controller_info cannot be None")
if not isinstance(controller_info, ZMQServerInfo):
raise TypeError(f"controller_info must be ZMQServerInfo, got {type(controller_info)}")
self.client_id = client_id
self._controller: ZMQServerInfo = controller_info
logger.info(f"[{self.client_id}]: Registered Controller server {controller_info.id} at {controller_info.ip}")
def initialize_storage_manager(
self,
manager_type: str,
config: dict[str, Any],
):
"""Initialize the storage manager.
Args:
manager_type: Type of storage manager to create. Supported types include:
AsyncSimpleStorageManager, KVStorageManager (under development), etc.
config: Configuration dictionary for the storage manager.
For AsyncSimpleStorageManager, must contain the following required keys:
- zmq_info: ZMQ server information about the storage units
"""
self.storage_manager = StorageManagerFactory.create(
manager_type, controller_info=self._controller, config=config
)
@with_controller_socket
async def async_get_meta(
self,
data_fields: list[str],
batch_size: int,
partition_id: str,
mode: str = "fetch",
task_name: str | None = None,
sampling_config: dict[str, Any] | None = None,
socket: zmq.asyncio.Socket | None = None,
) -> BatchMeta:
"""Asynchronously fetch data metadata from the controller via ZMQ.
Args:
data_fields: List of data field names to retrieve metadata for
batch_size: Number of samples to request in the batch
partition_id: Current data partition id
mode: Data fetch mode. Options:
- 'fetch': Get ready data only
- 'force_fetch': Get data regardless of readiness (may return unready samples)
- 'insert': Internal usage - should not be used by users
task_name: Optional task name associated with the request
sampling_config: Optional sampling configuration for custom samplers.
socket: ZMQ async socket for message transmission (injected by decorator)
Returns:
BatchMeta: Metadata object containing data structure, sample information, and readiness status
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Example 1: Basic fetch metadata
>>> batch_meta = asyncio.run(client.async_get_meta(
... data_fields=["input_ids", "attention_mask"],
... batch_size=4,
... partition_id="train_0",
... mode="fetch",
... task_name="generate_sequences"
... ))
>>> print(batch_meta.is_ready) # True if all samples ready
>>>
>>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
>>> batch_meta = asyncio.run(client.async_get_meta(
... data_fields=["input_ids", "attention_mask"],
... batch_size=8,
... partition_id="train_0",
... mode="fetch",
... task_name="generate_sequences",
... ))
>>> print(batch_meta.is_ready) # True if all samples ready
>>>
>>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
>>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
>>> batch_meta = asyncio.run(client.async_get_meta(
... partition_id="train_0", # optional
... mode="force_fetch",
... ))
>>> print(batch_meta.is_ready) # May be False if some samples not ready
"""
assert socket is not None
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_META,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
"data_fields": data_fields,
"batch_size": batch_size,
"partition_id": partition_id,
"mode": mode,
"task_name": task_name,
"sampling_config": sampling_config,
},
)
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}]: Client get_meta response: {response_msg} from controller {self._controller.id}"
)
if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE:
return response_msg.body["metadata"]
else:
raise RuntimeError(
f"[{self.client_id}]: Failed to get metadata from controller {self._controller.id}: "
f"{response_msg.body.get('message', 'Unknown error')}"
)
@with_controller_socket
async def async_set_custom_meta(
self,
metadata: BatchMeta,
socket: zmq.asyncio.Socket | None = None,
) -> None:
"""
Asynchronously send custom metadata to the controller.
This method sends per-sample custom metadata (custom_meta) to the controller.
The custom_meta is stored in the controller and can be retrieved along with
the BatchMeta in subsequent get_meta calls.
Args:
metadata: BatchMeta containing the samples and their custom metadata to store.
The custom_meta should be set using BatchMeta.update_custom_meta()
before calling this method.
socket: ZMQ async socket for message transmission (injected by decorator)
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Create batch with custom metadata
>>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...)
>>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
>>> asyncio.run(client.async_set_custom_meta(batch_meta))
"""
assert socket is not None
if not self._controller:
raise RuntimeError("No controller registered")
global_indexes = metadata.global_indexes
custom_meta = metadata.get_all_custom_meta()
if len(global_indexes) == 0 or len(custom_meta) == 0:
logger.debug(f"[{self.client_id}]: Empty BatchMeta or custom_meta provided. No action taken.")
return
metadata_chunks = metadata.chunk_by_partition()
partition_custom_meta: dict[str, dict[int, dict]] = {pid: {} for pid in set(metadata.partition_ids)}
for meta in metadata_chunks:
custom_meta = meta.get_all_custom_meta()
partition_custom_meta[meta.partition_ids[0]].update(
{meta.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta))}
)
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.SET_CUSTOM_META,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
"partition_custom_meta": partition_custom_meta,
},
)
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}]: Client set_custom_meta response: {response_msg} from controller {self._controller.id}"
)
if response_msg.request_type != ZMQRequestType.SET_CUSTOM_META_RESPONSE:
raise RuntimeError(
f"[{self.client_id}]: Failed to set custom metadata to controller {self._controller.id}: "
f"{response_msg.body.get('message', 'Unknown error')}"
)
async def async_put(
self,
data: TensorDict,
metadata: BatchMeta | None = None,
partition_id: str | None = None,
data_parser: Callable[[Any], Any] | None = None,
) -> BatchMeta:
"""Asynchronously write data to storage units based on metadata.
If metadata is not provided, it will be created automatically using insert mode
with the provided data fields and partition_id.
During put, the custom_meta in metadata will update the corresponding custom_meta in
TransferQueue Controller.
Note:
When using multiple workers for distributed execution, there may be data
ordering inconsistencies between workers during put operations.
Args:
data: Data to write as TensorDict
metadata: Records the metadata of a batch of data samples, containing index and
storage unit information. If None, metadata will be auto-generated.
partition_id: Target data partition id (required if metadata is not provided)
data_parser: Optional callable to parse reference data (e.g., URLs) into real
content. The input is a slice of the `data` parameter, in plain
dict format (not TensorDict), mapping field_name -> batched values.
For a regular tensor column the value is a batched tensor; for
nested tensors (jagged or strided) and NonTensorStack columns
the values are extracted into a list. It must modify values
in-place based on the original keys; do not add or remove keys.
The number of elements per column must also remain unchanged.
Do not change the inner order of values within each column.
Only supported by SimpleStorage.
Returns:
BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
metadata; will be updated in a future version to reflect the post-put state)
Raises:
ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
RuntimeError: If storage operation fails
Example:
>>> batch_size = 4
>>> seq_len = 16
>>> current_partition_id = "train_0"
>>> # Example 1: Normal usage with existing metadata
>>> batch_meta = asyncio.run(client.async_get_meta(
... data_fields=["prompts", "attention_mask"],
... batch_size=batch_size,
... partition_id=current_partition_id,
... mode="fetch",
... task_name="generate_sequences",
... ))
>>> batch = asyncio.run(client.async_get_data(batch_meta))
>>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
>>> asyncio.run(client.async_put(data=output, metadata=batch_meta))
>>>
>>> # Example 2: Initial data insertion without pre-existing metadata
>>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
>>> # Please make sure the corresponding partition_id is empty before calling the async_put()
>>> # without metadata.
>>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
>>> # interleave the initial data if n_sample > 1 before calling the async_put().
>>> original_prompts = torch.randn(batch_size, seq_len)
>>> n_samples = 4
>>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
>>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
>>> # This will create metadata in "insert" mode internally.
>>> metadata = asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id))
"""
if not hasattr(self, "storage_manager") or self.storage_manager is None:
raise RuntimeError(
f"[{self.client_id}]: Storage manager not initialized. "
"Call initialize_storage_manager() before performing storage operations."
)
for field_name, field_data in data.items():
if isinstance(field_data, torch.Tensor) and field_data.ndim == 1:
logger.info(
f"[{self.client_id}]: Data field '{field_name}' is a tensor with only one dimension. "
f"You may receive 2D tensors in key-value based backend."
)
if metadata is None:
if partition_id is None:
raise ValueError("partition_id must be provided if metadata is not given")
metadata = await self.async_get_meta(
data_fields=list(data.keys()),
batch_size=data.batch_size[0],
partition_id=partition_id,
mode="insert",
)
if not metadata or metadata.size == 0:
raise ValueError("metadata cannot be none or empty")
with limit_pytorch_auto_parallel_threads(
target_num_threads=TQ_NUM_THREADS, info=f"[{self.client_id}] async_put"
):
await self.storage_manager.put_data(data, metadata, data_parser=data_parser)
await self.async_set_custom_meta(metadata)
logger.debug(
f"[{self.client_id}]: partition {partition_id} put {metadata.size} samples to storage units successfully."
)
metadata = metadata.add_fields(data)
return metadata
async def async_get_data(self, metadata: BatchMeta) -> TensorDict:
"""Asynchronously fetch data from storage units and organize into TensorDict.
Args:
metadata: Batch metadata containing data location information and global indexes
Returns:
TensorDict containing:
- Requested data fields (e.g., "prompts", "attention_mask")
Example:
>>> batch_meta = asyncio.run(client.async_get_meta(
... data_fields=["prompts", "attention_mask"],
... batch_size=4,
... partition_id="train_0",
... mode="fetch",
... task_name="generate_sequences",
... ))
>>> batch = asyncio.run(client.async_get_data(batch_meta))
>>> print(batch)
>>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
"""
if not hasattr(self, "storage_manager") or self.storage_manager is None:
raise RuntimeError(
f"[{self.client_id}]: Storage manager not initialized. "
"Call initialize_storage_manager() before performing storage operations."
)
if not metadata or metadata.size == 0 or len(metadata.field_names) == 0:
logger.warning(f"[{self.client_id}]: Empty BatchMeta provided to get_data. Returning empty TensorDict.")
return TensorDict({}, batch_size=0)
with limit_pytorch_auto_parallel_threads(
target_num_threads=TQ_NUM_THREADS, info=f"[{self.client_id}] async_get_data"
):
results = await self.storage_manager.get_data(metadata)
logger.debug(f"[{self.client_id}]: get_data with {metadata.size} samples successfully.")
return results
async def async_clear_partition(self, partition_id: str):
"""Asynchronously clear the whole partition from all storage units and the controller.
Args:
partition_id: The partition id to clear data for
Raises:
RuntimeError: If clear operation fails
"""
try:
if not hasattr(self, "storage_manager") or self.storage_manager is None:
raise RuntimeError(
f"[{self.client_id}]: Storage manager not initialized. "
"Call initialize_storage_manager() before performing storage operations."
)
if not self._controller:
raise RuntimeError("No controller registered")
metadata = await self._get_partition_meta(partition_id)
if not metadata:
logger.warning(f"Try to clear an non-exist partition {partition_id}. No action will be taken.")
return
await self._clear_partition_in_controller(partition_id)
await self.storage_manager.clear_data(metadata)
logger.debug(f"[{self.client_id}]: Clear operation for partition_id {partition_id} completed.")
except Exception as e:
raise RuntimeError(f"Error in clear operation: {str(e)}") from e
async def async_clear_samples(self, metadata: BatchMeta):
"""Asynchronously clear specific samples from all storage units and the controller.
Args:
metadata: The BatchMeta of the corresponding data to be cleared
Raises:
RuntimeError: If clear operation fails
"""
try:
if not hasattr(self, "storage_manager") or self.storage_manager is None:
raise RuntimeError(
f"[{self.client_id}]: Storage manager not initialized. "
"Call initialize_storage_manager() before performing storage operations."
)
if metadata.size == 0:
logger.warning(f"[{self.client_id}]: Empty BatchMeta provided to clear_samples. No action taken.")
return
if not self._controller:
raise RuntimeError("No controller registered")
await self._clear_meta_in_controller(metadata)
await self.storage_manager.clear_data(metadata)
logger.debug(f"[{self.client_id}]: Clear operation for batch {metadata} completed.")
except Exception as e:
raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e
@with_controller_socket
async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
"""Clear metadata in the controller.
Args:
metadata: The BatchMeta of the corresponding data to be cleared
socket: ZMQ socket (injected by decorator)
Raises:
RuntimeError: If clear operation fails
"""
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.CLEAR_META,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={"global_indexes": metadata.global_indexes, "partition_ids": metadata.partition_ids},
)
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
raise RuntimeError("Failed to clear samples metadata in controller.")
@with_controller_socket
async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta:
"""Get metadata required for the whole partition from controller.
Args:
partition_id: Partition id to get partition metadata for
socket: ZMQ socket (injected by decorator)
Returns:
BatchMeta: Records the metadata of a batch of data samples.
Raises:
RuntimeError: If controller returns error response
"""
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_PARTITION_META,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={"partition_id": partition_id},
)
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
if response_msg.request_type != ZMQRequestType.GET_PARTITION_META_RESPONSE:
raise RuntimeError("Failed to get metadata for clear operation.")
return response_msg.body["metadata"]
@with_controller_socket
async def _clear_partition_in_controller(self, partition_id, socket=None):
"""Clear the whole partition in the controller.
Args:
partition_id: Partition id to clear metadata for
socket: ZMQ socket (injected by decorator)
Raises:
RuntimeError: If clear operation fails
"""
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.CLEAR_PARTITION,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={"partition_id": partition_id},
)
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
if response_msg.request_type != ZMQRequestType.CLEAR_PARTITION_RESPONSE:
raise RuntimeError(f"Failed to clear partition {partition_id} in controller.")
@with_controller_socket
async def async_get_consumption_status(
self,
task_name: str,
partition_id: str,
socket: zmq.asyncio.Socket | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""Get consumption status for current partition in a specific task.
Args:
task_name: Name of the task to check consumption for
partition_id: Partition id to check consumption status for
socket: ZMQ async socket for message transmission (injected by decorator)
Returns:
Tuple of:
- Partition global index tensor
- Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Get consumption status
>>> global_index, consumption_status = asyncio.run(client.async_get_consumption_status(
... task_name="generate_sequences",
... partition_id="train_0"
... ))
>>> print(f"Global index: {global_index}, Consumption status: {consumption_status}")
"""
assert socket is not None
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_CONSUMPTION,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
"partition_id": partition_id,
"task_name": task_name,
},
)
try:
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}]: Client get consumption response: {response_msg} "
f"from controller {self._controller.id}"
)
if response_msg.request_type == ZMQRequestType.CONSUMPTION_RESPONSE:
global_index = response_msg.body.get("global_index")
consumption_status = response_msg.body.get("consumption_status")
return global_index, consumption_status
else:
raise RuntimeError(
f"[{self.client_id}]: Failed to get consumption status from controller {self._controller.id}: "
f"{response_msg.body.get('message', 'Unknown error')}"
)
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e
@with_controller_socket
async def async_get_production_status(
self,
data_fields: list[str],
partition_id: str,
socket: zmq.asyncio.Socket | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""Get production status for specific data fields and partition.
Args:
data_fields: Data fields to check production status for
partition_id: Partition id to check production status for
socket: ZMQ async socket for message transmission (injected by decorator)
Returns:
Tuple of:
- Partition global index tensor
- Production status tensor for the specified fields. 1 for ready, 0 for not ready.
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Get production status
>>> global_index, production_status = asyncio.run(client.async_get_production_status(
... data_fields=["input_ids", "attention_mask"],
... partition_id="train_0"
... ))
>>> print(f"Global index: {global_index}, Production status: {production_status}")
"""
assert socket is not None
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_PRODUCTION,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
"partition_id": partition_id,
"data_fields": data_fields,
},
)
try:
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}]: Client get production response: {response_msg} "
f"from controller {self._controller.id}"
)
if response_msg.request_type == ZMQRequestType.PRODUCTION_RESPONSE:
global_index = response_msg.body.get("global_index")
production_status = response_msg.body.get("production_status")
return global_index, production_status
else:
raise RuntimeError(
f"[{self.client_id}]: Failed to get production status from controller {self._controller.id}: "
f"{response_msg.body.get('message', 'Unknown error')}"
)
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in get_data_production_status: {str(e)}") from e
async def async_check_consumption_status(
self,
task_name: str,
partition_id: str,
) -> bool:
"""Check if all samples for current partition have been consumed by a specific task.
Args:
task_name: Name of the task to check consumption for
partition_id: Partition id to check consumption status for
Returns:
bool: True if all samples have been consumed by the task, False otherwise
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Check if all samples have been consumed
>>> is_consumed = asyncio.run(client.async_check_consumption_status(
... task_name="generate_sequences",
... partition_id="train_0"
... ))
>>> print(f"All samples consumed: {is_consumed}")
"""
_, consumption_status = await self.async_get_consumption_status(
task_name=task_name,
partition_id=partition_id,
)
if consumption_status is None or consumption_status.numel() == 0:
return False
return torch.all(consumption_status == 1).item()
async def async_check_production_status(
self,
data_fields: list[str],
partition_id: str,
) -> bool:
"""Check if the all specific fields of samples for current partition are ready
(produced) for consumption.
Args:
data_fields: Data fields to check production status for
partition_id: Partition id to check production status for
Returns:
bool: True if all samples have been produced and ready, False otherwise
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Check if all samples are ready for consumption
>>> is_ready = asyncio.run(client.async_check_production_status(
... data_fields=["input_ids", "attention_mask"],
... partition_id="train_0"
... ))
>>> print(f"All samples ready: {is_ready}")
"""
_, production_status = await self.async_get_production_status(
data_fields=data_fields,
partition_id=partition_id,
)
if production_status is None:
return False
return torch.all(production_status == 1).item()
@with_controller_socket
async def async_reset_consumption(
self,
partition_id: str,
task_name: str | None = None,
socket: zmq.asyncio.Socket | None = None,
) -> bool:
"""Asynchronously reset consumption status for a partition.
This allows the same data to be re-consumed, useful for debugging scenarios
where the same rollout data needs to be trained multiple times.
Args:
partition_id: Partition id to reset consumption status for
task_name: Name of the task to reset. If None, resets all tasks.
socket: ZMQ async socket for message transmission (injected by decorator)
Returns:
bool: True if reset was successful, False otherwise
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Reset consumption for train task to re-train on same data
>>> success = asyncio.run(client.async_reset_consumption(
... partition_id="train_0",
... task_name="train"
... ))
>>> print(f"Reset successful: {success}")
"""
assert socket is not None
body = {"partition_id": partition_id}
if task_name is not None:
body["task_name"] = task_name
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.RESET_CONSUMPTION,
sender_id=self.client_id,
receiver_id=self._controller.id,
body=body,
)
try:
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}]: Client reset consumption response: {response_msg} "
f"from controller {self._controller.id}"
)
if response_msg.request_type == ZMQRequestType.RESET_CONSUMPTION_RESPONSE:
success = response_msg.body.get("success", False)
if not success:
logger.warning(f"[{self.client_id}]: Reset consumption failed: {response_msg.body.get('message')}")
return success
else:
raise RuntimeError(
f"[{self.client_id}]: Failed to reset consumption from controller {self._controller.id}: "
f"{response_msg.body.get('message', 'Unknown error')}"
)
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e
@with_controller_socket
async def async_get_partition_list(
self,
socket: zmq.asyncio.Socket | None = None,
) -> list[str]:
"""Asynchronously fetch the list of partition ids from the controller.
Args:
socket: ZMQ socket (injected by decorator)
Returns:
list[str]: List of partition ids managed by the controller
Example:
>>> partition_ids = asyncio.run(client.get_partition_list())
>>> print(f"Available partitions: {partition_ids}")
"""
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_LIST_PARTITIONS,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={},
)
try:
assert socket is not None
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}]: Client get partition list response: {response_msg} "
f"from controller {self._controller.id}"
)
if response_msg.request_type == ZMQRequestType.LIST_PARTITIONS_RESPONSE:
partition_ids = response_msg.body.get("partition_ids", [])
return partition_ids
else:
raise RuntimeError(
f"[{self.client_id}]: Failed to get partition list from controller {self._controller.id}: "
f"{response_msg.body.get('message', 'Unknown error')}"
)
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e
@with_controller_socket
async def async_kv_retrieve_meta(
self,
keys: list[str] | str,
partition_id: str,
create: bool = False,
socket: zmq.asyncio.Socket | None = None,
) -> BatchMeta:
"""Asynchronously retrieve BatchMeta by user-defined keys.
Retrieves metadata for given keys from a specified partition.
If keys do not exist and `create=True`, they will be automatically registered.
Args:
keys: List of keys to retrieve.
partition_id: The ID of the logical partition to search for keys.
create: If True, automatically create entries for missing keys.
socket: ZMQ socket injected by @with_controller_socket.
Returns:
BatchMeta: Metadata for the requested keys.
"""
if isinstance(keys, str):
keys = [keys]
elif isinstance(keys, list):
if len(keys) < 1:
raise ValueError("Received an empty list as keys.")
if not all(isinstance(k, str) for k in keys):
raise TypeError("Not all elements in `keys` are strings.")
else:
raise TypeError("Only string or list of strings are allowed as `keys`.")
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_RETRIEVE_META,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
"keys": keys,
"partition_id": partition_id,
"create": create,
},
)
try:
assert socket is not None, "Socket must be initialized before use"
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}] Received KV_RETRIEVE_META response: {response_msg} "
f"from controller {self._controller.id}"
)
if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_META_RESPONSE:
return response_msg.body.get("metadata", BatchMeta.empty())
raise RuntimeError(
f"[{self.client_id}] Failed to retrieve metadata {response_msg.body.get('message', 'Unknown error')}"
)
except Exception as e:
raise RuntimeError(f"[{self.client_id}] Failed in async_kv_retrieve_meta: {e}") from e
@with_controller_socket
async def async_kv_retrieve_keys(
self,
global_indexes: list[int] | int,
partition_id: str,
socket: zmq.asyncio.Socket | None = None,
) -> list[str]:
"""Asynchronously retrieve keys according to global_indexes from the controller.
Args:
global_indexes: List of global_indexes to retrieve from the controller
partition_id: The ID of the logical partition to search for global_indexes.
socket: ZMQ socket (injected by decorator)
Returns:
keys: list of keys of the corresponding global_indexes
Raises:
TypeError: If `global_indexes` is not a list of int or an int
RuntimeError: If some indexes in `global_indexes` do not have corresponding keys
"""
if isinstance(global_indexes, int):
global_indexes = [global_indexes]
elif isinstance(global_indexes, list):
if len(global_indexes) < 1:
raise ValueError("Received an empty list as `global_indexes`.")
if not all(isinstance(idx, int) for idx in global_indexes):
raise TypeError("Not all elements in `global_indexes` are int.")
else:
raise TypeError("Only int or list of int are allowed as `global_indexes`.")
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_RETRIEVE_KEYS,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={"global_indexes": global_indexes, "partition_id": partition_id},
)
try:
assert socket is not None
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}]: Client get kv_retrieve_indexes response: {response_msg} "
f"from controller {self._controller.id}"
)
if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE:
keys = response_msg.body.get("keys", [])
if len(keys) != len(global_indexes):
raise RuntimeError("Some global_indexes have no corresponding keys!")
return keys
else:
raise RuntimeError(
f"[{self.client_id}]: Failed to retrieve indexes from controller {self._controller.id}: "
f"{response_msg.body.get('message', 'Unknown error')}"
)
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e
@with_controller_socket
async def async_kv_list(
self,
partition_id: str | None = None,
socket: zmq.asyncio.Socket | None = None,
) -> dict[str, dict[str, Any]]:
"""Asynchronously retrieve keys and custom_meta from the controller for one or all partitions.
Args:
partition_id: The specific partition_id to query.
If None (default), returns keys from all partitions.
socket: ZMQ socket (injected by decorator)
Returns:
A nested dictionary mapping partition IDs to their keys and metadata.
Structure:
{
"partition_id": {
"key_name": {
"tag1": <value>,
... (other metadata)
},
...,
},
...
}
"""
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_LIST,
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
"partition_id": partition_id,
},
)
try:
assert socket is not None
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}]: Client get kv_list response: {response_msg} from controller {self._controller.id}"
)
if response_msg.request_type == ZMQRequestType.KV_LIST_RESPONSE:
partition_info = response_msg.body.get("partition_info", {})
return partition_info
else:
raise RuntimeError(
f"[{self.client_id}]: Failed to list keys from controller {self._controller.id}: "
f"{response_msg.body.get('message', 'Unknown error')}"
)
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in kv_list: {str(e)}") from e
def close(self) -> None:
"""Close the client and cleanup resources including storage manager."""
try:
if hasattr(self, "storage_manager") and self.storage_manager:
if hasattr(self.storage_manager, "close"):
self.storage_manager.close()
except Exception as e:
logger.warning(f"Error closing storage manager: {e}")
class TransferQueueClient(AsyncTransferQueueClient):
"""Synchronous client wrapper for TransferQueue.
Provides synchronous versions of all async methods for convenience.
"""
def __init__(
self,
client_id: str,
controller_info: ZMQServerInfo,
):
"""Initialize the synchronous TransferQueue client.
Args:
client_id: Unique identifier for this client instance
controller_info: Single controller ZMQ server information
"""
super().__init__(
client_id,
controller_info,
)
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._start_loop, daemon=True)
self._thread.start()
self._bind_sync_methods()
def _start_loop(self):
"""Start the synchronous loop."""
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
def _bind_sync_methods(
self,
):
"""Convert and bind synchronous methods."""
def _run(coro):
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result()
def _make_sync(async_method):
def wrapper(*args, **kwargs):
return _run(async_method(*args, **kwargs))
return wrapper
self._put = _make_sync(self.async_put)
self._get_meta = _make_sync(self.async_get_meta)
self._get_data = _make_sync(self.async_get_data)
self._clear_partition = _make_sync(self.async_clear_partition)
self._clear_samples = _make_sync(self.async_clear_samples)
self._get_consumption_status = _make_sync(self.async_get_consumption_status)
self._get_production_status = _make_sync(self.async_get_production_status)
self._check_consumption_status = _make_sync(self.async_check_consumption_status)
self._check_production_status = _make_sync(self.async_check_production_status)
self._get_partition_list = _make_sync(self.async_get_partition_list)
self._set_custom_meta = _make_sync(self.async_set_custom_meta)
self._reset_consumption = _make_sync(self.async_reset_consumption)
self._kv_retrieve_meta = _make_sync(self.async_kv_retrieve_meta)
self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys)
self._kv_list = _make_sync(self.async_kv_list)
def get_meta(
self,
data_fields: list[str],
batch_size: int,
partition_id: str,
mode: str = "fetch",
task_name: str | None = None,
sampling_config: dict[str, Any] | None = None,
) -> BatchMeta:
"""Synchronously fetch data metadata from the controller via ZMQ.
Args:
data_fields: List of data field names to retrieve metadata for
batch_size: Number of samples to request in the batch
partition_id: Current data partition id
mode: Data fetch mode. Options:
- 'fetch': Get ready data only
- 'force_fetch': Get data regardless of readiness (may return unready samples)
- 'insert': Internal usage - should not be used by users
task_name: Optional task name associated with the request
sampling_config: Optional sampling configuration for custom samplers.
Returns:
BatchMeta: Metadata object containing data structure, sample information, and readiness status
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Example 1: Basic fetch metadata
>>> batch_meta = client.get_meta(
... data_fields=["input_ids", "attention_mask"],
... batch_size=4,
... partition_id="train_0",
... mode="fetch",
... task_name="generate_sequences"
... )
>>> print(batch_meta.is_ready) # True if all samples ready
>>>
>>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
>>> batch_meta = client.get_meta(
... data_fields=["input_ids", "attention_mask"],
... batch_size=8,
... partition_id="train_0",
... mode="fetch",
... task_name="generate_sequences",
... sampling_config={"n_samples_per_prompt": 4}
... )
>>> print(batch_meta.is_ready) # True if all samples ready
>>>
>>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
>>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
>>> batch_meta = client.get_meta(
... partition_id="train_0", # optional
... mode="force_fetch",
... )
>>> print(batch_meta.is_ready) # May be False if some samples not ready
"""
return self._get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
mode=mode,
task_name=task_name,
sampling_config=sampling_config,
)
def set_custom_meta(self, metadata: BatchMeta) -> None:
"""Synchronously send custom metadata to the controller.
This method sends per-sample custom metadata (custom_meta) to the controller.
The custom_meta is stored in the controller and can be retrieved along with
the BatchMeta in subsequent get_meta calls.
Args:
metadata: BatchMeta containing the samples and their custom metadata to store.
The custom_meta should be set using BatchMeta.update_custom_meta()
before calling this method.
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Create batch with custom metadata
>>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=2, ...)
>>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
>>> client.set_custom_meta(batch_meta)
"""
return self._set_custom_meta(metadata=metadata)
def put(
self,
data: TensorDict,
metadata: BatchMeta | None = None,
partition_id: str | None = None,
data_parser: Callable[[Any], Any] | None = None,
) -> BatchMeta:
"""Synchronously write data to storage units based on metadata.
If metadata is not provided, it will be created automatically using insert mode
with the provided data fields and partition_id.
During put, the custom_meta in metadata will update the corresponding custom_meta in
TransferQueue Controller.
Note:
When using multiple workers for distributed execution, there may be data
ordering inconsistencies between workers during put operations.
Args:
data: Data to write as TensorDict
metadata: Records the metadata of a batch of data samples, containing index and
storage unit information. If None, metadata will be auto-generated.
partition_id: Target data partition id (required if metadata is not provided)
data_parser: Optional callable to parse reference data (e.g., URLs) into real
content. The input is a slice of the `data` parameter, in plain
dict format (not TensorDict), mapping field_name -> batched values.
For a regular tensor column the value is a batched tensor; for
nested tensors (jagged or strided) and NonTensorStack columns
the values are extracted into a list. It must modify values
in-place based on the original keys; do not add or remove keys.
The number of elements per column must also remain unchanged.
Do not change the inner order of values within each column.
Only supported by SimpleStorage.
Returns:
BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
metadata; will be updated in a future version to reflect the post-put state)
Raises:
ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
RuntimeError: If storage operation fails
Example:
>>> batch_size = 4
>>> seq_len = 16
>>> current_partition_id = "train_0"
>>> # Example 1: Normal usage with existing metadata
>>> batch_meta = client.get_meta(
... data_fields=["prompts", "attention_mask"],
... batch_size=batch_size,
... partition_id=current_partition_id,
... mode="fetch",
... task_name="generate_sequences",
... )
>>> batch = client.get_data(batch_meta)
>>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
>>> client.put(data=output, metadata=batch_meta)
>>>
>>> # Example 2: Initial data insertion without pre-existing metadata
>>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
>>> # Please make sure the corresponding partition_id is empty before calling the async_put()
>>> # without metadata.
>>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
>>> # interleave the initial data if n_sample > 1 before calling the async_put().
>>> original_prompts = torch.randn(batch_size, seq_len)
>>> n_samples = 4
>>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
>>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
>>> # This will create metadata in "insert" mode internally.
>>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id)
"""
return self._put(data=data, metadata=metadata, partition_id=partition_id, data_parser=data_parser)
def get_data(self, metadata: BatchMeta) -> TensorDict:
"""Synchronously fetch data from storage units and organize into TensorDict.
Args:
metadata: Batch metadata containing data location information and global indexes
Returns:
TensorDict containing:
- Requested data fields (e.g., "prompts", "attention_mask")
Example:
>>> batch_meta = client.get_meta(
... data_fields=["prompts", "attention_mask"],
... batch_size=4,
... partition_id="train_0",
... mode="fetch",
... task_name="generate_sequences",
... )
>>> batch = client.get_data(batch_meta)
>>> print(batch)
>>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
"""
return self._get_data(metadata=metadata)
def clear_partition(self, partition_id: str):
"""Synchronously clear the whole partition from all storage units and the controller.
Args:
partition_id: The partition id to clear data for
Raises:
RuntimeError: If clear operation fails
"""
return self._clear_partition(partition_id=partition_id)
def clear_samples(self, metadata: BatchMeta):
"""Synchronously clear specific samples from all storage units and the controller.
Args:
metadata: The BatchMeta of the corresponding data to be cleared
Raises:
RuntimeError: If clear operation fails
"""
return self._clear_samples(metadata=metadata)
def get_consumption_status(
self,
task_name: str,
partition_id: str,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""Synchronously get consumption status for a specific task and partition.
Args:
task_name: Name of the task to check consumption for
partition_id: Partition id to check consumption status for
Returns:
Tuple of:
- Partition global index tensor
- Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> global_index, consumption_status = client.get_consumption_status(
... task_name="generate_sequences",
... partition_id="train_0"
... )
>>> print(f"Global index: {global_index}, Consumption status: {consumption_status}")
"""
return self._get_consumption_status(task_name, partition_id)
def get_production_status(
self,
data_fields: list[str],
partition_id: str,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""Synchronously get production status for specific data fields and partition.
Args:
data_fields: Data fields to check production status for
partition_id: Partition id to check production status for
Returns:
Tuple of:
- Partition global index tensor
- Production status tensor for the specified fields. 1 for ready, 0 for not ready.
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> global_index, production_status = client.get_production_status(
... data_fields=["input_ids", "attention_mask"],
... partition_id="train_0"
... )
>>> print(f"Global index: {global_index}, Production status: {production_status}")
"""
return self._get_production_status(data_fields=data_fields, partition_id=partition_id)
def check_consumption_status(self, task_name: str, partition_id: str) -> bool:
"""Synchronously check if all samples for a partition have been consumed by a specific task.
Args:
task_name: Name of the task to check consumption for
partition_id: Partition id to check consumption status for
Returns:
bool: True if all samples have been consumed by the task, False otherwise
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Check if all samples have been consumed
>>> is_consumed = client.check_consumption_status(
... task_name="generate_sequences",
... partition_id="train_0"
... )
>>> print(f"All samples consumed: {is_consumed}")
"""
return self._check_consumption_status(task_name=task_name, partition_id=partition_id)
def check_production_status(self, data_fields: list[str], partition_id: str) -> bool:
"""Synchronously check if all samples for a partition are ready (produced) for consumption.
Args:
data_fields: Data fields to check production status for
partition_id: Partition id to check production status for
Returns:
bool: True if all samples have been produced and ready, False otherwise
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Check if all samples are ready for consumption
>>> is_ready = client.check_production_status(
... data_fields=["input_ids", "attention_mask"],
... partition_id="train_0"
... )
>>> print(f"All samples ready: {is_ready}")
"""
return self._check_production_status(data_fields=data_fields, partition_id=partition_id)
def reset_consumption(self, partition_id: str, task_name: str | None = None) -> bool:
"""Synchronously reset consumption status for a partition.
This allows the same data to be re-consumed, useful for debugging scenarios
where the same rollout data needs to be trained multiple times.
Args:
partition_id: Partition id to reset consumption status for
task_name: Name of the task to reset. If None, resets all tasks.
Returns:
bool: True if reset was successful, False otherwise
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Reset consumption for train task to re-train on same data
>>> success = client.reset_consumption(
... partition_id="train_0",
... task_name="train"
... )
>>> print(f"Reset successful: {success}")
"""
return self._reset_consumption(partition_id, task_name)
def get_partition_list(
self,
) -> list[str]:
"""Synchronously fetch the list of partition ids from the controller.
Returns:
list[str]: List of partition ids managed by the controller
Example:
>>> partition_ids = client.get_partition_list()
>>> print(f"Available partitions: {partition_ids}")
"""
return self._get_partition_list()
def kv_retrieve_meta(
self,
keys: list[str] | str,
partition_id: str,
create: bool = False,
) -> BatchMeta:
"""Synchronously retrieve BatchMeta by user-defined keys.
Retrieves metadata for given keys from a specified partition.
If keys do not exist and `create=True`, they will be automatically registered.
Args:
keys: List of keys to retrieve from the controller.
partition_id: Logical partition to query.
create: If True, automatically create entries for non-existent keys.
Returns:
BatchMeta: Metadata for the requested keys.
Raises:
TypeError: If `keys` is not a list of string or a string
"""
return self._kv_retrieve_meta(keys=keys, partition_id=partition_id, create=create)
def kv_retrieve_keys(
self,
global_indexes: list[int] | int,
partition_id: str,
) -> BatchMeta:
"""Synchronously retrieve keys according to global_indexes from the controller.
Args:
global_indexes: List of global_indexes to retrieve from the controller
partition_id: The ID of the logical partition to search for global_indexes.
Returns:
keys: list of keys of the corresponding global_indexes
Raises:
TypeError: If `global_indexes` is not a list of int or an int
RuntimeError: If some indexes in `global_indexes` do not have corresponding keys
"""
return self._kv_retrieve_keys(global_indexes=global_indexes, partition_id=partition_id)
def kv_list(
self,
partition_id: str | None = None,
) -> dict[str, dict[str, Any]]:
"""Synchronously retrieve keys and custom_meta from the controller for one or all partitions.
Args:
partition_id: The specific partition_id to query.
If None (default), returns keys from all partitions.
socket: ZMQ socket (injected by decorator)
Returns:
A nested dictionary mapping partition IDs to their keys and metadata.
Structure:
{
"partition_id": {
"key_name": {
"tag1": <value>,
... (other metadata)
},
...,
},
...
}
"""
return self._kv_list(partition_id=partition_id)
def close(self) -> None:
"""Close the client and cleanup resources including event loop and thread."""
if hasattr(self, "_loop") and self._loop is not None:
self._loop.call_soon_threadsafe(self._loop.stop)
if hasattr(self, "_thread") and self._thread is not None:
self._thread.join(timeout=5.0)
if self._thread.is_alive():
logger.warning(f"[{self.client_id}]: Background thread did not stop within timeout")
try:
self._loop.close()
except Exception as e:
logger.warning(f"[{self.client_id}]: Error closing event loop: {e}")
super().close()