# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import threading
from typing import Any, Callable

import torch
import zmq
import zmq.asyncio
from tensordict import TensorDict

from transfer_queue.metadata import BatchMeta
from transfer_queue.storage import StorageManagerFactory
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.zmq_utils import (
    ZMQMessage,
    ZMQRequestType,
    ZMQServerInfo,
    with_zmq_socket,
)

logger = get_logger(__name__)

TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8))

# Pre-bound decorator for controller socket operations.
with_controller_socket = with_zmq_socket(
    "request_handle_socket",
    get_identity=lambda self: self.client_id,
    get_peer=lambda self, target: self._controller,
)


class AsyncTransferQueueClient:
    """Asynchronous client for interacting with TransferQueue controller and storage systems.

    This client provides async methods for data transfer operations including getting metadata,
    reading data from storage, writing data to storage, and clearing data.
    """

    def __init__(
        self,
        client_id: str,
        controller_info: ZMQServerInfo,
    ):
        """Initialize the asynchronous TransferQueue client.

        Args:
            client_id: Unique identifier for this client instance
            controller_info: Single controller ZMQ server information
        """
        if controller_info is None:
            raise ValueError("controller_info cannot be None")
        if not isinstance(controller_info, ZMQServerInfo):
            raise TypeError(f"controller_info must be ZMQServerInfo, got {type(controller_info)}")
        self.client_id = client_id
        self._controller: ZMQServerInfo = controller_info
        logger.info(f"[{self.client_id}]: Registered Controller server {controller_info.id} at {controller_info.ip}")

    def initialize_storage_manager(
        self,
        manager_type: str,
        config: dict[str, Any],
    ):
        """Initialize the storage manager.

        Args:
            manager_type: Type of storage manager to create. Supported types include:
                          AsyncSimpleStorageManager, KVStorageManager (under development), etc.
            config: Configuration dictionary for the storage manager.
                    For AsyncSimpleStorageManager, must contain the following required keys:
                    - zmq_info: ZMQ server information about the storage units

        """
        self.storage_manager = StorageManagerFactory.create(
            manager_type, controller_info=self._controller, config=config
        )

    # ==================== Basic API ====================
    @with_controller_socket
    async def async_get_meta(
        self,
        data_fields: list[str],
        batch_size: int,
        partition_id: str,
        mode: str = "fetch",
        task_name: str | None = None,
        sampling_config: dict[str, Any] | None = None,
        socket: zmq.asyncio.Socket | None = None,
    ) -> BatchMeta:
        """Asynchronously fetch data metadata from the controller via ZMQ.

        Args:
            data_fields: List of data field names to retrieve metadata for
            batch_size: Number of samples to request in the batch
            partition_id: Current data partition id
            mode: Data fetch mode. Options:
                - 'fetch': Get ready data only
                - 'force_fetch': Get data regardless of readiness (may return unready samples)
                - 'insert': Internal usage - should not be used by users
            task_name: Optional task name associated with the request
            sampling_config: Optional sampling configuration for custom samplers.
            socket: ZMQ async socket for message transmission (injected by decorator)

        Returns:
            BatchMeta: Metadata object containing data structure, sample information, and readiness status

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Example 1: Basic fetch metadata
            >>> batch_meta = asyncio.run(client.async_get_meta(
            ...     data_fields=["input_ids", "attention_mask"],
            ...     batch_size=4,
            ...     partition_id="train_0",
            ...     mode="fetch",
            ...     task_name="generate_sequences"
            ... ))
            >>> print(batch_meta.is_ready)  # True if all samples ready
            >>>
            >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
            >>> batch_meta = asyncio.run(client.async_get_meta(
            ...     data_fields=["input_ids", "attention_mask"],
            ...     batch_size=8,
            ...     partition_id="train_0",
            ...     mode="fetch",
            ...     task_name="generate_sequences",
            ... ))
            >>> print(batch_meta.is_ready)  # True if all samples ready
            >>>
            >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
            >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
            >>> batch_meta = asyncio.run(client.async_get_meta(
            ...     partition_id="train_0",   # optional
            ...     mode="force_fetch",
            ... ))
            >>> print(batch_meta.is_ready)  # May be False if some samples not ready
        """
        assert socket is not None
        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.GET_META,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={
                "data_fields": data_fields,
                "batch_size": batch_size,
                "partition_id": partition_id,
                "mode": mode,
                "task_name": task_name,
                "sampling_config": sampling_config,
            },
        )

        await socket.send_multipart(request_msg.serialize())
        response_serialized = await socket.recv_multipart(copy=False)
        response_msg = ZMQMessage.deserialize(response_serialized)
        logger.debug(
            f"[{self.client_id}]: Client get_meta response: {response_msg} from controller {self._controller.id}"
        )

        if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE:
            return response_msg.body["metadata"]
        else:
            raise RuntimeError(
                f"[{self.client_id}]: Failed to get metadata from controller {self._controller.id}: "
                f"{response_msg.body.get('message', 'Unknown error')}"
            )

    @with_controller_socket
    async def async_set_custom_meta(
        self,
        metadata: BatchMeta,
        socket: zmq.asyncio.Socket | None = None,
    ) -> None:
        """
        Asynchronously send custom metadata to the controller.

        This method sends per-sample custom metadata (custom_meta) to the controller.
        The custom_meta is stored in the controller and can be retrieved along with
        the BatchMeta in subsequent get_meta calls.

        Args:
            metadata: BatchMeta containing the samples and their custom metadata to store.
                     The custom_meta should be set using BatchMeta.update_custom_meta()
                     before calling this method.
            socket: ZMQ async socket for message transmission (injected by decorator)

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Create batch with custom metadata
            >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...)
            >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
            >>> asyncio.run(client.async_set_custom_meta(batch_meta))
        """
        assert socket is not None

        if not self._controller:
            raise RuntimeError("No controller registered")

        global_indexes = metadata.global_indexes
        custom_meta = metadata.get_all_custom_meta()

        if len(global_indexes) == 0 or len(custom_meta) == 0:
            logger.debug(f"[{self.client_id}]: Empty BatchMeta or custom_meta provided. No action taken.")
            return

        # chunk metadata according to partition_ids
        metadata_chunks = metadata.chunk_by_partition()

        # Build partition_custom_meta in format: {partition_id: {global_index: {meta1:xxx, meta2:xxx}}}
        partition_custom_meta: dict[str, dict[int, dict]] = {pid: {} for pid in set(metadata.partition_ids)}

        for meta in metadata_chunks:
            custom_meta = meta.get_all_custom_meta()
            partition_custom_meta[meta.partition_ids[0]].update(
                {meta.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta))}
            )

        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.SET_CUSTOM_META,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={
                "partition_custom_meta": partition_custom_meta,
            },
        )

        await socket.send_multipart(request_msg.serialize())
        response_serialized = await socket.recv_multipart(copy=False)
        response_msg = ZMQMessage.deserialize(response_serialized)
        logger.debug(
            f"[{self.client_id}]: Client set_custom_meta response: {response_msg} from controller {self._controller.id}"
        )

        if response_msg.request_type != ZMQRequestType.SET_CUSTOM_META_RESPONSE:
            raise RuntimeError(
                f"[{self.client_id}]: Failed to set custom metadata to controller {self._controller.id}: "
                f"{response_msg.body.get('message', 'Unknown error')}"
            )

    async def async_put(
        self,
        data: TensorDict,
        metadata: BatchMeta | None = None,
        partition_id: str | None = None,
        data_parser: Callable[[Any], Any] | None = None,
    ) -> BatchMeta:
        """Asynchronously write data to storage units based on metadata.

        If metadata is not provided, it will be created automatically using insert mode
        with the provided data fields and partition_id.

        During put, the custom_meta in metadata will update the corresponding custom_meta in
        TransferQueue Controller.

        Note:
            When using multiple workers for distributed execution, there may be data
            ordering inconsistencies between workers during put operations.

        Args:
            data: Data to write as TensorDict
            metadata: Records the metadata of a batch of data samples, containing index and
                      storage unit information. If None, metadata will be auto-generated.
            partition_id: Target data partition id (required if metadata is not provided)
            data_parser: Optional callable to parse reference data (e.g., URLs) into real
                         content. The input is a slice of the `data` parameter, in plain
                         dict format (not TensorDict), mapping field_name -> batched values.
                         For a regular tensor column the value is a batched tensor; for
                         nested tensors (jagged or strided) and NonTensorStack columns
                         the values are extracted into a list. It must modify values
                         in-place based on the original keys; do not add or remove keys.
                         The number of elements per column must also remain unchanged.
                         Do not change the inner order of values within each column.
                         Only supported by SimpleStorage.

        Returns:
            BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
                       metadata; will be updated in a future version to reflect the post-put state)

        Raises:
            ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
            RuntimeError: If storage operation fails

        Example:
            >>> batch_size = 4
            >>> seq_len = 16
            >>> current_partition_id = "train_0"
            >>> # Example 1: Normal usage with existing metadata
            >>> batch_meta = asyncio.run(client.async_get_meta(
            ...     data_fields=["prompts", "attention_mask"],
            ...     batch_size=batch_size,
            ...     partition_id=current_partition_id,
            ...     mode="fetch",
            ...     task_name="generate_sequences",
            ... ))
            >>> batch = asyncio.run(client.async_get_data(batch_meta))
            >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
            >>> asyncio.run(client.async_put(data=output, metadata=batch_meta))
            >>>
            >>> # Example 2: Initial data insertion without pre-existing metadata
            >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
            >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
            >>> # without metadata.
            >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
            >>> # interleave the initial data if n_sample > 1 before calling the async_put().
            >>> original_prompts = torch.randn(batch_size, seq_len)
            >>> n_samples = 4
            >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
            >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
            >>> # This will create metadata in "insert" mode internally.
            >>> metadata = asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id))
        """

        if not hasattr(self, "storage_manager") or self.storage_manager is None:
            raise RuntimeError(
                f"[{self.client_id}]: Storage manager not initialized. "
                "Call initialize_storage_manager() before performing storage operations."
            )

        for field_name, field_data in data.items():
            if isinstance(field_data, torch.Tensor) and field_data.ndim == 1:
                logger.info(
                    f"[{self.client_id}]: Data field '{field_name}' is a tensor with only one dimension. "
                    f"You may receive 2D tensors in key-value based backend."
                )

        if metadata is None:
            if partition_id is None:
                raise ValueError("partition_id must be provided if metadata is not given")

            metadata = await self.async_get_meta(
                data_fields=list(data.keys()),
                batch_size=data.batch_size[0],
                partition_id=partition_id,
                mode="insert",
            )

        if not metadata or metadata.size == 0:
            raise ValueError("metadata cannot be none or empty")

        with limit_pytorch_auto_parallel_threads(
            target_num_threads=TQ_NUM_THREADS, info=f"[{self.client_id}] async_put"
        ):
            await self.storage_manager.put_data(data, metadata, data_parser=data_parser)

        await self.async_set_custom_meta(metadata)

        logger.debug(
            f"[{self.client_id}]: partition {partition_id} put {metadata.size} samples to storage units successfully."
        )

        # update metadata after put
        metadata = metadata.add_fields(data)

        return metadata

    async def async_get_data(self, metadata: BatchMeta) -> TensorDict:
        """Asynchronously fetch data from storage units and organize into TensorDict.

        Args:
            metadata: Batch metadata containing data location information and global indexes

        Returns:
            TensorDict containing:
                - Requested data fields (e.g., "prompts", "attention_mask")

        Example:
            >>> batch_meta = asyncio.run(client.async_get_meta(
            ...     data_fields=["prompts", "attention_mask"],
            ...     batch_size=4,
            ...     partition_id="train_0",
            ...     mode="fetch",
            ...     task_name="generate_sequences",
            ... ))
            >>> batch = asyncio.run(client.async_get_data(batch_meta))
            >>> print(batch)
            >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
        """

        if not hasattr(self, "storage_manager") or self.storage_manager is None:
            raise RuntimeError(
                f"[{self.client_id}]: Storage manager not initialized. "
                "Call initialize_storage_manager() before performing storage operations."
            )

        if not metadata or metadata.size == 0 or len(metadata.field_names) == 0:
            logger.warning(f"[{self.client_id}]: Empty BatchMeta provided to get_data. Returning empty TensorDict.")
            return TensorDict({}, batch_size=0)

        with limit_pytorch_auto_parallel_threads(
            target_num_threads=TQ_NUM_THREADS, info=f"[{self.client_id}] async_get_data"
        ):
            results = await self.storage_manager.get_data(metadata)

        logger.debug(f"[{self.client_id}]: get_data with {metadata.size} samples successfully.")

        return results

    async def async_clear_partition(self, partition_id: str):
        """Asynchronously clear the whole partition from all storage units and the controller.

        Args:
            partition_id: The partition id to clear data for

        Raises:
            RuntimeError: If clear operation fails
        """
        try:
            if not hasattr(self, "storage_manager") or self.storage_manager is None:
                raise RuntimeError(
                    f"[{self.client_id}]: Storage manager not initialized. "
                    "Call initialize_storage_manager() before performing storage operations."
                )

            if not self._controller:
                raise RuntimeError("No controller registered")

            metadata = await self._get_partition_meta(partition_id)

            if not metadata:
                logger.warning(f"Try to clear an non-exist partition {partition_id}. No action will be taken.")
                return

            # Clear the controller metadata
            await self._clear_partition_in_controller(partition_id)

            # Clear storage unit data
            await self.storage_manager.clear_data(metadata)

            logger.debug(f"[{self.client_id}]: Clear operation for partition_id {partition_id} completed.")
        except Exception as e:
            raise RuntimeError(f"Error in clear operation: {str(e)}") from e

    async def async_clear_samples(self, metadata: BatchMeta):
        """Asynchronously clear specific samples from all storage units and the controller.

        Args:
            metadata: The BatchMeta of the corresponding data to be cleared

        Raises:
            RuntimeError: If clear operation fails
        """
        try:
            if not hasattr(self, "storage_manager") or self.storage_manager is None:
                raise RuntimeError(
                    f"[{self.client_id}]: Storage manager not initialized. "
                    "Call initialize_storage_manager() before performing storage operations."
                )

            if metadata.size == 0:
                logger.warning(f"[{self.client_id}]: Empty BatchMeta provided to clear_samples. No action taken.")
                return

            if not self._controller:
                raise RuntimeError("No controller registered")

            # Clear the controller metadata
            await self._clear_meta_in_controller(metadata)

            # Clear storage unit data
            await self.storage_manager.clear_data(metadata)

            logger.debug(f"[{self.client_id}]: Clear operation for batch {metadata} completed.")
        except Exception as e:
            raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e

    @with_controller_socket
    async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
        """Clear metadata in the controller.

        Args:
            metadata: The BatchMeta of the corresponding data to be cleared
            socket: ZMQ socket (injected by decorator)

        Raises:
            RuntimeError: If clear operation fails
        """

        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.CLEAR_META,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={"global_indexes": metadata.global_indexes, "partition_ids": metadata.partition_ids},
        )

        await socket.send_multipart(request_msg.serialize())
        response_serialized = await socket.recv_multipart(copy=False)
        response_msg = ZMQMessage.deserialize(response_serialized)

        if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
            raise RuntimeError("Failed to clear samples metadata in controller.")

    @with_controller_socket
    async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta:
        """Get metadata required for the whole partition from controller.

        Args:
            partition_id: Partition id to get partition metadata for
            socket: ZMQ socket (injected by decorator)

        Returns:
            BatchMeta: Records the metadata of a batch of data samples.

        Raises:
            RuntimeError: If controller returns error response
        """
        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.GET_PARTITION_META,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={"partition_id": partition_id},
        )

        await socket.send_multipart(request_msg.serialize())
        response_serialized = await socket.recv_multipart(copy=False)
        response_msg = ZMQMessage.deserialize(response_serialized)

        if response_msg.request_type != ZMQRequestType.GET_PARTITION_META_RESPONSE:
            raise RuntimeError("Failed to get metadata for clear operation.")

        return response_msg.body["metadata"]

    @with_controller_socket
    async def _clear_partition_in_controller(self, partition_id, socket=None):
        """Clear the whole partition in the controller.

        Args:
            partition_id: Partition id to clear metadata for
            socket: ZMQ socket (injected by decorator)

        Raises:
            RuntimeError: If clear operation fails
        """

        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.CLEAR_PARTITION,
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={"partition_id": partition_id},
        )

        await socket.send_multipart(request_msg.serialize())
        response_serialized = await socket.recv_multipart(copy=False)
        response_msg = ZMQMessage.deserialize(response_serialized)

        if response_msg.request_type != ZMQRequestType.CLEAR_PARTITION_RESPONSE:
            raise RuntimeError(f"Failed to clear partition {partition_id} in controller.")

    # ==================== Status Query API ====================
    @with_controller_socket
    async def async_get_consumption_status(
        self,
        task_name: str,
        partition_id: str,
        socket: zmq.asyncio.Socket | None = None,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
        """Get consumption status for current partition in a specific task.

        Args:
            task_name: Name of the task to check consumption for
            partition_id: Partition id to check consumption status for
            socket: ZMQ async socket for message transmission (injected by decorator)

        Returns:
            Tuple of:
            - Partition global index tensor
            - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Get consumption status
            >>> global_index, consumption_status = asyncio.run(client.async_get_consumption_status(
            ...     task_name="generate_sequences",
            ...     partition_id="train_0"
            ... ))
            >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}")
        """

        assert socket is not None
        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.GET_CONSUMPTION,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={
                "partition_id": partition_id,
                "task_name": task_name,
            },
        )

        try:
            await socket.send_multipart(request_msg.serialize())
            response_serialized = await socket.recv_multipart(copy=False)
            response_msg = ZMQMessage.deserialize(response_serialized)
            logger.debug(
                f"[{self.client_id}]: Client get consumption response: {response_msg} "
                f"from controller {self._controller.id}"
            )

            if response_msg.request_type == ZMQRequestType.CONSUMPTION_RESPONSE:
                global_index = response_msg.body.get("global_index")
                consumption_status = response_msg.body.get("consumption_status")
                return global_index, consumption_status
            else:
                raise RuntimeError(
                    f"[{self.client_id}]: Failed to get consumption status from controller {self._controller.id}: "
                    f"{response_msg.body.get('message', 'Unknown error')}"
                )
        except Exception as e:
            raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e

    @with_controller_socket
    async def async_get_production_status(
        self,
        data_fields: list[str],
        partition_id: str,
        socket: zmq.asyncio.Socket | None = None,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
        """Get production status for specific data fields and partition.

        Args:
            data_fields: Data fields to check production status for
            partition_id: Partition id to check production status for
            socket: ZMQ async socket for message transmission (injected by decorator)

        Returns:
            Tuple of:
            - Partition global index tensor
            - Production status tensor for the specified fields. 1 for ready, 0 for not ready.

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Get production status
            >>> global_index, production_status = asyncio.run(client.async_get_production_status(
            ...     data_fields=["input_ids", "attention_mask"],
            ...     partition_id="train_0"
            ... ))
            >>> print(f"Global index: {global_index}, Production status: {production_status}")
        """
        assert socket is not None
        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.GET_PRODUCTION,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={
                "partition_id": partition_id,
                "data_fields": data_fields,
            },
        )

        try:
            await socket.send_multipart(request_msg.serialize())
            response_serialized = await socket.recv_multipart(copy=False)
            response_msg = ZMQMessage.deserialize(response_serialized)
            logger.debug(
                f"[{self.client_id}]: Client get production response: {response_msg} "
                f"from controller {self._controller.id}"
            )

            if response_msg.request_type == ZMQRequestType.PRODUCTION_RESPONSE:
                global_index = response_msg.body.get("global_index")
                production_status = response_msg.body.get("production_status")
                return global_index, production_status
            else:
                raise RuntimeError(
                    f"[{self.client_id}]: Failed to get production status from controller {self._controller.id}: "
                    f"{response_msg.body.get('message', 'Unknown error')}"
                )
        except Exception as e:
            raise RuntimeError(f"[{self.client_id}]: Error in get_data_production_status: {str(e)}") from e

    async def async_check_consumption_status(
        self,
        task_name: str,
        partition_id: str,
    ) -> bool:
        """Check if all samples for current partition have been consumed by a specific task.

        Args:
            task_name: Name of the task to check consumption for
            partition_id: Partition id to check consumption status for

        Returns:
            bool: True if all samples have been consumed by the task, False otherwise

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Check if all samples have been consumed
            >>> is_consumed = asyncio.run(client.async_check_consumption_status(
            ...     task_name="generate_sequences",
            ...     partition_id="train_0"
            ... ))
            >>> print(f"All samples consumed: {is_consumed}")
        """

        _, consumption_status = await self.async_get_consumption_status(
            task_name=task_name,
            partition_id=partition_id,
        )

        if consumption_status is None or consumption_status.numel() == 0:
            return False
        return torch.all(consumption_status == 1).item()

    async def async_check_production_status(
        self,
        data_fields: list[str],
        partition_id: str,
    ) -> bool:
        """Check if the all specific fields of samples for current partition are ready
        (produced) for consumption.

        Args:
            data_fields: Data fields to check production status for
            partition_id: Partition id to check production status for

        Returns:
            bool: True if all samples have been produced and ready, False otherwise

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Check if all samples are ready for consumption
            >>> is_ready = asyncio.run(client.async_check_production_status(
            ...     data_fields=["input_ids", "attention_mask"],
            ...     partition_id="train_0"
            ... ))
            >>> print(f"All samples ready: {is_ready}")
        """
        _, production_status = await self.async_get_production_status(
            data_fields=data_fields,
            partition_id=partition_id,
        )

        if production_status is None:
            return False
        return torch.all(production_status == 1).item()

    @with_controller_socket
    async def async_reset_consumption(
        self,
        partition_id: str,
        task_name: str | None = None,
        socket: zmq.asyncio.Socket | None = None,
    ) -> bool:
        """Asynchronously reset consumption status for a partition.

        This allows the same data to be re-consumed, useful for debugging scenarios
        where the same rollout data needs to be trained multiple times.

        Args:
            partition_id: Partition id to reset consumption status for
            task_name: Name of the task to reset. If None, resets all tasks.
            socket: ZMQ async socket for message transmission (injected by decorator)

        Returns:
            bool: True if reset was successful, False otherwise

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Reset consumption for train task to re-train on same data
            >>> success = asyncio.run(client.async_reset_consumption(
            ...     partition_id="train_0",
            ...     task_name="train"
            ... ))
            >>> print(f"Reset successful: {success}")
        """
        assert socket is not None
        body = {"partition_id": partition_id}
        if task_name is not None:
            body["task_name"] = task_name
        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.RESET_CONSUMPTION,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body=body,
        )
        try:
            await socket.send_multipart(request_msg.serialize())
            response_serialized = await socket.recv_multipart(copy=False)
            response_msg = ZMQMessage.deserialize(response_serialized)
            logger.debug(
                f"[{self.client_id}]: Client reset consumption response: {response_msg} "
                f"from controller {self._controller.id}"
            )
            if response_msg.request_type == ZMQRequestType.RESET_CONSUMPTION_RESPONSE:
                success = response_msg.body.get("success", False)
                if not success:
                    logger.warning(f"[{self.client_id}]: Reset consumption failed: {response_msg.body.get('message')}")
                return success
            else:
                raise RuntimeError(
                    f"[{self.client_id}]: Failed to reset consumption from controller {self._controller.id}: "
                    f"{response_msg.body.get('message', 'Unknown error')}"
                )
        except Exception as e:
            raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e

    @with_controller_socket
    async def async_get_partition_list(
        self,
        socket: zmq.asyncio.Socket | None = None,
    ) -> list[str]:
        """Asynchronously fetch the list of partition ids from the controller.

        Args:
            socket: ZMQ socket (injected by decorator)

        Returns:
            list[str]: List of partition ids managed by the controller

        Example:
            >>> partition_ids = asyncio.run(client.get_partition_list())
            >>> print(f"Available partitions: {partition_ids}")
        """
        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.GET_LIST_PARTITIONS,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={},
        )

        try:
            assert socket is not None
            await socket.send_multipart(request_msg.serialize())
            response_serialized = await socket.recv_multipart(copy=False)
            response_msg = ZMQMessage.deserialize(response_serialized)
            logger.debug(
                f"[{self.client_id}]: Client get partition list response: {response_msg} "
                f"from controller {self._controller.id}"
            )

            if response_msg.request_type == ZMQRequestType.LIST_PARTITIONS_RESPONSE:
                partition_ids = response_msg.body.get("partition_ids", [])
                return partition_ids
            else:
                raise RuntimeError(
                    f"[{self.client_id}]: Failed to get partition list from controller {self._controller.id}: "
                    f"{response_msg.body.get('message', 'Unknown error')}"
                )
        except Exception as e:
            raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e

    # ==================== KV Interface API ====================
    @with_controller_socket
    async def async_kv_retrieve_meta(
        self,
        keys: list[str] | str,
        partition_id: str,
        create: bool = False,
        socket: zmq.asyncio.Socket | None = None,
    ) -> BatchMeta:
        """Asynchronously retrieve BatchMeta by user-defined keys.

        Retrieves metadata for given keys from a specified partition.
        If keys do not exist and `create=True`, they will be automatically registered.

        Args:
            keys: List of keys to retrieve.
            partition_id: The ID of the logical partition to search for keys.
            create:  If True, automatically create entries for missing keys.
            socket: ZMQ socket injected by @with_controller_socket.

        Returns:
            BatchMeta: Metadata for the requested keys.
        """
        if isinstance(keys, str):
            keys = [keys]
        elif isinstance(keys, list):
            if len(keys) < 1:
                raise ValueError("Received an empty list as keys.")
            # validate all the elements are str
            if not all(isinstance(k, str) for k in keys):
                raise TypeError("Not all elements in `keys` are strings.")
        else:
            raise TypeError("Only string or list of strings are allowed as `keys`.")

        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.KV_RETRIEVE_META,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={
                "keys": keys,
                "partition_id": partition_id,
                "create": create,
            },
        )

        try:
            assert socket is not None, "Socket must be initialized before use"
            await socket.send_multipart(request_msg.serialize())
            response_serialized = await socket.recv_multipart(copy=False)
            response_msg = ZMQMessage.deserialize(response_serialized)
            logger.debug(
                f"[{self.client_id}] Received KV_RETRIEVE_META response: {response_msg} "
                f"from controller {self._controller.id}"
            )

            if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_META_RESPONSE:
                return response_msg.body.get("metadata", BatchMeta.empty())

            raise RuntimeError(
                f"[{self.client_id}] Failed to retrieve metadata {response_msg.body.get('message', 'Unknown error')}"
            )
        except Exception as e:
            raise RuntimeError(f"[{self.client_id}] Failed in async_kv_retrieve_meta: {e}") from e

    @with_controller_socket
    async def async_kv_retrieve_keys(
        self,
        global_indexes: list[int] | int,
        partition_id: str,
        socket: zmq.asyncio.Socket | None = None,
    ) -> list[str]:
        """Asynchronously retrieve keys according to global_indexes from the controller.

        Args:
            global_indexes: List of global_indexes to retrieve from the controller
            partition_id: The ID of the logical partition to search for global_indexes.
            socket: ZMQ socket (injected by decorator)

        Returns:
            keys: list of keys of the corresponding global_indexes

        Raises:
            TypeError: If `global_indexes` is not a list of int or an int
            RuntimeError: If some indexes in `global_indexes` do not have corresponding keys
        """

        if isinstance(global_indexes, int):
            global_indexes = [global_indexes]
        elif isinstance(global_indexes, list):
            if len(global_indexes) < 1:
                raise ValueError("Received an empty list as `global_indexes`.")
            # validate all the elements are int
            if not all(isinstance(idx, int) for idx in global_indexes):
                raise TypeError("Not all elements in `global_indexes` are int.")
        else:
            raise TypeError("Only int or list of int are allowed as `global_indexes`.")

        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.KV_RETRIEVE_KEYS,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={"global_indexes": global_indexes, "partition_id": partition_id},
        )

        try:
            assert socket is not None
            await socket.send_multipart(request_msg.serialize())
            response_serialized = await socket.recv_multipart(copy=False)
            response_msg = ZMQMessage.deserialize(response_serialized)
            logger.debug(
                f"[{self.client_id}]: Client get kv_retrieve_indexes response: {response_msg} "
                f"from controller {self._controller.id}"
            )

            if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE:
                keys = response_msg.body.get("keys", [])
                if len(keys) != len(global_indexes):
                    raise RuntimeError("Some global_indexes have no corresponding keys!")
                return keys
            else:
                raise RuntimeError(
                    f"[{self.client_id}]: Failed to retrieve indexes from controller {self._controller.id}: "
                    f"{response_msg.body.get('message', 'Unknown error')}"
                )
        except Exception as e:
            raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e

    @with_controller_socket
    async def async_kv_list(
        self,
        partition_id: str | None = None,
        socket: zmq.asyncio.Socket | None = None,
    ) -> dict[str, dict[str, Any]]:
        """Asynchronously retrieve keys and custom_meta from the controller for one or all partitions.

        Args:
            partition_id: The specific partition_id to query.
                If None (default), returns keys from all partitions.
            socket: ZMQ socket (injected by decorator)

        Returns:
            A nested dictionary mapping partition IDs to their keys and metadata.

            Structure:
            {
                "partition_id": {
                    "key_name": {
                        "tag1": <value>,
                        ... (other metadata)
                    },
                    ...,
                },
                ...
            }
        """

        request_msg = ZMQMessage.create(
            request_type=ZMQRequestType.KV_LIST,  # type: ignore[arg-type]
            sender_id=self.client_id,
            receiver_id=self._controller.id,
            body={
                "partition_id": partition_id,
            },
        )

        try:
            assert socket is not None
            await socket.send_multipart(request_msg.serialize())
            response_serialized = await socket.recv_multipart(copy=False)
            response_msg = ZMQMessage.deserialize(response_serialized)
            logger.debug(
                f"[{self.client_id}]: Client get kv_list response: {response_msg} from controller {self._controller.id}"
            )

            if response_msg.request_type == ZMQRequestType.KV_LIST_RESPONSE:
                partition_info = response_msg.body.get("partition_info", {})
                return partition_info
            else:
                raise RuntimeError(
                    f"[{self.client_id}]: Failed to list keys from controller {self._controller.id}: "
                    f"{response_msg.body.get('message', 'Unknown error')}"
                )
        except Exception as e:
            raise RuntimeError(f"[{self.client_id}]: Error in kv_list: {str(e)}") from e

    def close(self) -> None:
        """Close the client and cleanup resources including storage manager."""
        try:
            if hasattr(self, "storage_manager") and self.storage_manager:
                if hasattr(self.storage_manager, "close"):
                    self.storage_manager.close()
        except Exception as e:
            logger.warning(f"Error closing storage manager: {e}")


class TransferQueueClient(AsyncTransferQueueClient):
    """Synchronous client wrapper for TransferQueue.

    Provides synchronous versions of all async methods for convenience.
    """

    def __init__(
        self,
        client_id: str,
        controller_info: ZMQServerInfo,
    ):
        """Initialize the synchronous TransferQueue client.

        Args:
            client_id: Unique identifier for this client instance
            controller_info: Single controller ZMQ server information
        """
        super().__init__(
            client_id,
            controller_info,
        )

        # create new event loop in a separate thread
        self._loop = asyncio.new_event_loop()
        self._thread = threading.Thread(target=self._start_loop, daemon=True)
        self._thread.start()

        # convert and bind sync methods
        self._bind_sync_methods()

    def _start_loop(self):
        """Start the synchronous loop."""
        asyncio.set_event_loop(self._loop)
        self._loop.run_forever()

    def _bind_sync_methods(
        self,
    ):
        """Convert and bind synchronous methods."""

        def _run(coro):
            future = asyncio.run_coroutine_threadsafe(coro, self._loop)
            return future.result()

        def _make_sync(async_method):
            def wrapper(*args, **kwargs):
                return _run(async_method(*args, **kwargs))

            return wrapper

        # Bind internal sync wrappers. Public methods are defined explicitly below
        # to ensure proper type hints and documentation.
        self._put = _make_sync(self.async_put)
        self._get_meta = _make_sync(self.async_get_meta)
        self._get_data = _make_sync(self.async_get_data)
        self._clear_partition = _make_sync(self.async_clear_partition)
        self._clear_samples = _make_sync(self.async_clear_samples)
        self._get_consumption_status = _make_sync(self.async_get_consumption_status)
        self._get_production_status = _make_sync(self.async_get_production_status)
        self._check_consumption_status = _make_sync(self.async_check_consumption_status)
        self._check_production_status = _make_sync(self.async_check_production_status)
        self._get_partition_list = _make_sync(self.async_get_partition_list)
        self._set_custom_meta = _make_sync(self.async_set_custom_meta)
        self._reset_consumption = _make_sync(self.async_reset_consumption)
        self._kv_retrieve_meta = _make_sync(self.async_kv_retrieve_meta)
        self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys)
        self._kv_list = _make_sync(self.async_kv_list)

    # ==================== Basic API ====================
    def get_meta(
        self,
        data_fields: list[str],
        batch_size: int,
        partition_id: str,
        mode: str = "fetch",
        task_name: str | None = None,
        sampling_config: dict[str, Any] | None = None,
    ) -> BatchMeta:
        """Synchronously fetch data metadata from the controller via ZMQ.

        Args:
            data_fields: List of data field names to retrieve metadata for
            batch_size: Number of samples to request in the batch
            partition_id: Current data partition id
            mode: Data fetch mode. Options:
                - 'fetch': Get ready data only
                - 'force_fetch': Get data regardless of readiness (may return unready samples)
                - 'insert': Internal usage - should not be used by users
            task_name: Optional task name associated with the request
            sampling_config: Optional sampling configuration for custom samplers.


        Returns:
            BatchMeta: Metadata object containing data structure, sample information, and readiness status

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Example 1: Basic fetch metadata
            >>> batch_meta = client.get_meta(
            ...     data_fields=["input_ids", "attention_mask"],
            ...     batch_size=4,
            ...     partition_id="train_0",
            ...     mode="fetch",
            ...     task_name="generate_sequences"
            ... )
            >>> print(batch_meta.is_ready)  # True if all samples ready
            >>>
            >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
            >>> batch_meta = client.get_meta(
            ...     data_fields=["input_ids", "attention_mask"],
            ...     batch_size=8,
            ...     partition_id="train_0",
            ...     mode="fetch",
            ...     task_name="generate_sequences",
            ...     sampling_config={"n_samples_per_prompt": 4}
            ... )
            >>> print(batch_meta.is_ready)  # True if all samples ready
            >>>
            >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
            >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
            >>> batch_meta = client.get_meta(
            ...     partition_id="train_0",   # optional
            ...     mode="force_fetch",
            ... )
            >>> print(batch_meta.is_ready)  # May be False if some samples not ready
        """
        return self._get_meta(
            data_fields=data_fields,
            batch_size=batch_size,
            partition_id=partition_id,
            mode=mode,
            task_name=task_name,
            sampling_config=sampling_config,
        )

    def set_custom_meta(self, metadata: BatchMeta) -> None:
        """Synchronously send custom metadata to the controller.

        This method sends per-sample custom metadata (custom_meta) to the controller.
        The custom_meta is stored in the controller and can be retrieved along with
        the BatchMeta in subsequent get_meta calls.

        Args:
            metadata: BatchMeta containing the samples and their custom metadata to store.
                     The custom_meta should be set using BatchMeta.update_custom_meta()
                     before calling this method.

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Create batch with custom metadata
            >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=2, ...)
            >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
            >>> client.set_custom_meta(batch_meta)
        """

        return self._set_custom_meta(metadata=metadata)

    def put(
        self,
        data: TensorDict,
        metadata: BatchMeta | None = None,
        partition_id: str | None = None,
        data_parser: Callable[[Any], Any] | None = None,
    ) -> BatchMeta:
        """Synchronously write data to storage units based on metadata.

        If metadata is not provided, it will be created automatically using insert mode
        with the provided data fields and partition_id.

        During put, the custom_meta in metadata will update the corresponding custom_meta in
        TransferQueue Controller.

        Note:
            When using multiple workers for distributed execution, there may be data
            ordering inconsistencies between workers during put operations.

        Args:
            data: Data to write as TensorDict
            metadata: Records the metadata of a batch of data samples, containing index and
                      storage unit information. If None, metadata will be auto-generated.
            partition_id: Target data partition id (required if metadata is not provided)
            data_parser: Optional callable to parse reference data (e.g., URLs) into real
                         content. The input is a slice of the `data` parameter, in plain
                         dict format (not TensorDict), mapping field_name -> batched values.
                         For a regular tensor column the value is a batched tensor; for
                         nested tensors (jagged or strided) and NonTensorStack columns
                         the values are extracted into a list. It must modify values
                         in-place based on the original keys; do not add or remove keys.
                         The number of elements per column must also remain unchanged.
                         Do not change the inner order of values within each column.
                         Only supported by SimpleStorage.

        Returns:
            BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
                       metadata; will be updated in a future version to reflect the post-put state)

        Raises:
            ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
            RuntimeError: If storage operation fails

        Example:
            >>> batch_size = 4
            >>> seq_len = 16
            >>> current_partition_id = "train_0"
            >>> # Example 1: Normal usage with existing metadata
            >>> batch_meta = client.get_meta(
            ...     data_fields=["prompts", "attention_mask"],
            ...     batch_size=batch_size,
            ...     partition_id=current_partition_id,
            ...     mode="fetch",
            ...     task_name="generate_sequences",
            ... )
            >>> batch = client.get_data(batch_meta)
            >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
            >>> client.put(data=output, metadata=batch_meta)
            >>>
            >>> # Example 2: Initial data insertion without pre-existing metadata
            >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
            >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
            >>> # without metadata.
            >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
            >>> # interleave the initial data if n_sample > 1 before calling the async_put().
            >>> original_prompts = torch.randn(batch_size, seq_len)
            >>> n_samples = 4
            >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
            >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
            >>> # This will create metadata in "insert" mode internally.
            >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id)
        """
        return self._put(data=data, metadata=metadata, partition_id=partition_id, data_parser=data_parser)

    def get_data(self, metadata: BatchMeta) -> TensorDict:
        """Synchronously fetch data from storage units and organize into TensorDict.

        Args:
            metadata: Batch metadata containing data location information and global indexes

        Returns:
            TensorDict containing:
                - Requested data fields (e.g., "prompts", "attention_mask")

        Example:
            >>> batch_meta = client.get_meta(
            ...     data_fields=["prompts", "attention_mask"],
            ...     batch_size=4,
            ...     partition_id="train_0",
            ...     mode="fetch",
            ...     task_name="generate_sequences",
            ... )
            >>> batch = client.get_data(batch_meta)
            >>> print(batch)
            >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
        """
        return self._get_data(metadata=metadata)

    def clear_partition(self, partition_id: str):
        """Synchronously clear the whole partition from all storage units and the controller.

        Args:
            partition_id: The partition id to clear data for

        Raises:
            RuntimeError: If clear operation fails
        """
        return self._clear_partition(partition_id=partition_id)

    def clear_samples(self, metadata: BatchMeta):
        """Synchronously clear specific samples from all storage units and the controller.

        Args:
            metadata: The BatchMeta of the corresponding data to be cleared

        Raises:
            RuntimeError: If clear operation fails
        """
        return self._clear_samples(metadata=metadata)

    # ==================== Status Query API ====================
    def get_consumption_status(
        self,
        task_name: str,
        partition_id: str,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
        """Synchronously get consumption status for a specific task and partition.

        Args:
            task_name: Name of the task to check consumption for
            partition_id: Partition id to check consumption status for

        Returns:
            Tuple of:
            - Partition global index tensor
            - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> global_index, consumption_status = client.get_consumption_status(
            ...     task_name="generate_sequences",
            ...     partition_id="train_0"
            ... )
            >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}")
        """
        return self._get_consumption_status(task_name, partition_id)

    def get_production_status(
        self,
        data_fields: list[str],
        partition_id: str,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
        """Synchronously get production status for specific data fields and partition.

        Args:
            data_fields: Data fields to check production status for
            partition_id: Partition id to check production status for

        Returns:
            Tuple of:
            - Partition global index tensor
            - Production status tensor for the specified fields. 1 for ready, 0 for not ready.

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> global_index, production_status = client.get_production_status(
            ...     data_fields=["input_ids", "attention_mask"],
            ...     partition_id="train_0"
            ... )
            >>> print(f"Global index: {global_index}, Production status: {production_status}")
        """
        return self._get_production_status(data_fields=data_fields, partition_id=partition_id)

    def check_consumption_status(self, task_name: str, partition_id: str) -> bool:
        """Synchronously check if all samples for a partition have been consumed by a specific task.

        Args:
            task_name: Name of the task to check consumption for
            partition_id: Partition id to check consumption status for

        Returns:
            bool: True if all samples have been consumed by the task, False otherwise

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Check if all samples have been consumed
            >>> is_consumed = client.check_consumption_status(
            ...     task_name="generate_sequences",
            ...     partition_id="train_0"
            ... )
            >>> print(f"All samples consumed: {is_consumed}")
        """
        return self._check_consumption_status(task_name=task_name, partition_id=partition_id)

    def check_production_status(self, data_fields: list[str], partition_id: str) -> bool:
        """Synchronously check if all samples for a partition are ready (produced) for consumption.

        Args:
            data_fields: Data fields to check production status for
            partition_id: Partition id to check production status for

        Returns:
            bool: True if all samples have been produced and ready, False otherwise

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Check if all samples are ready for consumption
            >>> is_ready = client.check_production_status(
            ...     data_fields=["input_ids", "attention_mask"],
            ...     partition_id="train_0"
            ... )
            >>> print(f"All samples ready: {is_ready}")
        """
        return self._check_production_status(data_fields=data_fields, partition_id=partition_id)

    def reset_consumption(self, partition_id: str, task_name: str | None = None) -> bool:
        """Synchronously reset consumption status for a partition.

        This allows the same data to be re-consumed, useful for debugging scenarios
        where the same rollout data needs to be trained multiple times.

        Args:
            partition_id: Partition id to reset consumption status for
            task_name: Name of the task to reset. If None, resets all tasks.

        Returns:
            bool: True if reset was successful, False otherwise

        Raises:
            RuntimeError: If communication fails or controller returns error response

        Example:
            >>> # Reset consumption for train task to re-train on same data
            >>> success = client.reset_consumption(
            ...     partition_id="train_0",
            ...     task_name="train"
            ... )
            >>> print(f"Reset successful: {success}")
        """
        return self._reset_consumption(partition_id, task_name)

    def get_partition_list(
        self,
    ) -> list[str]:
        """Synchronously fetch the list of partition ids from the controller.

        Returns:
            list[str]: List of partition ids managed by the controller

        Example:
            >>> partition_ids = client.get_partition_list()
            >>> print(f"Available partitions: {partition_ids}")
        """
        return self._get_partition_list()

    # ==================== KV Interface API ====================
    def kv_retrieve_meta(
        self,
        keys: list[str] | str,
        partition_id: str,
        create: bool = False,
    ) -> BatchMeta:
        """Synchronously retrieve BatchMeta by user-defined keys.

        Retrieves metadata for given keys from a specified partition.
        If keys do not exist and `create=True`, they will be automatically registered.

        Args:
            keys: List of keys to retrieve from the controller.
            partition_id: Logical partition to query.
            create: If True, automatically create entries for non-existent keys.

        Returns:
            BatchMeta: Metadata for the requested keys.

        Raises:
            TypeError: If `keys` is not a list of string or a string
        """
        return self._kv_retrieve_meta(keys=keys, partition_id=partition_id, create=create)

    def kv_retrieve_keys(
        self,
        global_indexes: list[int] | int,
        partition_id: str,
    ) -> BatchMeta:
        """Synchronously retrieve keys according to global_indexes from the controller.

        Args:
            global_indexes: List of global_indexes to retrieve from the controller
            partition_id: The ID of the logical partition to search for global_indexes.

        Returns:
            keys: list of keys of the corresponding global_indexes

        Raises:
            TypeError: If `global_indexes` is not a list of int or an int
            RuntimeError: If some indexes in `global_indexes` do not have corresponding keys
        """

        return self._kv_retrieve_keys(global_indexes=global_indexes, partition_id=partition_id)

    def kv_list(
        self,
        partition_id: str | None = None,
    ) -> dict[str, dict[str, Any]]:
        """Synchronously retrieve keys and custom_meta from the controller for one or all partitions.

        Args:
            partition_id: The specific partition_id to query.
                If None (default), returns keys from all partitions.
            socket: ZMQ socket (injected by decorator)

        Returns:
            A nested dictionary mapping partition IDs to their keys and metadata.

            Structure:
            {
                "partition_id": {
                    "key_name": {
                        "tag1": <value>,
                        ... (other metadata)
                    },
                    ...,
                },
                ...
            }
        """

        return self._kv_list(partition_id=partition_id)

    def close(self) -> None:
        """Close the client and cleanup resources including event loop and thread."""

        if hasattr(self, "_loop") and self._loop is not None:
            self._loop.call_soon_threadsafe(self._loop.stop)

            if hasattr(self, "_thread") and self._thread is not None:
                self._thread.join(timeout=5.0)
                if self._thread.is_alive():
                    logger.warning(f"[{self.client_id}]: Background thread did not stop within timeout")

            try:
                self._loop.close()
            except Exception as e:
                logger.warning(f"[{self.client_id}]: Error closing event loop: {e}")

        super().close()