# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import pytest
import ray
import tensordict
import torch
import zmq

from transfer_queue.storage.simple_storage import SimpleStorageUnit
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType


class MockStorageClient:
    """Mock client for testing storage unit operations."""

    def __init__(self, storage_put_get_address):
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.DEALER)
        self.socket.setsockopt(zmq.RCVTIMEO, 5000)  # 5 second timeout
        self.socket.connect(storage_put_get_address)

    def send_put(self, client_id, global_indexes, field_data, data_parser=None):
        body = {"global_indexes": global_indexes, "data": field_data}
        if data_parser is not None:
            body["data_parser"] = data_parser
        msg = ZMQMessage.create(
            request_type=ZMQRequestType.PUT_DATA,
            sender_id=f"mock_client_{client_id}",
            body=body,
        )
        self.socket.send_multipart(msg.serialize())
        return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))

    def send_get(self, client_id, global_indexes, fields):
        msg = ZMQMessage.create(
            request_type=ZMQRequestType.GET_DATA,
            sender_id=f"mock_client_{client_id}",
            body={"global_indexes": global_indexes, "fields": fields},
        )
        self.socket.send_multipart(msg.serialize())
        return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))

    def send_clear(self, client_id, global_indexes):
        msg = ZMQMessage.create(
            request_type=ZMQRequestType.CLEAR_DATA,
            sender_id=f"mock_client_{client_id}",
            body={"global_indexes": global_indexes},
        )
        self.socket.send_multipart(msg.serialize())
        return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))

    def close(self):
        self.socket.close()
        self.context.term()


@pytest.fixture(scope="session")
def ray_setup():
    """Initialize Ray for testing."""
    ray.init(ignore_reinit_error=True)
    yield
    ray.shutdown()


@pytest.fixture
def storage_setup(ray_setup):
    """Set up storage unit for testing."""
    storage_size = 10000
    tensordict.set_list_to_stack(True).set()

    # Start Ray actor for SimpleStorageUnit
    storage_actor = SimpleStorageUnit.options(max_concurrency=50, num_cpus=1).remote(storage_unit_size=storage_size)

    # Get ZMQ server info from storage unit
    zmq_info = ray.get(storage_actor.get_zmq_server_info.remote())
    put_get_address = zmq_info.to_addr("put_get_socket")
    time.sleep(1)  # Wait for socket to be ready

    yield storage_actor, put_get_address

    # Cleanup
    ray.kill(storage_actor)


def test_put_get_single_client(storage_setup):
    """Test basic put and get operations with a single client."""
    _, put_get_address = storage_setup

    client = MockStorageClient(put_get_address)

    # PUT data
    global_indexes = [0, 1, 2]
    field_data = {
        "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
        "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
    }

    response = client.send_put(0, global_indexes, field_data)
    assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE

    # GET data
    response = client.send_get(0, [0, 1], ["log_probs", "rewards"])
    assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE

    retrieved_data = response.body["data"]
    assert "log_probs" in retrieved_data
    assert "rewards" in retrieved_data
    assert len(retrieved_data["log_probs"]) == 2
    assert len(retrieved_data["rewards"]) == 2

    # Verify data correctness
    torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
    torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
    torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0]))
    torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0]))

    client.close()


def test_put_get_multiple_clients(storage_setup):
    """Test put and get operations with multiple clients."""
    _, put_get_address = storage_setup

    num_clients = 3
    clients = [MockStorageClient(put_get_address) for _ in range(num_clients)]

    # Each client puts unique data using different global_indexes
    for i, client in enumerate(clients):
        global_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
        field_data = {
            "log_probs": [
                torch.tensor([i, i + 1, i + 2]),
                torch.tensor([i + 3, i + 4, i + 5]),
                torch.tensor([i + 6, i + 7, i + 8]),
            ],
            "rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])],
        }

        response = client.send_put(i, global_indexes, field_data)
        assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE

    # Test overlapping global indexes
    overlapping_client = MockStorageClient(put_get_address)
    overlap_global_indexes = [0]  # Overlaps with first client's index 0
    overlap_field_data = {"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]}
    response = overlapping_client.send_put(99, overlap_global_indexes, overlap_field_data)
    assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE

    # Each original client gets its own data (except for index 0 which was overwritten)
    for i, client in enumerate(clients):
        response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"])
        assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE

        retrieved_data = response.body["data"]
        assert len(retrieved_data["log_probs"]) == 2
        assert len(retrieved_data["rewards"]) == 2

        # For index 0, expect data from overlapping_client; others from original client
        if i == 0:
            # Index 0 was overwritten
            torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999]))
            torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999]))
            # Index 1 remains original
            torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5]))
            torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10]))
        else:
            # All data remains original
            torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2]))
            torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5]))
            torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10]))
            torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10]))

    # Cleanup
    for client in clients:
        client.close()
    overlapping_client.close()


def test_performance_basic(storage_setup):
    """Basic performance test with larger data volume."""
    _, put_get_address = storage_setup

    client = MockStorageClient(put_get_address)

    # PUT performance test
    put_latencies = []
    num_puts = 10  # Reduced for faster testing
    batch_size = 16  # Reduced for faster testing

    for i in range(num_puts):
        start = time.time()

        # Use batch size and index mapping
        global_indexes = list(range(i * batch_size, (i + 1) * batch_size))

        # Create tensor data
        log_probs_data = []
        rewards_data = []

        for _ in range(batch_size):
            # Smaller tensors for faster testing
            log_probs_tensor = torch.randn(100)
            rewards_tensor = torch.randn(100)
            log_probs_data.append(log_probs_tensor)
            rewards_data.append(rewards_tensor)

        field_data = {"log_probs": log_probs_data, "rewards": rewards_data}

        response = client.send_put(0, global_indexes, field_data)
        latency = time.time() - start
        put_latencies.append(latency)
        assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE

    # GET performance test
    get_latencies = []
    num_gets = 10

    for i in range(num_gets):
        start = time.time()
        # Retrieve batch of data
        global_indexes = list(range(i * batch_size, (i + 1) * batch_size))
        response = client.send_get(0, global_indexes, ["log_probs", "rewards"])
        latency = time.time() - start
        get_latencies.append(latency)
        assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE

    avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000  # ms
    avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000  # ms

    # More lenient performance thresholds for testing environment
    assert avg_put_latency < 1500, f"Avg PUT latency {avg_put_latency}ms exceeds threshold"
    assert avg_get_latency < 1500, f"Avg GET latency {avg_get_latency}ms exceeds threshold"

    client.close()


def test_put_get_nested_tensor(storage_setup):
    """Test put and get operations with nested tensors."""
    _, put_get_address = storage_setup

    client = MockStorageClient(put_get_address)

    # PUT data with nested tensors
    global_indexes = [0, 1, 2]
    field_data = {
        "variable_length_sequences": [
            torch.tensor([-0.5, -1.2, -0.8]),
            torch.tensor([-0.3, -1.5, -2.1, -0.9]),
            torch.tensor([-1.1, -0.7]),
        ],
        "attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])],
    }

    response = client.send_put(0, global_indexes, field_data)
    assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE

    # GET data
    response = client.send_get(0, [0, 2], ["variable_length_sequences", "attention_mask"])
    assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE

    retrieved_data = response.body["data"]
    assert "variable_length_sequences" in retrieved_data
    assert "attention_mask" in retrieved_data
    assert len(retrieved_data["variable_length_sequences"]) == 2
    assert len(retrieved_data["attention_mask"]) == 2

    # Verify data correctness
    torch.testing.assert_close(retrieved_data["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
    torch.testing.assert_close(retrieved_data["variable_length_sequences"][1], torch.tensor([-1.1, -0.7]))
    torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
    torch.testing.assert_close(retrieved_data["attention_mask"][1], torch.tensor([1, 1]))

    client.close()


def test_put_get_non_tensor_data(storage_setup):
    """Test put and get operations with non-tensor data (strings)."""
    _, put_get_address = storage_setup

    client = MockStorageClient(put_get_address)

    # PUT data with non-tensor data
    global_indexes = [0, 1, 2]
    field_data = {
        "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
        "response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"],
    }

    response = client.send_put(0, global_indexes, field_data)
    assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE

    # GET data
    response = client.send_get(0, [0, 1, 2], ["prompt_text", "response_text"])
    assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE

    retrieved_data = response.body["data"]
    assert "prompt_text" in retrieved_data
    assert "response_text" in retrieved_data

    # Verify data correctness
    assert isinstance(retrieved_data["prompt_text"][0], str)
    assert isinstance(retrieved_data["response_text"][0], str)

    assert retrieved_data["prompt_text"][0] == "Hello world!"
    assert retrieved_data["prompt_text"][1] == "This is a longer sentence for testing"
    assert retrieved_data["prompt_text"][2] == "Test case"
    assert retrieved_data["response_text"][0] == "Hi there!"
    assert retrieved_data["response_text"][1] == "This is the response to the longer sentence"
    assert retrieved_data["response_text"][2] == "Test response"

    client.close()


def test_put_get_single_item(storage_setup):
    """Test put and get operations for a single item."""
    _, put_get_address = storage_setup

    client = MockStorageClient(put_get_address)

    # PUT single item data
    field_data = {
        "prompt_text": ["Hello world!"],
        "attention_mask": [torch.tensor([1, 1, 1])],
    }
    response = client.send_put(0, [0], field_data)
    assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE

    # GET data
    response = client.send_get(0, [0], ["prompt_text", "attention_mask"])
    assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE

    retrieved_data = response.body["data"]

    assert "prompt_text" in retrieved_data
    assert "attention_mask" in retrieved_data

    assert retrieved_data["prompt_text"][0] == "Hello world!"
    assert len(retrieved_data["attention_mask"]) == 1
    torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))

    client.close()


def test_clear_data(storage_setup):
    """Test clear operations."""
    _, put_get_address = storage_setup

    client = MockStorageClient(put_get_address)

    # PUT data first
    global_indexes = [0, 1, 2]
    field_data = {
        "log_probs": [torch.tensor([1.0]), torch.tensor([2.0]), torch.tensor([3.0])],
        "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
    }

    response = client.send_put(0, global_indexes, field_data)
    assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE

    # Verify data exists
    response = client.send_get(0, [0, 1, 2], ["log_probs"])
    assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
    assert len(response.body["data"]["log_probs"]) == 3

    # Clear data
    response = client.send_clear(0, [0, 2])  # Clear only indexes 0 and 2
    assert response.request_type == ZMQRequestType.CLEAR_DATA_RESPONSE

    # Verify some data is cleared (but index 1 should still exist)
    response = client.send_get(0, [1], ["log_probs"])
    assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
    assert len(response.body["data"]["log_probs"]) == 1
    torch.testing.assert_close(response.body["data"]["log_probs"][0], torch.tensor([2.0]))

    client.close()


def test_storage_unit_data_direct():
    """Test StorageUnitData class directly without ZMQ."""
    from transfer_queue.storage import StorageUnitData

    storage_data = StorageUnitData(storage_size=10)

    field_data = {
        "log_probs": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])],
        "rewards": [torch.tensor([10.0]), torch.tensor([20.0])],
    }
    # global_indexes = global_index values (e.g., 0 and 1)
    storage_data.put_data(field_data, [0, 1])

    result = storage_data.get_data(["log_probs", "rewards"], [0, 1])
    assert "log_probs" in result
    assert "rewards" in result
    assert len(result["log_probs"]) == 2
    assert len(result["rewards"]) == 2

    result_single = storage_data.get_data(["log_probs"], [0])
    torch.testing.assert_close(result_single["log_probs"][0], torch.tensor([1.0, 2.0]))

    # clear: key is removed (not set to None)
    storage_data.clear([0])
    assert 0 not in storage_data.field_data["log_probs"]  # key gone
    assert 1 in storage_data.field_data["log_probs"]  # other key intact


def test_storage_unit_data_capacity_uses_active_keys():
    """Capacity check must use _active_keys, not scan field_data."""
    from transfer_queue.storage.simple_storage import StorageUnitData

    storage = StorageUnitData(storage_size=3)

    # Fill to capacity
    storage.put_data({"f": [1, 2, 3]}, global_indexes=[0, 1, 2])
    assert len(storage._active_keys) == 3

    # Exceeding capacity must raise
    with pytest.raises(ValueError, match="Storage capacity exceeded"):
        storage.put_data({"f": [4]}, global_indexes=[3])

    # After clearing one key, adding one more should succeed
    storage.clear(keys=[2])
    assert len(storage._active_keys) == 2
    storage.put_data({"f": [4]}, global_indexes=[3])
    assert storage._active_keys == {0, 1, 3}


def test_storage_unit_data_parser(storage_setup):
    """Test data_parser functionality in SimpleStorageUnit.

    Writes two columns:
    - normal_data: regular tensors, should remain unchanged
    - data_to_be_parsed: list of shape descriptors (list of ints)

    data_parser converts shape descriptors into random tensors of those shapes.
    """
    _, put_get_address = storage_setup
    client = MockStorageClient(put_get_address)

    def create_data_by_shape_parser(field_data):
        if "data_to_be_parsed" in field_data:
            shapes = field_data["data_to_be_parsed"]
            field_data["data_to_be_parsed"] = [torch.randn(shape) for shape in shapes]
        return field_data

    # Prepare data: normal_data is a batch tensor, data_to_be_parsed is a list of shape lists
    field_data = {
        "normal_data": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
        "data_to_be_parsed": [[2, 3], [1, 4], [3, 2]],
    }
    global_indexes = [0, 1, 2]

    # Put with data_parser
    response = client.send_put(0, global_indexes, field_data, data_parser=create_data_by_shape_parser)
    assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"Put failed: {response.body}"

    # Get back
    response = client.send_get(0, global_indexes, ["normal_data", "data_to_be_parsed"])
    assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE

    result = response.body["data"]

    # Verify normal_data is unchanged
    torch.testing.assert_close(result["normal_data"][0], torch.tensor([1.0, 2.0]))
    torch.testing.assert_close(result["normal_data"][1], torch.tensor([3.0, 4.0]))
    torch.testing.assert_close(result["normal_data"][2], torch.tensor([5.0, 6.0]))

    # Verify data_to_be_parsed shapes match the input shape descriptors
    expected_shapes = [(2, 3), (1, 4), (3, 2)]
    for i, expected_shape in enumerate(expected_shapes):
        actual_shape = tuple(result["data_to_be_parsed"][i].shape)
        assert actual_shape == expected_shape, (
            f"Shape mismatch at index {i}: expected {expected_shape}, got {actual_shape}"
        )

    client.close()


def test_storage_unit_data_parser_callable_types(storage_setup):
    """Test that various callable types (partial, callable class) work as data_parser."""
    _, put_get_address = storage_setup
    client = MockStorageClient(put_get_address)

    from functools import partial

    # 1. Test functools.partial
    def _partial_parser(field_data, prefix):
        if "text" in field_data:
            field_data["text"] = [f"{prefix}{t}" for t in field_data["text"]]
        return field_data

    partial_parser = partial(_partial_parser, prefix="parsed_")

    response = client.send_put(
        0,
        [0, 1],
        {"text": ["a", "b"]},
        data_parser=partial_parser,
    )
    assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"partial parser failed: {response.body}"

    response = client.send_get(0, [0, 1], ["text"])
    assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
    assert response.body["data"]["text"] == ["parsed_a", "parsed_b"]

    # 2. Test callable class instance
    class CallableParser:
        def __call__(self, field_data):
            if "value" in field_data:
                field_data["value"] = [v * 2 for v in field_data["value"]]
            return field_data

    callable_parser = CallableParser()
    response = client.send_put(
        0,
        [2, 3],
        {"value": [1, 2]},
        data_parser=callable_parser,
    )
    assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"callable class parser failed: {response.body}"

    response = client.send_get(0, [2, 3], ["value"])
    assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
    assert response.body["data"]["value"] == [2, 4]

    client.close()


def test_storage_unit_data_parser_validation(storage_setup):
    """Test that invalid data_parser inputs produce clear error messages."""
    _, put_get_address = storage_setup
    client = MockStorageClient(put_get_address)

    # 1. Non-callable data_parser should return a clear TypeError
    response = client.send_put(
        0,
        [0],
        {"data": [1]},
        data_parser="not_callable",
    )
    assert response.request_type == ZMQRequestType.PUT_ERROR
    assert "data_parser must be callable" in response.body["message"]

    # 2. data_parser returning non-dict should return a clear TypeError
    def bad_parser(field_data):
        return "not_a_dict"

    response = client.send_put(
        0,
        [1],
        {"data": [1]},
        data_parser=bad_parser,
    )
    assert response.request_type == ZMQRequestType.PUT_ERROR
    assert "data_parser must return a dict" in response.body["message"]

    # 3. data_parser deleting a key should return a clear ValueError
    def delete_key_parser(field_data):
        del field_data["data"]
        return field_data

    response = client.send_put(
        0,
        [2],
        {"data": [1], "extra": [2]},
        data_parser=delete_key_parser,
    )
    assert response.request_type == ZMQRequestType.PUT_ERROR
    assert "data_parser must not change dict keys" in response.body["message"]

    # 4. data_parser adding a key should return a clear ValueError
    def add_key_parser(field_data):
        field_data["new_key"] = [999]
        return field_data

    response = client.send_put(
        0,
        [3],
        {"data": [1]},
        data_parser=add_key_parser,
    )
    assert response.request_type == ZMQRequestType.PUT_ERROR
    assert "data_parser must not change dict keys" in response.body["message"]

    # 5. data_parser changing element count should return a clear ValueError
    def wrong_len_parser(field_data):
        field_data["data"] = field_data["data"][:-1]
        return field_data

    response = client.send_put(
        0,
        [4, 5],
        {"data": [1, 2]},
        data_parser=wrong_len_parser,
    )
    assert response.request_type == ZMQRequestType.PUT_ERROR
    assert "data_parser changed the number of elements" in response.body["message"]

    client.close()