# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from threading import Thread
from unittest.mock import patch

import pytest
import torch
import zmq
from tensordict import NonTensorStack, TensorDict

from transfer_queue import TransferQueueClient
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.enum_utils import Role
from transfer_queue.utils.zmq_utils import (
    ZMQMessage,
    ZMQRequestType,
    ZMQServerInfo,
)

TEST_DATA = TensorDict(
    {
        "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
        "variable_length_sequences": torch.nested.as_nested_tensor(
            [
                torch.tensor([-0.5, -1.2, -0.8]),
                torch.tensor([-0.3, -1.5, -2.1, -0.9]),
                torch.tensor([-1.1, -0.7]),
            ],
            layout=torch.jagged,
        ),
        "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
    },
    batch_size=[3],
)


# Mock Controller for Client Unit Testing
class MockController:
    def __init__(self, controller_id="controller_0"):
        self.controller_id = controller_id
        self.context = zmq.Context()

        # Socket for data requests
        self.request_socket = self.context.socket(zmq.ROUTER)
        self.request_port = self._bind_to_random_port(self.request_socket)

        self.zmq_server_info = ZMQServerInfo(
            role=Role.CONTROLLER,
            id=controller_id,
            ip="127.0.0.1",
            ports={
                "request_handle_socket": self.request_port,
            },
        )

        self.running = True
        self.request_thread = Thread(target=self._handle_requests, daemon=True)
        self.request_thread.start()

    def _bind_to_random_port(self, socket):
        port = socket.bind_to_random_port("tcp://127.0.0.1")
        return port

    def _handle_requests(self):
        poller = zmq.Poller()
        poller.register(self.request_socket, zmq.POLLIN)

        while self.running:
            try:
                socks = dict(poller.poll(100))  # 100ms timeout
                if self.request_socket in socks:
                    messages = self.request_socket.recv_multipart(copy=False)
                    identity = messages.pop(0)
                    serialized_msg = messages
                    request_msg = ZMQMessage.deserialize(serialized_msg)

                    # Determine response based on request type
                    if request_msg.request_type == ZMQRequestType.GET_META:
                        response_body = self._mock_batch_meta(request_msg.body)
                        response_type = ZMQRequestType.GET_META_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.CLEAR_META:
                        response_body = {"message": "clear meta ok"}
                        response_type = ZMQRequestType.CLEAR_META_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.CLEAR_PARTITION:
                        response_body = {"message": "clear partition ok"}
                        response_type = ZMQRequestType.CLEAR_PARTITION_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.GET_PARTITION_META:
                        # Mock partition metadata response
                        response_body = self._mock_batch_meta(request_msg.body)
                        response_type = ZMQRequestType.GET_PARTITION_META_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.GET_CONSUMPTION:
                        # Mock consumption status check - all consumed
                        response_body = {
                            "partition_id": request_msg.body.get("partition_id"),
                            "global_index": torch.tensor([0, 1, 2]),
                            "consumption_status": torch.tensor([1, 1, 1]),
                        }
                        response_type = ZMQRequestType.CONSUMPTION_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION:
                        # Mock production status check - all produced
                        response_body = {
                            "partition_id": request_msg.body.get("partition_id"),
                            "global_index": torch.tensor([0, 1, 2]),
                            "production_status": torch.tensor([[1, 1, 1], [1, 1, 1]]),
                        }
                        response_type = ZMQRequestType.PRODUCTION_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS:
                        # Mock partition list
                        response_body = {
                            "partition_ids": ["partition_0", "partition_1", "test_partition"],
                        }
                        response_type = ZMQRequestType.LIST_PARTITIONS_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META:
                        response_body = {"message": "success"}
                        response_type = ZMQRequestType.SET_CUSTOM_META_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.RESET_CONSUMPTION:
                        # Mock reset consumption - always succeed
                        response_body = {
                            "success": True,
                            "message": "Consumption reset successfully",
                        }
                        response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_META:
                        response_body = self._mock_kv_retrieve_meta(request_msg.body)
                        response_type = ZMQRequestType.KV_RETRIEVE_META_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS:
                        response_body = self._mock_kv_retrieve_keys(request_msg.body)
                        response_type = ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE
                    elif request_msg.request_type == ZMQRequestType.KV_LIST:
                        response_body = self._mock_kv_list(request_msg.body)
                        response_type = ZMQRequestType.KV_LIST_RESPONSE
                    else:
                        response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
                        response_type = ZMQRequestType.CLEAR_META_RESPONSE

                    # Send response
                    response_msg = ZMQMessage.create(
                        request_type=response_type,
                        sender_id=self.controller_id,
                        receiver_id=request_msg.sender_id,
                        body=response_body,
                    )
                    self.request_socket.send_multipart([identity, *response_msg.serialize()])
            except zmq.Again:
                continue
            except Exception as e:
                print(f"MockController ERROR: {e}")
                raise

    def _mock_batch_meta(self, request_body):
        batch_size = request_body.get("batch_size", 1)
        data_fields = request_body.get("data_fields", [])

        # Build columnar field_schema
        field_schema = {
            field_name: {"dtype": None, "shape": None, "is_nested": False, "is_non_tensor": False}
            for field_name in data_fields
        }

        metadata = BatchMeta(
            global_indexes=list(range(batch_size)),
            partition_ids=["0"] * batch_size,
            field_schema=field_schema,
        )

        return {"metadata": metadata}

    def _mock_kv_retrieve_meta(self, request_body):
        """Mock KV retrieve keys response."""
        keys = request_body.get("keys", [])
        create = request_body.get("create", False)
        partition_id = request_body.get("partition_id", "")

        if not hasattr(self, "_kv_partition_keys"):
            self._kv_partition_keys = {}

        start_index = self._get_next_kv_index(partition_id)
        global_indexes = list(range(start_index, start_index + len(keys)))

        # Build columnar BatchMeta for KV interface
        field_schema = {
            "data": {"dtype": "torch.float32", "shape": [1, 10], "is_nested": False, "is_non_tensor": False}
        }
        import numpy as np

        production_status = np.ones(len(global_indexes), dtype=np.int8)
        metadata = BatchMeta(
            global_indexes=global_indexes,
            partition_ids=[partition_id] * len(global_indexes),
            field_schema=field_schema,
            production_status=production_status,
        )

        if create:
            if partition_id not in self._kv_partition_keys:
                self._kv_partition_keys[partition_id] = []
            self._kv_partition_keys[partition_id].extend(keys)

        if global_indexes:
            self._update_kv_index(partition_id, global_indexes[-1] + 1)

        return {"metadata": metadata}

    def _mock_kv_list(self, request_body):
        """Mock KV list response."""
        partition_id = request_body.get("partition_id", None)

        # Initialize key tracking if not exists
        if not hasattr(self, "_kv_partition_keys"):
            self._kv_partition_keys = {}

        # Return cached keys for this partition
        keys = self._kv_partition_keys.get(partition_id, [])

        return {"partition_info": {partition_id: {k: {} for k in keys}}, "message": "success"}

    def _mock_kv_retrieve_keys(self, request_body):
        """Mock KV retrieve indexes response."""
        global_indexes = request_body.get("global_indexes", [])
        partition_id = request_body.get("partition_id", "")

        # Initialize key tracking if not exists
        if not hasattr(self, "_kv_partition_keys"):
            self._kv_partition_keys = {}

        # Initialize index to key mapping if not exists
        if not hasattr(self, "_kv_index_to_key"):
            self._kv_index_to_key = {}

        # Get keys for this partition
        partition_keys = self._kv_partition_keys.get(partition_id, [])

        # Build reverse mapping from index to key if needed
        if not hasattr(self, "_kv_partition_index_map"):
            self._kv_partition_index_map = {}

        if partition_id not in self._kv_partition_index_map:
            # Build the mapping from stored keys
            start_idx = self._get_next_kv_index(partition_id) - len(partition_keys)
            self._kv_partition_index_map[partition_id] = {}
            for i, key in enumerate(partition_keys):
                self._kv_partition_index_map[partition_id][start_idx + i] = key

        index_map = self._kv_partition_index_map.get(partition_id, {})

        # Retrieve keys for the given global_indexes
        keys = []
        for idx in global_indexes:
            keys.append(index_map.get(idx, None))

        return {"keys": keys}

    def _get_next_kv_index(self, partition_id):
        """Get next available index for KV keys in partition."""
        if not hasattr(self, "_kv_index_map"):
            self._kv_index_map = {}
        if partition_id not in self._kv_index_map:
            self._kv_index_map[partition_id] = 0
            # Also initialize key tracking
            if not hasattr(self, "_kv_partition_keys"):
                self._kv_partition_keys = {}
            self._kv_partition_keys[partition_id] = []
        return self._kv_index_map[partition_id]

    def _update_kv_index(self, partition_id, next_index):
        """Update next available index for KV keys."""
        if not hasattr(self, "_kv_index_map"):
            self._kv_index_map = {}
        self._kv_index_map[partition_id] = next_index

    def stop(self):
        self.running = False
        time.sleep(0.2)  # Give thread time to stop
        self.request_socket.close()
        self.context.term()


# Mock Storage for Client Unit Testing
class MockStorage:
    def __init__(self, storage_id="storage_0"):
        self.storage_id = storage_id
        self.context = zmq.Context()

        # Socket for data operations
        self.data_socket = self.context.socket(zmq.ROUTER)
        self.data_port = self._bind_to_random_port(self.data_socket)

        self.zmq_server_info = ZMQServerInfo(
            role=Role.STORAGE,
            id=storage_id,
            ip="127.0.0.1",
            ports={
                "put_get_socket": self.data_port,
            },
        )

        self.running = True
        self.data_thread = Thread(target=self._handle_data_requests, daemon=True)
        self.data_thread.start()

    def _bind_to_random_port(self, socket):
        port = socket.bind_to_random_port("tcp://127.0.0.1")
        return port

    def _handle_data_requests(self):
        poller = zmq.Poller()
        poller.register(self.data_socket, zmq.POLLIN)

        while self.running:
            try:
                socks = dict(poller.poll(100))  # 100ms timeout
                if self.data_socket in socks:
                    messages = self.data_socket.recv_multipart(copy=False)
                    identity = messages.pop(0)
                    serialized_msg = messages
                    msg = ZMQMessage.deserialize(serialized_msg)

                    # Handle different request types
                    if msg.request_type == ZMQRequestType.PUT_DATA:
                        response_body = {"message": "Data stored successfully"}
                        response_type = ZMQRequestType.PUT_DATA_RESPONSE
                    elif msg.request_type == ZMQRequestType.GET_DATA:
                        response_body = self._handle_get_data(msg.body)
                        response_type = ZMQRequestType.GET_DATA_RESPONSE
                    elif msg.request_type == ZMQRequestType.CLEAR_DATA:
                        response_body = {"message": "Data cleared successfully"}
                        response_type = ZMQRequestType.CLEAR_DATA_RESPONSE

                    # Send response
                    response_msg = ZMQMessage.create(
                        request_type=response_type,
                        sender_id=self.storage_id,
                        receiver_id=msg.sender_id,
                        body=response_body,
                    )
                    self.data_socket.send_multipart([identity, *response_msg.serialize()])
            except zmq.Again:
                continue
            except Exception as e:
                if self.running:
                    print(f"MockStorage running exception: {e}")
                else:
                    print(f"MockStorage ERROR: {e}")
                    raise

    def _handle_get_data(self, request_body):
        """Handle GET_DATA request by retrieving stored data"""
        global_indexes = request_body.get("global_indexes", [])
        fields = request_body.get("fields", [])

        result: dict[str, list] = {}
        for field in fields:
            gathered_items = [TEST_DATA[field][i] for i in global_indexes]

            if gathered_items:
                all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
                if all_tensors:
                    result[field] = torch.nested.as_nested_tensor(gathered_items, layout=torch.jagged)
                else:
                    result[field] = NonTensorStack(*gathered_items)

        return {"data": TensorDict(result)}

    def stop(self):
        self.running = False
        time.sleep(0.2)  # Give thread time to stop
        self.data_socket.close()
        self.context.term()


# Test Fixtures
@pytest.fixture
def mock_controller():
    controller = MockController()
    yield controller
    controller.stop()


@pytest.fixture
def mock_storage():
    storage = MockStorage()
    yield storage
    storage.stop()


@pytest.fixture
def client_setup(mock_controller, mock_storage):
    # Create client with mock controller and storage
    client_id = "client_0"

    client = TransferQueueClient(
        client_id=client_id,
        controller_info=mock_controller.zmq_server_info,
    )

    # Mock the storage manager to avoid handshake issues but mock all data operations
    with patch(
        "transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller"
    ):
        config = {
            "controller_info": mock_controller.zmq_server_info,
            "zmq_info": {mock_storage.storage_id: mock_storage.zmq_server_info},
        }
        client.initialize_storage_manager(manager_type="SimpleStorage", config=config)

        # Mock all storage manager methods to avoid real ZMQ operations
        async def mock_put_data(data, metadata, data_parser=None):
            pass  # Just pretend to store the data

        async def mock_get_data(metadata):
            # Return the test data when requested
            return TEST_DATA

        async def mock_clear_data(metadata):
            pass  # Just pretend to clear the data

        client.storage_manager.put_data = mock_put_data
        client.storage_manager.get_data = mock_get_data
        client.storage_manager.clear_data = mock_clear_data

    yield client, mock_controller, mock_storage


# Test basic functionality
def test_client_initialization(client_setup):
    """Test client initialization and connection setup"""
    client, mock_controller, mock_storage = client_setup

    assert client.client_id is not None
    assert client._controller is not None
    assert client._controller.id == mock_controller.controller_id


def test_put_and_get_data(client_setup):
    """Test basic put and get operations"""
    client, _, _ = client_setup

    # Test put operation
    client.put(data=TEST_DATA, partition_id="0")

    # Get metadata for retrieving data
    metadata = client.get_meta(
        data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, partition_id="0"
    )

    # Test get operation
    result = client.get_data(metadata)

    # Verify result structure
    assert "log_probs" in result
    assert "variable_length_sequences" in result
    assert "prompt_text" in result

    torch.testing.assert_close(result["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
    torch.testing.assert_close(result["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
    torch.testing.assert_close(result["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
    torch.testing.assert_close(result["variable_length_sequences"][1], torch.tensor([-0.3, -1.5, -2.1, -0.9]))
    assert result["prompt_text"][0] == "Hello world!"
    assert result["prompt_text"][1] == "This is a longer sentence for testing"


def test_get_meta(client_setup):
    """Test metadata retrieval"""
    client, _, _ = client_setup

    # Test get_meta operation
    metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=10, partition_id="0")

    # Verify metadata structure
    assert hasattr(metadata, "global_indexes")
    assert hasattr(metadata, "field_names")
    assert hasattr(metadata, "size")
    assert len(metadata.global_indexes) == 10


# Test with single controller and multiple storage units
def test_single_controller_multiple_storages():
    """Test client with single controller and multiple storage units"""
    # Create single controller and multiple storage units
    controller = MockController("controller_0")
    storages = [MockStorage(f"storage_{i}") for i in range(3)]

    try:
        # Create client with single controller
        client_id = "client_test_single_controller"

        client = TransferQueueClient(client_id=client_id, controller_info=controller.zmq_server_info)

        # Mock the storage manager to avoid handshake issues but mock all data operations
        with patch(
            "transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller"
        ):
            config = {
                "controller_info": controller.zmq_server_info,
                "zmq_info": {s.storage_id: s.zmq_server_info for s in storages},
            }
            client.initialize_storage_manager(manager_type="SimpleStorage", config=config)

            # Mock all storage manager methods to avoid real ZMQ operations
            async def mock_put_data(data, metadata, data_parser=None):
                pass  # Just pretend to store the data

            async def mock_get_data(metadata):
                # Return some test data when requested
                return TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)

            async def mock_clear_data(metadata):
                pass  # Just pretend to clear the data

            client.storage_manager.put_data = mock_put_data
            client.storage_manager.get_data = mock_get_data
            client.storage_manager.clear_data = mock_clear_data

        # Verify controller is set
        assert client._controller is not None
        assert client._controller.id == controller.controller_id

        # Test basic operation
        test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)

        # Test put operation
        client.put(data=test_data, partition_id="0")

    finally:
        # Clean up
        controller.stop()
        for s in storages:
            s.stop()


# Test error handling
def test_put_without_required_params(client_setup):
    """Test put operation without required parameters"""
    client, _, _ = client_setup

    # Create test data
    test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)

    # Test put without partition id (should fail)
    with pytest.raises(ValueError):
        client.put(data=test_data)


# Test new status checking methods
def test_check_consumption_status(client_setup):
    """Test consumption status checking"""
    client, _, _ = client_setup

    # Test synchronous check_consumption_status
    is_consumed = client.check_consumption_status(task_name="generate_sequences", partition_id="train_0")
    assert is_consumed is True


def test_check_production_status(client_setup):
    """Test production status checking"""
    client, _, _ = client_setup

    # Test synchronous check_production_status
    is_produced = client.check_production_status(data_fields=["prompt_ids", "attention_mask"], partition_id="train_0")
    assert is_produced is True


def test_get_consumption_status(client_setup):
    """Test get_consumption_status - returns global_index and consumption_status tensors"""
    client, _, _ = client_setup

    # Test synchronous get_consumption_status
    global_index, consumption_status = client.get_consumption_status(
        task_name="generate_sequences", partition_id="train_0"
    )

    # Verify return types
    assert global_index is not None
    assert consumption_status is not None

    # Verify global_index contains expected values
    assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))

    # Verify consumption_status (mock returns all consumed)
    expected_status = torch.tensor([1, 1, 1], dtype=torch.int8)
    assert torch.equal(consumption_status, expected_status)

    print("✓ get_consumption_status returns correct global_index and consumption_status")


def test_get_production_status(client_setup):
    """Test get_production_status - returns global_index and production_status tensors"""
    client, _, _ = client_setup

    # Test synchronous get_production_status
    global_index, production_status = client.get_production_status(
        data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
    )

    # Verify return types
    assert global_index is not None
    assert production_status is not None

    # Verify global_index contains expected values
    assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))

    # Verify production_status shape (mock returns 2x3 matrix)
    expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8)
    assert torch.equal(production_status, expected_status)

    print("✓ get_production_status returns correct global_index and production_status")


def test_get_partition_list(client_setup):
    """Test partition list retrieval"""
    client, _, _ = client_setup

    # Test synchronous get_partition_list
    partition_list = client.get_partition_list()
    assert isinstance(partition_list, list)
    assert len(partition_list) > 0
    assert "partition_0" in partition_list
    assert "partition_1" in partition_list
    assert "test_partition" in partition_list


def test_reset_consumption(client_setup):
    """Test synchronous reset_consumption - resets consumption status for a partition"""
    client, _, _ = client_setup

    # Test synchronous reset_consumption with task_name
    success = client.reset_consumption(partition_id="train_0", task_name="generate_sequences")
    assert success is True

    print("✓ reset_consumption with task_name returns True")


def test_reset_consumption_all_tasks(client_setup):
    """Test synchronous reset_consumption without task_name (resets all tasks)"""
    client, _, _ = client_setup

    # Test synchronous reset_consumption without task_name (reset all tasks)
    success = client.reset_consumption(partition_id="train_0")
    assert success is True

    print("✓ reset_consumption without task_name (all tasks) returns True")


@pytest.mark.asyncio
async def test_async_reset_consumption(client_setup):
    """Test async reset_consumption - resets consumption status for a partition"""
    client, _, _ = client_setup

    # Test async_reset_consumption with task_name
    success = await client.async_reset_consumption(partition_id="train_0", task_name="generate_sequences")
    assert success is True

    print("✓ async_reset_consumption with task_name returns True")


@pytest.mark.asyncio
async def test_async_reset_consumption_all_tasks(client_setup):
    """Test async reset_consumption without task_name (resets all tasks)"""
    client, _, _ = client_setup

    # Test async_reset_consumption without task_name (reset all tasks)
    success = await client.async_reset_consumption(partition_id="train_0")
    assert success is True

    print("✓ async_reset_consumption without task_name (all tasks) returns True")


@pytest.mark.asyncio
async def test_async_check_consumption_status(client_setup):
    """Test async consumption status checking"""
    client, _, _ = client_setup

    # Test async_check_consumption_status
    is_consumed = await client.async_check_consumption_status(task_name="generate_sequences", partition_id="train_0")
    assert is_consumed is True


@pytest.mark.asyncio
async def test_async_check_production_status(client_setup):
    """Test async production status checking"""
    client, _, _ = client_setup

    # Test async_check_production_status
    is_produced = await client.async_check_production_status(
        data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
    )
    assert is_produced is True


@pytest.mark.asyncio
async def test_async_get_consumption_status(client_setup):
    """Test async get_consumption_status - returns global_index and consumption_status tensors"""
    client, _, _ = client_setup

    # Test async_get_consumption_status
    global_index, consumption_status = await client.async_get_consumption_status(
        task_name="generate_sequences", partition_id="train_0"
    )

    # Verify return types
    assert global_index is not None
    assert consumption_status is not None

    # Verify global_index contains expected values
    assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))

    # Verify consumption_status (mock returns all consumed)
    expected_status = torch.tensor([1, 1, 1], dtype=torch.int8)
    assert torch.equal(consumption_status, expected_status)

    print("✓ async_get_consumption_status returns correct global_index and consumption_status")


@pytest.mark.asyncio
async def test_async_get_production_status(client_setup):
    """Test async get_production_status - returns global_index and production_status tensors"""
    client, _, _ = client_setup

    # Test async_get_production_status
    global_index, production_status = await client.async_get_production_status(
        data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
    )

    # Verify return types
    assert global_index is not None
    assert production_status is not None

    # Verify global_index contains expected values
    assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))

    # Verify production_status shape (mock returns 2x3 matrix)
    expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8)
    assert torch.equal(production_status, expected_status)

    print("✓ async_get_production_status returns correct global_index and production_status")


@pytest.mark.asyncio
async def test_async_get_partition_list(client_setup):
    """Test async partition list retrieval"""
    client, _, _ = client_setup

    # Test async_get_partition_list
    partition_list = await client.async_get_partition_list()
    assert isinstance(partition_list, list)
    assert len(partition_list) > 0
    assert "partition_0" in partition_list
    assert "partition_1" in partition_list
    assert "test_partition" in partition_list


# Test clear methods
@pytest.mark.asyncio
async def test_async_clear_partition(client_setup):
    """Test async clear partition operation"""
    client, _, _ = client_setup

    # Test async_clear_partition
    await client.async_clear_partition(partition_id="test_partition")

    # If no exception is raised, the test passes
    assert True


@pytest.mark.asyncio
async def test_async_clear_samples(client_setup):
    """Test async clear samples operation"""
    client, _, _ = client_setup

    # First get metadata to create a BatchMeta object
    metadata = await client.async_get_meta(data_fields=["tokens", "labels"], batch_size=2, partition_id="0")

    # Test async_clear_samples
    await client.async_clear_samples(metadata=metadata)

    # If no exception is raised, the test passes
    assert True


def test_clear_partition(client_setup):
    """Test synchronous clear partition operation"""
    client, _, _ = client_setup

    # Test synchronous clear_partition
    client.clear_partition(partition_id="test_partition")

    # If no exception is raised, the test passes
    assert True


def test_clear_samples(client_setup):
    """Test synchronous clear samples operation"""
    client, _, _ = client_setup

    # First get metadata to create a BatchMeta object
    metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=2, partition_id="0")

    # Test synchronous clear_samples
    client.clear_samples(metadata=metadata)

    # If no exception is raised, the test passes
    assert True


@pytest.mark.asyncio
async def test_async_clear_samples_with_empty_metadata(client_setup):
    """Test async_clear_samples with empty BatchMeta"""
    client, _, _ = client_setup

    # Create empty BatchMeta
    metadata = BatchMeta(global_indexes=[], partition_ids=[], field_schema={})

    # The clear operation should complete without raising an exception
    # because the mock storage manager is configured to handle this
    await client.async_clear_samples(metadata=metadata)

    # If no exception is raised, the test passes
    assert True


@pytest.mark.asyncio
async def test_sync_methods_work_in_async_event_loop(client_setup):
    """Test all synchronous methods can be called from within an asyncio event loop.

    This test verifies that the sync methods can be called directly from an async
    function without causing "asyncio.run() cannot be called from a running loop" errors.
    """
    client, _, _ = client_setup

    test_data = TensorDict({"tokens": torch.randint(0, 100, (3, 64))}, batch_size=3)

    # Test sync put
    metadata = client.put(data=test_data, partition_id="0")
    assert metadata is not None

    # Test sync get_meta - use fields that mock returns
    metadata = client.get_meta(
        data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, partition_id="0"
    )
    assert metadata is not None
    assert len(metadata.global_indexes) == 2

    # Test sync get_data - verify we get the expected fields from mock
    result = client.get_data(metadata)
    assert result is not None
    assert "log_probs" in result
    assert "prompt_text" in result

    # Test sync check_consumption_status
    is_consumed = client.check_consumption_status(task_name="generate_sequences", partition_id="train_0")
    assert isinstance(is_consumed, bool)

    # Test sync get_consumption_status
    global_index, consumption_status = client.get_consumption_status(
        task_name="generate_sequences", partition_id="train_0"
    )
    assert global_index is not None
    assert consumption_status is not None

    # Test sync check_production_status
    is_produced = client.check_production_status(data_fields=["log_probs", "prompt_text"], partition_id="train_0")
    assert isinstance(is_produced, bool)

    # Test sync get_production_status
    global_index, production_status = client.get_production_status(
        data_fields=["log_probs", "prompt_text"], partition_id="train_0"
    )
    assert global_index is not None
    assert production_status is not None

    # Test sync get_partition_list
    partition_list = client.get_partition_list()
    assert isinstance(partition_list, list)
    assert len(partition_list) > 0

    # Test sync clear_partition
    client.clear_partition(partition_id="test_partition")

    # Test sync clear_samples
    metadata = client.get_meta(data_fields=["log_probs", "prompt_text"], batch_size=2, partition_id="0")
    client.clear_samples(metadata=metadata)

    print("✓ All sync methods work correctly when called from within asyncio event loop")


@pytest.mark.asyncio
async def test_sync_and_async_methods_mixed_usage(client_setup):
    """Test mixing sync and async method calls within the same async context.

    This test verifies that async methods and sync methods can be used interchangeably
    without conflicts when called from an async function.
    """
    client, _, _ = client_setup

    test_data = TensorDict({"tokens": torch.randint(0, 100, (2, 32))}, batch_size=2)

    # Call sync method first
    sync_put_result = client.put(data=test_data, partition_id="0")
    assert sync_put_result is not None

    # Call async method
    async_metadata = await client.async_get_meta(data_fields=["tokens"], batch_size=2, partition_id="0")
    assert async_metadata is not None

    # Call sync method again
    sync_get_meta_result = client.get_meta(data_fields=["tokens"], batch_size=2, partition_id="0")
    assert sync_get_meta_result is not None

    # Call async method
    async_data = await client.async_get_data(sync_get_meta_result)
    assert async_data is not None

    print("✓ Mixed async and sync method calls work correctly")


# =====================================================
# Custom Meta Interface Tests
# =====================================================


class TestClientCustomMetaInterface:
    """Tests for client custom_meta interface methods."""

    def test_set_custom_meta_sync(self, client_setup):
        """Test synchronous set_custom_meta method."""
        client, _, _ = client_setup

        # Test synchronous set_custom_meta

        # First get metadata
        metadata = client.get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
        # Set custom_meta on the metadata
        metadata.update_custom_meta(
            [
                {"input_ids": {"token_count": 100}},
                {"input_ids": {"token_count": 120}},
            ]
        )

        # Call set_custom_meta with metadata (BatchMeta)
        client.set_custom_meta(metadata)
        print("✓ set_custom_meta sync method works")

    @pytest.mark.asyncio
    async def test_set_custom_meta_async(self, client_setup):
        """Test asynchronous async_set_custom_meta method."""
        client, _, _ = client_setup

        # First get metadata
        metadata = await client.async_get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
        # Set custom_meta on the metadata
        metadata.update_custom_meta(
            [
                {"input_ids": {"token_count": 100}},
                {"input_ids": {"token_count": 120}},
            ]
        )

        # Call async_set_custom_meta with metadata (BatchMeta)
        await client.async_set_custom_meta(metadata)
        print("✓ async_set_custom_meta async method works")


# =====================================================
# KV Interface Tests
# =====================================================


class TestClientKVInterface:
    """Tests for client KV interface methods."""

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_meta_single(self, client_setup):
        """Test async_kv_retrieve_meta with single key."""
        client, _, _ = client_setup

        # Test async_kv_retrieve_meta with single key
        metadata = await client.async_kv_retrieve_meta(
            keys="test_key_1",
            partition_id="test_partition",
            create=True,
        )

        # Verify metadata structure
        assert metadata is not None
        assert hasattr(metadata, "global_indexes")
        assert hasattr(metadata, "size")
        assert metadata.size == 1

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_meta_multiple(self, client_setup):
        """Test async_kv_retrieve_meta with multiple keys."""
        client, _, _ = client_setup

        # Test async_kv_retrieve_meta with multiple keys
        keys = ["key_a", "key_b", "key_c"]
        metadata = await client.async_kv_retrieve_meta(
            keys=keys,
            partition_id="test_partition",
            create=True,
        )

        # Verify metadata structure
        assert metadata is not None
        assert hasattr(metadata, "global_indexes")
        assert hasattr(metadata, "size")
        assert metadata.size == 3

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_meta_create_false(self, client_setup):
        """Test async_kv_retrieve_meta with create=False (retrieve existing keys)."""
        client, _, _ = client_setup

        # create some keys
        await client.async_kv_retrieve_meta(
            keys="existing_key",
            partition_id="existing_partition",
            create=True,
        )

        # Then retrieve them with create=False
        metadata = await client.async_kv_retrieve_meta(
            keys="existing_key",
            partition_id="existing_partition",
            create=False,
        )

        # Verify metadata structure
        assert metadata is not None
        assert metadata.size == 1

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_meta_invalid_keys_type(self, client_setup):
        """Test async_kv_retrieve_meta raises error with invalid keys type."""
        client, _, _ = client_setup

        # Test with invalid keys type (not string or list)
        with pytest.raises(TypeError):
            await client.async_kv_retrieve_meta(
                keys=123,  # Invalid type
                partition_id="test_partition",
                create=True,
            )

    @pytest.mark.asyncio
    async def test_async_kv_list_with_keys(self, client_setup):
        """Test async_kv_list returns keys after they are registered."""
        client, mock_controller, _ = client_setup

        # First register some keys
        await client.async_kv_retrieve_meta(
            keys=["key_1", "key_2"],
            partition_id="kv_partition",
            create=True,
        )

        # Then list them
        partition_info = await client.async_kv_list(partition_id="kv_partition")

        # Verify keys are returned
        assert len(partition_info["kv_partition"]) >= 2
        assert "key_1" in partition_info["kv_partition"]
        assert "key_2" in partition_info["kv_partition"]
        assert list(partition_info["kv_partition"].values()) == [{}, {}]

    @pytest.mark.asyncio
    async def test_async_kv_list_multiple_partitions(self, client_setup):
        """Test async_kv_list with multiple partitions."""
        client, _, _ = client_setup

        # Create keys in different partitions
        await client.async_kv_retrieve_meta(
            keys="partition_a_key",
            partition_id="partition_a",
            create=True,
        )
        await client.async_kv_retrieve_meta(
            keys="partition_b_key",
            partition_id="partition_b",
            create=True,
        )

        # List keys for each partition
        partition_a = await client.async_kv_list(partition_id="partition_a")
        partition_b = await client.async_kv_list(partition_id="partition_b")

        # Verify keys are isolated per partition
        assert "partition_a" in partition_a
        assert "partition_b" in partition_b
        assert "partition_a" not in partition_b
        assert "partition_b" not in partition_a
        assert "partition_a_key" in partition_a["partition_a"]
        assert "partition_b_key" not in partition_a["partition_a"]
        assert "partition_b_key" in partition_b["partition_b"]
        assert "partition_a_key" not in partition_b["partition_b"]
        assert list(partition_a["partition_a"].values()) == [{}]
        assert list(partition_b["partition_b"].values()) == [{}]

    def test_kv_retrieve_meta_type_validation(self, client_setup):
        """Test synchronous kv_retrieve_meta type validation."""
        import asyncio

        client, _, _ = client_setup

        # Test with non-string element in list
        async def test_invalid_list():
            with pytest.raises(TypeError):
                await client.async_kv_retrieve_meta(
                    keys=["valid_key", 123],  # Invalid: 123 is not a string
                    partition_id="test_partition",
                    create=True,
                )

        asyncio.run(test_invalid_list())

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_keys_single(self, client_setup):
        """Test async_kv_retrieve_keys with single global_index."""
        client, _, _ = client_setup
        partition_id = "test_partition_idx"

        # First create a key using kv_retrieve_meta
        await client.async_kv_retrieve_meta(
            keys=["test_key"],
            partition_id=partition_id,
            create=True,
        )

        # Now retrieve the key using global_index 0
        keys = await client.async_kv_retrieve_keys(
            global_indexes=[0],
            partition_id=partition_id,
        )

        assert keys == ["test_key"]

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_keys_multiple(self, client_setup):
        """Test async_kv_retrieve_keys with multiple global_indexes."""
        client, _, _ = client_setup
        partition_id = "test_partition_idx"

        # First create keys using kv_retrieve_meta
        keys_to_create = ["key_a", "key_b", "key_c"]
        await client.async_kv_retrieve_meta(
            keys=keys_to_create,
            partition_id=partition_id,
            create=True,
        )

        # Retrieve keys using global_indexes [0, 1, 2]
        keys = await client.async_kv_retrieve_keys(
            global_indexes=[0, 1, 2],
            partition_id=partition_id,
        )

        assert keys == ["key_a", "key_b", "key_c"]

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_keys_partial(self, client_setup):
        """Test async_kv_retrieve_keys with subset of global_indexes."""
        client, _, _ = client_setup
        partition_id = "test_partition_idx"

        # First create keys using kv_retrieve_meta
        await client.async_kv_retrieve_meta(
            keys=["first_key", "second_key", "third_key"],
            partition_id=partition_id,
            create=True,
        )

        # Retrieve only first and third keys
        keys = await client.async_kv_retrieve_keys(
            global_indexes=[0, 2],
            partition_id=partition_id,
        )

        assert keys == ["first_key", "third_key"]

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_keys_single_int(self, client_setup):
        """Test async_kv_retrieve_keys accepts a single int."""
        client, _, _ = client_setup
        partition_id = "test_partition_idx"

        # First create a key using kv_retrieve_meta
        await client.async_kv_retrieve_meta(
            keys=["single_key"],
            partition_id=partition_id,
            create=True,
        )

        # Now retrieve the key using a single int (not a list)
        keys = await client.async_kv_retrieve_keys(
            global_indexes=0,
            partition_id=partition_id,
        )

        assert keys == ["single_key"]

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_keys_invalid_type(self, client_setup):
        """Test async_kv_retrieve_keys raises error with invalid global_indexes type."""
        client, _, _ = client_setup

        # Test with invalid type (string instead of int)
        with pytest.raises(TypeError):
            await client.async_kv_retrieve_keys(
                global_indexes=["not_an_int"],
                partition_id="test_partition",
            )

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_keys_empty_list(self, client_setup):
        """Test async_kv_retrieve_keys raises error with empty list."""
        client, _, _ = client_setup

        with pytest.raises(ValueError):
            await client.async_kv_retrieve_keys(
                global_indexes=[],
                partition_id="test_partition",
            )

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_keys_non_existent(self, client_setup):
        """Test async_kv_retrieve_keys returns None for non-existent global_indexes."""
        client, _, _ = client_setup
        partition_id = "test_partition_idx"

        # First create a key using kv_retrieve_meta
        await client.async_kv_retrieve_meta(
            keys=["existing_key"],
            partition_id=partition_id,
            create=True,
        )

        # Try to retrieve a non-existent global_index
        keys = await client.async_kv_retrieve_keys(
            global_indexes=[99],
            partition_id=partition_id,
        )
        assert keys == [None]

    @pytest.mark.asyncio
    async def test_async_kv_retrieve_keys_multiple_partitions(self, client_setup):
        """Test async_kv_retrieve_keys returns keys from the correct partition."""
        client, _, _ = client_setup
        partition_1 = "partition_1"
        partition_2 = "partition_2"

        # Create keys in both partitions
        await client.async_kv_retrieve_meta(
            keys=["key_1"],
            partition_id=partition_1,
            create=True,
        )
        await client.async_kv_retrieve_meta(
            keys=["key_2"],
            partition_id=partition_2,
            create=True,
        )

        # Retrieve key from partition_1 (global_index 0)
        keys_1 = await client.async_kv_retrieve_keys(
            global_indexes=[0],
            partition_id=partition_1,
        )

        # Retrieve key from partition_2 (global_index 0)
        keys_2 = await client.async_kv_retrieve_keys(
            global_indexes=[0],
            partition_id=partition_2,
        )

        assert keys_1 == ["key_1"]
        assert keys_2 == ["key_2"]

    def test_kv_retrieve_keys_sync(self, client_setup):
        """Test synchronous kv_retrieve_keys."""
        client, _, _ = client_setup
        partition_id = "test_partition_sync"

        # First create a key using kv_retrieve_meta
        client.kv_retrieve_meta(
            keys=["sync_key"],
            partition_id=partition_id,
            create=True,
        )

        # Now retrieve the key using global_index
        keys = client.kv_retrieve_keys(
            global_indexes=[0],
            partition_id=partition_id,
        )

        assert keys == ["sync_key"]

    def test_kv_retrieve_keys_type_validation(self, client_setup):
        """Test synchronous kv_retrieve_keys type validation."""
        client, _, _ = client_setup

        # Test with non-int element in list
        with pytest.raises(TypeError):
            client.kv_retrieve_keys(
                global_indexes=[0, "invalid"],
                partition_id="test_partition",
            )