# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import time
from importlib import resources
from typing import Any, Callable

import ray
import torch
from omegaconf import DictConfig, OmegaConf
from tensordict import TensorDict
from tensordict.tensorclass import NonTensorStack

from transfer_queue.client import TransferQueueClient
from transfer_queue.controller import TransferQueueController
from transfer_queue.metadata import KVBatchMeta
from transfer_queue.sampler import *  # noqa: F401
from transfer_queue.sampler import BaseSampler
from transfer_queue.storage.bootstrap import StorageBootstrapProvider
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.yuanrong_utils import cleanup_yuanrong_resources
from transfer_queue.utils.zmq_utils import process_zmq_server_info

logger = get_logger(__name__)

_TQ_CLIENT: Any = None
_TQ_STORAGE: Any = None
_TQ_CONTROLLER: Any = None


def _maybe_create_tq_client(conf: DictConfig | None = None) -> TransferQueueClient:
    global _TQ_CLIENT
    if _TQ_CLIENT is None:
        if conf is None:
            _init_from_existing()
            assert _TQ_CLIENT is not None, (
                "TransferQueueController has not been initialized yet. Please call init() first."
            )
            return _TQ_CLIENT

        pid = os.getpid()
        _TQ_CLIENT = TransferQueueClient(
            client_id=f"TransferQueueClient_{pid}", controller_info=conf.controller.zmq_info
        )

        backend_name = conf.backend.storage_backend

        _TQ_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name])

    return _TQ_CLIENT


def _maybe_create_tq_storage(conf: DictConfig) -> DictConfig:
    global _TQ_STORAGE

    if _TQ_STORAGE is None:
        _TQ_STORAGE = {}
        backend_name = conf.backend.storage_backend
        provider_fn = StorageBootstrapProvider.get_provider(backend_name)
        if provider_fn is not None:
            backend_resources = provider_fn(conf)
            if backend_resources is not None:
                _TQ_STORAGE[backend_name] = backend_resources
            else:
                logger.error(f"Not found available {backend_name} storage resources, please check the config.")
        else:
            logger.error(
                f"Storage backend {backend_name} not registered. Please add it to the StorageBootstrapProvider."
            )
    return conf


def _init_from_existing() -> bool:
    """Initialize the TransferQueueClient from existing controller.

    Returns:
        True if successfully initialized from existing controller, False otherwise.
    """
    global _TQ_CONTROLLER
    try:
        if _TQ_CONTROLLER is None:
            _TQ_CONTROLLER = ray.get_actor("TransferQueueController", namespace="transfer_queue")

    except ValueError:
        logger.info("Called _init_from_existing() but TransferQueueController has not been initialized yet.")
        return False

    logger.info("Found existing TransferQueueController instance. Connecting...")

    conf = None
    while conf is None:
        conf = ray.get(_TQ_CONTROLLER.get_config.remote())
        if conf is not None:
            _maybe_create_tq_client(conf)

            logger.info("TransferQueueClient initialized.")
            return True

        logger.debug("Waiting for controller to initialize... Retrying in 1s")
        time.sleep(1)

    return False


# ==================== Initialization API ====================
def init(conf: DictConfig | None = None) -> DictConfig | None:
    """Initialize the TransferQueue system.

    This function sets up the TransferQueue controller, distributed storage, and client.
    It should be called once at the beginning of the program before any data operations.

    If a controller already exists, reuse it and only initialize the client;
    the provided `conf` will be ignored in this case.

    Args:
        conf: Optional custom config merged with default `config.yaml`.
              Only takes effect on first-time initialization, ignored when attaching
              to an existing controller.
    Returns:
        The merged configuration dictionary.

    Example:
        >>> # In process 0, node A
        >>> import transfer_queue as tq
        >>> tq.init()   # Initialize the TransferQueue
        >>> tq.put(...) # then you can use tq for data operations
        >>>
        >>> # In process 1, node B (with Ray connected to node A)
        >>> import transfer_queue as tq
        >>> tq.init()   # This will only initialize a TransferQueueClient and link with existing TQ
        >>> metadata = tq.get_meta(...)
        >>> data = tq.get_data(metadata)
    """
    if _init_from_existing():
        return conf

    logger.info("No TransferQueueController found. Starting first-time initialization...")

    final_conf = OmegaConf.create({}, flags={"allow_objects": True})
    default_conf = OmegaConf.load(resources.files("transfer_queue") / "config.yaml")
    final_conf = OmegaConf.merge(final_conf, default_conf)
    if conf:
        final_conf = OmegaConf.merge(final_conf, conf)

    # TODO(hz): support load custom sampler class from external module.
    try:
        sampler = final_conf.controller.sampler
        if isinstance(sampler, BaseSampler):
            # user pass a pre-initialized sampler instance
            sampler = sampler
        elif isinstance(sampler, type) and issubclass(sampler, BaseSampler):
            # user pass a sampler class
            sampler = sampler()
        elif isinstance(sampler, str):
            # user pass a sampler name str
            # try to convert as sampler class
            sampler = globals()[final_conf.controller.sampler]
    except KeyError:
        raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None

    try:
        global _TQ_CONTROLLER
        _TQ_CONTROLLER = TransferQueueController.options(  # type: ignore[attr-defined]
            name="TransferQueueController", namespace="transfer_queue"
        ).remote(sampler=sampler, polling_mode=final_conf.controller.polling_mode)
        logger.info("TransferQueueController has been created.")
    except ValueError:
        logger.info("Some other rank has initialized TransferQueueController. Try to connect to existing controller.")
        _init_from_existing()
        return final_conf

    controller_zmq_info = process_zmq_server_info(_TQ_CONTROLLER)
    final_conf.controller.zmq_info = controller_zmq_info

    final_conf = _maybe_create_tq_storage(final_conf)

    ray.get(_TQ_CONTROLLER.store_config.remote(final_conf))
    logger.info(f"TransferQueue config: {final_conf}")

    # start Prometheus metrics exporter if enabled
    metrics_conf = final_conf.get("metrics", {})
    if metrics_conf.get("enabled", False):
        metrics_port = metrics_conf.get("port", 0)
        metrics_endpoint = ray.get(_TQ_CONTROLLER.start_metrics.remote(port=metrics_port))
        final_conf.metrics.enabled = True
        final_conf.metrics.endpoint = metrics_endpoint
        # Update stored config so other processes can discover the endpoint
        ray.get(_TQ_CONTROLLER.store_config.remote(final_conf))
        # Register storage units for metrics collection (SimpleStorage only)
        if final_conf.backend.storage_backend == "SimpleStorage":
            storage_zmq_info = final_conf.backend.SimpleStorage.get("zmq_info")
            if storage_zmq_info:
                ray.get(_TQ_CONTROLLER.register_storage_units_for_metrics.remote(storage_zmq_info))
            # Start metrics exporter on each storage unit
            if _TQ_STORAGE and "SimpleStorage" in _TQ_STORAGE:
                futures = [handle.start_metrics.remote(port=0) for handle in _TQ_STORAGE["SimpleStorage"].values()]
                ray.get(futures)
        logger.info(f"Prometheus metrics endpoint: http://{metrics_endpoint}/metrics")

    _maybe_create_tq_client(final_conf)
    return final_conf


def close():
    """Close the TransferQueue system.

    This function cleans up the TransferQueue system, including:
    - Closing the client and its associated resources
    - Cleaning up distributed storage (only for the process that initialized it)
    - Killing the controller actor

    Note:
        This function should be called when the TransferQueue system is no longer needed.
    """
    global _TQ_CLIENT
    global _TQ_STORAGE
    global _TQ_CONTROLLER

    try:
        if _TQ_STORAGE:
            for key, value in _TQ_STORAGE.items():
                if key == "SimpleStorage":
                    # only the process that do first-time init can clean the distributed storage
                    for storage in value.values():
                        ray.kill(storage)
                elif key == "MooncakeStore":
                    check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True)
                    if check.returncode == 0:
                        pids = check.stdout.strip().replace("\n", ", ")
                        logger.warning(
                            f"TransferQueue will not stop mooncake_master process with PID: {pids}. "
                            f"Consider manually killing the mooncake_master."
                        )

                    if _TQ_CLIENT:
                        try:
                            ret = _TQ_CLIENT.storage_manager.storage_client._store.remove_all()
                            if ret < 0:
                                logger.error("Failed to remove existing keys in mooncake_master.")
                            else:
                                logger.info("Successfully removed all existing keys in mooncake_master.")
                        except Exception:
                            pass
                elif key == "Yuanrong":
                    cleanup_yuanrong_resources(value)
                else:
                    logger.warning(f"close for _TQ_STORAGE with key {key} is not supported for now.")

        _TQ_STORAGE = None
    except Exception:
        pass

    if _TQ_CLIENT:
        _TQ_CLIENT.close()
        _TQ_CLIENT = None

    if _TQ_CONTROLLER:
        try:
            ray.kill(_TQ_CONTROLLER)
        except Exception:
            pass
        _TQ_CONTROLLER = None


# ==================== Metrics API ====================
def get_metrics_endpoint() -> str | None:
    """Return the Prometheus metrics endpoint address (``host:port``), or *None* if metrics are disabled.

    Works from any process — the endpoint is stored in the Controller's config
    so that processes joining via ``_init_from_existing()`` can discover it too.

    Example:
        >>> import transfer_queue as tq
        >>> tq.init({"metrics": {"enabled": True}})
        >>> endpoint = tq.get_metrics_endpoint()
        >>> print(endpoint)   # e.g. "10.0.1.42:38271"
        >>> # Use endpoint to register Prometheus scrape target
    """
    if _TQ_CONTROLLER is None:
        _init_from_existing()
    if _TQ_CONTROLLER is None:
        return None
    conf = ray.get(_TQ_CONTROLLER.get_config.remote())
    if conf is None:
        return None
    return conf.get("metrics", {}).get("endpoint", None)


# ==================== High-Level KV Interface API ====================
def kv_put(
    key: str,
    partition_id: str,
    fields: TensorDict | dict[str, Any] | None = None,
    tag: dict[str, Any] | None = None,
    data_parser: Callable[[Any], Any] | None = None,
) -> KVBatchMeta:
    """Put a single key-value pair to TransferQueue.

    This is a convenience method for putting data using a user-specified key
    instead of BatchMeta. Internally, the key is translated to a BatchMeta
    and the data is stored using the regular put mechanism.

    Args:
        key: User-specified key for the data sample (in row)
        partition_id: Logical partition to store the data in
        fields: Data fields to store. Can be a TensorDict or a dict of tensors.
                Each key in `fields` will be treated as a column for the data sample.
                If dict is provided, tensors will be unsqueezed to add batch dimension.
                If not provided, will only update the newly given tag to the key.
        tag: Optional metadata tag to associate with the key
        data_parser: Optional callable to parse reference data (e.g., URLs) into real
                     content. The input is a slice of the `fields` parameter passed to
                     kv_put / kv_batch_put, in plain dict format (not TensorDict),
                     mapping field_name -> batched values. For a regular tensor column
                     the value is a batched tensor; for nested tensors (jagged or
                     strided) and NonTensorStack columns the values are extracted into
                     a list. It must modify values in-place based on the original keys;
                     do not add or remove keys. The number of elements per column must
                     also remain unchanged. Do not change the inner order of values
                     within each column. Only supported by SimpleStorage.

    Returns:
        KVBatchMeta: Metadata containing the key, tags, partition_id, and fields.
                     The `fields` attribute includes all fields stored for this sample,
                     including any new fields written by this put operation.

    Raises:
        ValueError: If neither fields nor tag is provided
        ValueError: If nested tensors are provided (use kv_batch_put instead)
        RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`

    Example:
        >>> import transfer_queue as tq
        >>> import torch
        >>> tq.init()
        >>> # Put with both fields and tag
        >>> meta = tq.kv_put(
        ...     key="sample_1",
        ...     partition_id="train",
        ...     fields={"input_ids": torch.tensor([1, 2, 3])},
        ...     tag={"score": 0.95}
        ... )
        >>> print(meta.fields)  # ['input_ids']
    """
    if fields is None and tag is None:
        raise ValueError("Please provide at least one parameter of `fields` or `tag`.")

    tq_client = _maybe_create_tq_client()

    # 1. translate user-specified key to BatchMeta
    batch_meta = tq_client.kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True)

    if batch_meta.size != 1:
        raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")

    # 2. register the user-specified tag to BatchMeta
    if tag is not None:
        batch_meta.update_custom_meta([tag])

    # 3. put data
    if fields is not None:
        if isinstance(fields, dict):
            # TODO: consider whether to support this...
            batch = {}
            for field_name, value in fields.items():
                if isinstance(value, torch.Tensor):
                    if value.is_nested:
                        raise ValueError("Please use (async)kv_batch_put for batch operation")
                    batch[field_name] = value.unsqueeze(0)
                else:
                    batch[field_name] = NonTensorStack(value)
            fields = TensorDict(batch, batch_size=[1])
        elif not isinstance(fields, TensorDict):
            raise ValueError("`fields` can only be dict or TensorDict")

        # After put, batch_meta.field_names will include the new fields written by user
        batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser)
    else:
        # Directly update custom_meta (tag) to controller
        tq_client.set_custom_meta(batch_meta)

    fields_to_return = batch_meta.field_names

    return KVBatchMeta(
        keys=[key],
        tags=batch_meta.custom_meta,
        partition_id=partition_id,
        fields=fields_to_return,
        extra_info=batch_meta.extra_info,
    )


def kv_batch_put(
    keys: list[str],
    partition_id: str,
    fields: TensorDict | None = None,
    tags: list[dict[str, Any]] | None = None,
    data_parser: Callable[[Any], Any] | None = None,
) -> KVBatchMeta:
    """Batch put multiple key-value pairs into the TransferQueue.

    This method stores multiple key-value entries in a single operation,
    which is significantly more efficient than repeated calls to ``kv_put``.

    Args:
        keys: List of user-defined unique keys for the data entries.
        partition_id: Logical partition where the data will be stored.
        fields: TensorDict containing batched data for all keys. Must have ``batch_size == len(keys)``.
            If not provided, only the associated tags will be updated.
        tags: List of metadata dictionaries, one per key.  Length must match the number of keys.
        data_parser: Optional callable to parse raw reference data (e.g., URLs) into real content
            before storage. The input is a plain dict (not TensorDict) mapping field names to
            batched values. The parser  **must modify data in-place** without adding/removing
            keys or changing element counts/order. Only supported by ``SimpleStorage`` backend.

    Returns:
        KVBatchMeta: Metadata object containing stored keys, tags, partition ID,
            and field information. The ``fields`` attribute includes all
            persisted fields for the written samples.

    Raises:
        ValueError: When both ``fields`` and ``tags`` are empty.
        ValueError: When ``fields`` batch size mismatches key count.
        ValueError: When ``tags`` length mismatches key count.
        RuntimeError: When retrieved metadata size mismatches input key count.

    Example:
        >>> import transfer_queue as tq
        >>> from tensordict import TensorDict
        >>> tq.init()
        >>> keys = ["sample_1", "sample_2", "sample_3"]
        >>> fields = TensorDict({
        ...     "input_ids": torch.randn(3, 10),
        ...     "attention_mask": torch.ones(3, 10),
        ... }, batch_size=3)
        >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
        >>> meta = tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
        >>> print(meta.fields)
    """
    num_keys = len(keys)

    if fields is None and tags is None:
        raise ValueError("Please provide at least one parameter of fields or tag.")

    if fields is not None and fields.batch_size[0] != num_keys:
        raise ValueError(f"Length of `keys` ({num_keys}) does not match `fields` batch size ({fields.batch_size[0]}).")

    tq_client = _maybe_create_tq_client()
    batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True)

    if batch_meta.size != num_keys:
        raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match input `keys` size {num_keys}.")

    if tags is not None:
        if len(tags) != num_keys:
            raise ValueError(f"Length of `keys` ({num_keys}) does not match length of `tags` ({len(tags)}).")
        batch_meta.update_custom_meta(tags)

    if fields is not None:
        batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser)
    else:  # tags is not None
        tq_client.set_custom_meta(batch_meta)

    return KVBatchMeta(
        keys=keys,
        tags=batch_meta.custom_meta,
        partition_id=partition_id,
        fields=batch_meta.field_names,
        extra_info=batch_meta.extra_info,
    )


def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: list[str] | str | None = None) -> TensorDict:
    """Get data from TransferQueue using KVBatchMeta.

    This is a convenience method for retrieving data using KVBatchMeta returned
    from a previous put operation. It extracts the keys and partition_id from
    the metadata to fetch the corresponding data.

    Args:
        meta: KVBatchMeta object returned from a previous put operation (e.g., kv_put,
              kv_batch_put). It contains keys, partition_id, and fields information.
        select_fields: Optional field(s) to retrieve, which overrides the fields
                       recorded in the given KVBatchMeta. If None, uses all fields
                       from meta.fields. Can be a single field name (str) or a list
                       of field names.

    Returns:
        TensorDict with the requested data

    Raises:
        ValueError: If keys or partition are not found
        ValueError: If empty fields exist in any key (sample)
        ValueError: If any field in select_fields doesn't exist in KVBatchMeta.fields

    Example:
        >>> import transfer_queue as tq
        >>> tq.init()
        >>> # First put some data
        >>> keys = ["sample_1", "sample_2", "sample_3"]
        >>> fields = TensorDict({
        ...     "input_ids": torch.randn(3, 10),
        ...     "attention_mask": torch.ones(3, 10),
        ... }, batch_size=3)
        >>> meta = tq.kv_batch_put(keys=keys, partition_id="train", fields=fields)
        >>> # Then retrieve it using the returned metadata
        >>> data = tq.kv_batch_get_by_meta(meta)
    """
    if meta.partition_id is None:
        raise ValueError("Must provide partition_id in the input KVBatchMeta.")
    if select_fields is not None:
        if isinstance(select_fields, str):
            fields_to_fetch: list[str] | None = [select_fields]
        else:
            fields_to_fetch = select_fields

        assert fields_to_fetch is not None
        if meta.fields is None or any(f not in meta.fields for f in fields_to_fetch):
            raise ValueError(
                f"Some fields assigned in select_fields not found in the metadata. "
                f"Assigned: {fields_to_fetch}; Fields in KVBatchMeta: {meta.fields}."
            )
    else:
        fields_to_fetch = meta.fields
    return kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=fields_to_fetch)


def kv_batch_get(keys: list[str] | str, partition_id: str, select_fields: list[str] | str | None = None) -> TensorDict:
    """Get data from TransferQueue using user-specified keys.

    This is a convenience method for retrieving data using keys instead of indexes.

    Args:
        keys: Single key or list of keys to retrieve
        partition_id: Partition containing the keys
        select_fields: Optional field(s) to retrieve. If None, retrieves all fields

    Returns:
        TensorDict with the requested data

    Raises:
        ValueError: If keys or partition are not found
        ValueError: If empty fields exist in any key (sample)

    Example:
        >>> import transfer_queue as tq
        >>> tq.init()
        >>> # Get single key with all fields
        >>> data = tq.kv_batch_get(keys="sample_1", partition_id="train")
        >>> # Get multiple keys with specific fields
        >>> data = tq.kv_batch_get(
        ...     keys=["sample_1", "sample_2"],
        ...     partition_id="train",
        ...     select_fields="input_ids"
        ... )
    """
    tq_client = _maybe_create_tq_client()

    batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False)

    if batch_meta.size == 0:
        raise ValueError("keys or partition were not found!")

    fields_to_fetch: list[str] | None
    if select_fields is not None:
        if isinstance(select_fields, str):
            fields_to_fetch = [select_fields]
        else:
            fields_to_fetch = select_fields
        batch_meta = batch_meta.select_fields(fields_to_fetch)

    if not batch_meta.is_ready:
        raise ValueError("Some fields are not ready in all the requested keys!")

    data = tq_client.get_data(batch_meta)
    return data


def kv_list(partition_id: str | None = None) -> dict[str, dict[str, Any]]:
    """List all keys and their metadata in one or all partitions.

    Args:
        partition_id: The specific partition_id to query.
            If None (default), returns keys from all partitions.

    Returns:
        A nested dictionary mapping partition IDs to their keys and metadata.

        Structure:
        {
            "partition_id": {
                "key_name": {
                    "tag1": <value>,
                    ... (other metadata)
                },
                ...,
            },
            ...
        }

    Example:
        >>> import transfer_queue as tq
        >>> tq.init()
        >>> # Case 1: Retrieve a specific partition
        >>> partitions = tq.kv_list(partition_id="train")
        >>> print(f"Keys: {list(partitions['train'].keys())}")
        >>> print(f"Tags: {list(partitions['train'].values())}")
        >>> # Case 2: Retrieve all partitions
        >>> all_partitions = tq.kv_list()
        >>> for pid, keys in all_partitions.items():
        >>>     print(f"Partition: {pid}, Key count: {len(keys)}")
    """
    tq_client = _maybe_create_tq_client()

    partition_info = tq_client.kv_list(partition_id)

    return partition_info


def kv_clear(keys: list[str] | str, partition_id: str) -> None:
    """Clear key-value pairs from TransferQueue.

    This removes the specified keys and their associated data from both
    the controller and storage units.

    Args:
        keys: Single key or list of keys to clear
        partition_id: Partition containing the keys

    Example:
        >>> import transfer_queue as tq
        >>> tq.init()
        >>> # Clear single key
        >>> tq.kv_clear(keys="sample_1", partition_id="train")
        >>> # Clear multiple keys
        >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
    """

    if isinstance(keys, str):
        keys = [keys]

    tq_client = _maybe_create_tq_client()
    batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False)

    if batch_meta.size > 0:
        tq_client.clear_samples(batch_meta)


# ==================== KV Interface API ====================
async def async_kv_put(
    key: str,
    partition_id: str,
    fields: TensorDict | dict[str, Any] | None = None,
    tag: dict[str, Any] | None = None,
    data_parser: Callable[[Any], Any] | None = None,
) -> KVBatchMeta:
    """Asynchronously put a single key-value pair to TransferQueue.

    This is a convenience method for putting data using a user-specified key
    instead of BatchMeta. Internally, the key is translated to a BatchMeta
    and the data is stored using the regular put mechanism.

    Args:
        key: User-specified key for the data sample (in row)
        partition_id: Logical partition to store the data in
        fields: Data fields to store. Can be a TensorDict or a dict of tensors.
                Each key in `fields` will be treated as a column for the data sample.
                If dict is provided, tensors will be unsqueezed to add batch dimension.
                If not provided, will only update the newly given tag to the key.
        tag: Optional metadata tag to associate with the key
        data_parser: Optional callable to parse reference data (e.g., URLs) into real
                     content. The input is a slice of the `fields` parameter passed to
                     kv_put / kv_batch_put, in plain dict format (not TensorDict),
                     mapping field_name -> batched values. For a regular tensor column
                     the value is a batched tensor; for nested tensors (jagged or
                     strided) and NonTensorStack columns the values are extracted into
                     a list. It must modify values in-place based on the original keys;
                     do not add or remove keys. The number of elements per column must
                     also remain unchanged. Do not change the inner order of values
                     within each column. Only supported by SimpleStorage.

    Returns:
        KVBatchMeta: Metadata containing the key, tags, partition_id, and fields.
                     The `fields` attribute includes all fields stored for this sample,
                     including any new fields written by this put operation.

    Raises:
        ValueError: If neither fields nor tag is provided
        ValueError: If nested tensors are provided (use kv_batch_put instead)
        RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`

    Example:
        >>> import transfer_queue as tq
        >>> import torch
        >>> tq.init()
        >>> # Put with both fields and tag
        >>> meta = await tq.async_kv_put(
        ...     key="sample_1",
        ...     partition_id="train",
        ...     fields={"input_ids": torch.tensor([1, 2, 3])},
        ...     tag={"score": 0.95}
        ... )
        >>> print(meta.fields)  # ['input_ids']
    """

    if fields is None and tag is None:
        raise ValueError("Please provide at least one parameter of fields or tag.")

    tq_client = _maybe_create_tq_client()

    # 1. translate user-specified key to BatchMeta
    batch_meta = await tq_client.async_kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True)

    if batch_meta.size != 1:
        raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")

    # 2. register the user-specified tag to BatchMeta
    if tag is not None:
        batch_meta.update_custom_meta([tag])

    # 3. put data
    if fields is not None:
        if isinstance(fields, dict):
            # TODO: consider whether to support this...
            batch = {}
            for field_name, value in fields.items():
                if isinstance(value, torch.Tensor):
                    if value.is_nested:
                        raise ValueError("Please use (async)kv_batch_put for batch operation")
                    batch[field_name] = value.unsqueeze(0)
                else:
                    batch[field_name] = NonTensorStack(value)
            fields = TensorDict(batch, batch_size=[1])
        elif not isinstance(fields, TensorDict):
            raise ValueError("`fields` can only be dict or TensorDict")

        # After put, batch_meta.field_names will include the new fields written by user
        batch_meta = await tq_client.async_put(fields, batch_meta, data_parser=data_parser)
    else:
        # Directly update custom_meta (tag) to controller
        await tq_client.async_set_custom_meta(batch_meta)

    fields_to_return = batch_meta.field_names

    return KVBatchMeta(
        keys=[key],
        tags=batch_meta.custom_meta,
        partition_id=partition_id,
        fields=fields_to_return,
        extra_info=batch_meta.extra_info,
    )


async def async_kv_batch_put(
    keys: list[str],
    partition_id: str,
    fields: TensorDict | None = None,
    tags: list[dict[str, Any]] | None = None,
    data_parser: Callable[[Any], Any] | None = None,
) -> KVBatchMeta:
    """Asynchronously put multiple key-value pairs to TransferQueue in batch.

    This method stores multiple key-value pairs in a single operation, which is more
    efficient than calling kv_put multiple times.

    Args:
        keys: List of user-specified keys for the data
        partition_id: Logical partition to store the data in
        fields: TensorDict containing data for all keys. Must have batch_size == len(keys).
                If not provided, will only update the newly given tags to the keys.
        tags: List of metadata tags, one for each key
        data_parser: Optional callable to parse reference data (e.g., URLs) into real
                     content. The input is a slice of the `fields` parameter passed to
                     kv_put / kv_batch_put, in plain dict format (not TensorDict),
                     mapping field_name -> batched values. For a regular tensor column
                     the value is a batched tensor; for nested tensors (jagged or
                     strided) and NonTensorStack columns the values are extracted into
                     a list. It must modify values in-place based on the original keys;
                     do not add or remove keys. The number of elements per column must
                     also remain unchanged. Do not change the inner order of values
                     within each column. Only supported by SimpleStorage.

    Returns:
        KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields.
                     The `fields` attribute includes all fields stored for these samples,
                     including any new fields written by this put operation.

    Raises:
        ValueError: If neither `fields` nor `tags` is provided
        ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
        RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`

    Example:
        >>> import transfer_queue as tq
        >>> tq.init()
        >>> keys = ["sample_1", "sample_2", "sample_3"]
        >>> fields = TensorDict({
        ...     "input_ids": torch.randn(3, 10),
        ...     "attention_mask": torch.ones(3, 10),
        ... }, batch_size=3)
        >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
        >>> meta = await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
        >>> print(meta.fields)  # ['input_ids', 'attention_mask']
    """

    if fields is None and tags is None:
        raise ValueError("Please provide at least one parameter of `fields` or `tags`.")

    if fields is not None and fields.batch_size[0] != len(keys):
        raise ValueError(
            f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
            f"batch_size {fields.batch_size[0]}"
        )

    tq_client = _maybe_create_tq_client()

    # 1. translate user-specified key to BatchMeta
    batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True)

    if batch_meta.size != len(keys):
        raise RuntimeError(
            f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
        )

    # 2. register the user-specified tags to BatchMeta
    if tags is not None:
        if len(tags) != len(keys):
            raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
        batch_meta.update_custom_meta(tags)

    # 3. put data
    if fields is not None:
        # After put, batch_meta.field_names will include the new fields written by user
        batch_meta = await tq_client.async_put(fields, batch_meta, data_parser=data_parser)
    else:
        # Directly update custom_meta (tags) to controller
        await tq_client.async_set_custom_meta(batch_meta)

    fields_to_return = batch_meta.field_names

    return KVBatchMeta(
        keys=keys,
        tags=batch_meta.custom_meta,
        partition_id=partition_id,
        fields=fields_to_return,
        extra_info=batch_meta.extra_info,
    )


async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: list[str] | str | None = None) -> TensorDict:
    """Asynchronously get data from TransferQueue using KVBatchMeta.

    This is a convenience method for retrieving data using KVBatchMeta returned
    from a previous put operation. It extracts the keys and partition_id from
    the metadata to fetch the corresponding data.

    Args:
        meta: KVBatchMeta object returned from a previous put operation (e.g., async_kv_put,
              async_kv_batch_put). It contains keys, partition_id, and fields information.
        select_fields: Optional field(s) to retrieve, which overrides the fields
                       recorded in the given KVBatchMeta. If None, uses all fields
                       from meta.fields. Can be a single field name (str) or a list
                       of field names.

    Returns:
        TensorDict with the requested data

    Raises:
        ValueError: If keys or partition are not found
        ValueError: If empty fields exist in any key (sample)
        ValueError: If any field in select_fields doesn't exist in KVBatchMeta.fields

    Example:
        >>> import transfer_queue as tq
        >>> tq.init()
        >>> # First put some data
        >>> keys = ["sample_1", "sample_2", "sample_3"]
        >>> fields = TensorDict({
        ...     "input_ids": torch.randn(3, 10),
        ...     "attention_mask": torch.ones(3, 10),
        ... }, batch_size=3)
        >>> meta = await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields)
        >>> # Then retrieve it using the returned metadata
        >>> data = await tq.async_kv_batch_get_by_meta(meta)
    """
    if meta.partition_id is None:
        raise ValueError("Must provide partition_id in the input KVBatchMeta.")

    fields_to_fetch: list[str] | None
    if select_fields is not None:
        if isinstance(select_fields, str):
            fields_to_fetch = [select_fields]
        else:
            fields_to_fetch = select_fields

        assert fields_to_fetch is not None
        if meta.fields is None or any(f not in meta.fields for f in fields_to_fetch):
            raise ValueError(
                f"Some fields assigned in select_fields not found in the metadata. "
                f"Assigned: {fields_to_fetch}; Fields in KVBatchMeta: {meta.fields}."
            )
    else:
        fields_to_fetch = meta.fields
    return await async_kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=fields_to_fetch)


async def async_kv_batch_get(
    keys: list[str] | str, partition_id: str, select_fields: list[str] | str | None = None
) -> TensorDict:
    """Asynchronously get data from TransferQueue using user-specified keys.

    This is a convenience method for retrieving data using keys instead of indexes.

    Args:
        keys: Single key or list of keys to retrieve
        partition_id: Partition containing the keys
        select_fields: Optional field(s) to retrieve. If None, retrieves all fields

    Returns:
        TensorDict with the requested data

    Raises:
        ValueError: If keys or partition are not found
        ValueError: If empty fields exist in any key (sample)

    Example:
        >>> import transfer_queue as tq
        >>> tq.init()
        >>> # Get single key with all fields
        >>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train")
        >>> # Get multiple keys with specific fields
        >>> data = await tq.async_kv_batch_get(
        ...     keys=["sample_1", "sample_2"],
        ...     partition_id="train",
        ...     select_fields="input_ids"
        ... )
    """
    tq_client = _maybe_create_tq_client()

    batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False)

    if batch_meta.size == 0:
        raise ValueError("keys or partition were not found!")

    if select_fields is not None:
        if isinstance(select_fields, str):
            fields_to_fetch = [select_fields]
        else:
            fields_to_fetch = select_fields
        batch_meta = batch_meta.select_fields(fields_to_fetch)

    if not batch_meta.is_ready:
        raise ValueError("Some fields are not ready in all the requested keys!")

    data = await tq_client.async_get_data(batch_meta)
    return data


async def async_kv_list(partition_id: str | None = None) -> dict[str, dict[str, Any]]:
    """Asynchronously list all keys and their metadata in one or all partitions.

    Args:
        partition_id: The specific partition_id to query.
            If None (default), returns keys from all partitions.

    Returns:
        A nested dictionary mapping partition IDs to their keys and metadata.

        Structure:
        {
            "partition_id": {
                "key_name": {
                    "tag1": <value>,
                    ... (other metadata)
                },
                ...,
            },
            ...
        }


    Example:
        >>> import transfer_queue as tq
        >>> tq.init()
        >>> # Case 1: Retrieve a specific partition
        >>> partitions = await tq.async_kv_list(partition_id="train")
        >>> print(f"Keys: {list(partitions['train'].keys())}")
        >>> print(f"Tags: {list(partitions['train'].values())}")
        >>> # Case 2: Retrieve all partitions
        >>> all_partitions = await tq.async_kv_list()
        >>> for pid, keys in all_partitions.items():
        >>>     print(f"Partition: {pid}, Key count: {len(keys)}")
    """
    tq_client = _maybe_create_tq_client()

    partition_info = await tq_client.async_kv_list(partition_id)

    return partition_info


async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
    """Asynchronously clear key-value pairs from TransferQueue.

    This removes the specified keys and their associated data from both
    the controller and storage units.

    Args:
        keys: Single key or list of keys to clear
        partition_id: Partition containing the keys

    Example:
        >>> import transfer_queue as tq
        >>> tq.init()
        >>> # Clear single key
        >>> await tq.async_kv_clear(keys="sample_1", partition_id="train")
        >>> # Clear multiple keys
        >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
    """

    if isinstance(keys, str):
        keys = [keys]

    tq_client = _maybe_create_tq_client()
    batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False)

    if batch_meta.size > 0:
        await tq_client.async_clear_samples(batch_meta)


# ==================== Low-Level Native API ====================
# For low-level API support, please refer to transfer_queue/client.py for details.
def get_client():
    """Get a TransferQueueClient for using low-level API"""
    assert _TQ_CLIENT is not None, "Please initialize the TransferQueue first by calling `tq.init()`!"
    return _TQ_CLIENT