# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from unittest import mock

import pytest
import torch

try:
    import torch_npu  # noqa: F401
except ImportError:
    pass


# --- Mock Backend Implementation ---
# In real scenarios, multiple DsTensorClients or KVClients share storage.
# Here, each mockClient is implemented with independent storage using a simple dictionary,
# and is only suitable for unit testing.


class MockDsTensorClient:
    def __init__(self, host, port, device_id):
        self.storage = {}

    def init(self):
        pass

    def mset_d2h(self, keys, values):
        for k, v in zip(keys, values, strict=True):
            assert v.device.type == "npu"
            self.storage[k] = v

    def mget_h2d(self, keys, out_tensors):
        for i, k in enumerate(keys):
            # Note: If key is missing, tensor remains unchanged (mock limitation)
            if k in self.storage:
                out_tensors[i].copy_(self.storage[k])

    def delete(self, keys):
        for k in keys:
            self.storage.pop(k, None)


class MockKVClient:
    def __init__(self, host, port):
        self.storage = {}

    def init(self):
        pass

    def mcreate(self, keys, sizes):
        class MockBuffer:
            def __init__(self, size):
                self._data = bytearray(size)

            def MutableData(self):
                return memoryview(self._data)

        self._current_keys = keys
        return [MockBuffer(s) for s in sizes]

    def mset_buffer(self, buffers):
        for key, buf in zip(self._current_keys, buffers, strict=True):
            self.storage[key] = bytes(buf.MutableData())

    def get_buffers(self, keys):
        return [memoryview(self.storage[k]) if k in self.storage else None for k in keys]

    def delete(self, keys):
        for k in keys:
            self.storage.pop(k, None)


# --- Fixtures ---


@pytest.fixture
def mock_yr_datasystem():
    """Wipe real 'yr' modules and inject mocks."""

    # 1. Clean up sys.modules to force a fresh import under mock conditions
    # This ensures top-level code in yuanrong_client.py is re-executed
    to_delete = [k for k in sys.modules if k.startswith("yr")]
    for mod in to_delete:
        del sys.modules[mod]

    # 2. Setup Mock Objects
    ds_mock = mock.MagicMock()
    ds_mock.DsTensorClient = MockDsTensorClient
    ds_mock.KVClient = MockKVClient

    yr_mock = mock.MagicMock(datasystem=ds_mock)

    # 3. Apply patches
    # - sys.modules: Redirects 'import yr' to our mocks
    # - YUANRONG_DATASYSTEM_IMPORTED: Forces the existence check to True so initialize the client successfully
    # - datasystem: Direct attribute patch for the module
    # - find_reachable_host: Mock host detection to avoid real network checks
    def mock_find_reachable_host(port, timeout=1.0):
        return "127.0.0.1"

    with (
        mock.patch.dict("sys.modules", {"yr": yr_mock, "yr.datasystem": ds_mock}),
        mock.patch("transfer_queue.storage.clients.yuanrong_client.YUANRONG_DATASYSTEM_IMPORTED", True, create=True),
        mock.patch("transfer_queue.storage.clients.yuanrong_client.datasystem", ds_mock),
        mock.patch(
            "transfer_queue.storage.clients.yuanrong_client.find_reachable_host", side_effect=mock_find_reachable_host
        ),
    ):
        yield


@pytest.fixture
def config():
    return {"worker_port": 12345, "enable_yr_npu_optimization": True}


def assert_tensors_equal(a: torch.Tensor, b: torch.Tensor):
    assert a.shape == b.shape and a.dtype == b.dtype
    # Move to CPU for cross-device comparison
    assert torch.equal(a.cpu(), b.cpu())


# --- Test Suite ---


class TestYuanrongStorageE2E:
    @pytest.fixture(autouse=True)
    def setup_client(self, mock_yr_datasystem, config):
        # Lazy import to ensure mocks are active
        from transfer_queue.storage.clients.yuanrong_client import YuanrongStorageClient

        self.client_cls = YuanrongStorageClient
        self.config = config

    def _create_data(self, mode="cpu"):
        if mode == "cpu":
            keys = ["t", "s", "i"]
            vals = [torch.randn(2), "hi", 1]
        elif mode == "npu":
            if not (hasattr(torch, "npu") and torch.npu.is_available()):
                pytest.skip("NPU required")
            keys = ["n1", "n2"]
            vals = [torch.randn(2).npu(), torch.tensor([1]).npu()]
        else:  # mixed
            if not (hasattr(torch, "npu") and torch.npu.is_available()):
                pytest.skip("NPU required")
            keys = ["n1", "c1"]
            vals = [torch.randn(2).npu(), "cpu"]

        shapes = [list(v.shape) if isinstance(v, torch.Tensor) else [] for v in vals]
        dtypes = [v.dtype if isinstance(v, torch.Tensor) else None for v in vals]
        return keys, vals, shapes, dtypes

    def test_mock_can_work(self, config):
        mock_class = (MockDsTensorClient, MockKVClient)
        client = self.client_cls(config)
        for strategy in client._strategies:
            assert isinstance(strategy._ds_client, mock_class)

    def test_cpu_only_flow(self, config):
        client = self.client_cls(config)
        keys, vals, shp, dt = self._create_data("cpu")

        # Put & Verify Meta
        meta = client.put(keys, vals)
        # "2" is a tag added by YuanrongStorageClient, indicating that it is processed via General KV path.
        assert all(m == "2" for m in meta)

        # Get & Verify Values
        ret = client.get(keys, shp, dt, meta)
        for o, r in zip(vals, ret, strict=True):
            if isinstance(o, torch.Tensor):
                assert_tensors_equal(o, r)
            else:
                assert o == r

        # Clear & Verify
        client.clear(keys, meta)
        assert all(v is None for v in client.get(keys, shp, dt, meta))

    def test_npu_only_flow(self, config):
        keys, vals, shp, dt = self._create_data("npu")
        client = self.client_cls(config)

        meta = client.put(keys, vals)
        # "1" is a tag added by YuanrongStorageClient, indicating that it is processed via NPU path.
        assert all(m == "1" for m in meta)

        ret = client.get(keys, shp, dt, meta)
        for o, r in zip(vals, ret, strict=True):
            assert_tensors_equal(o, r)

        client.clear(keys, meta)

    def test_mixed_flow(self, config):
        keys, vals, shp, dt = self._create_data("mixed")
        client = self.client_cls(config)

        meta = client.put(keys, vals)
        assert set(meta) == {"1", "2"}

        ret = client.get(keys, shp, dt, meta)
        for o, r in zip(vals, ret, strict=True):
            if isinstance(o, torch.Tensor):
                assert_tensors_equal(o, r)
            else:
                assert o == r

    def test_get_with_invalid_backend_meta_raises_error(self, config):
        """Verify that get raises ValueError when backend_meta contains an unrecognized tag."""
        client = self.client_cls(config)
        keys = ["k1"]
        shapes = [[]]
        dtypes = [None]
        invalid_meta = ["99"]
        with pytest.raises(ValueError, match="Cannot retrieve stored data"):
            client.get(keys, shapes, dtypes, invalid_meta)

    def test_get_with_empty_backend_meta_raises_error(self, config):
        """Verify that get raises ValueError when backend_meta contains empty tags (not previously stored)."""
        client = self.client_cls(config)
        keys = ["k1"]
        shapes = [[]]
        dtypes = [None]
        empty_meta = [""]
        with pytest.raises(ValueError, match="no backend metadata"):
            client.get(keys, shapes, dtypes, empty_meta)

    def test_put_with_no_strategies_raises_error(self, config):
        """Verify that put raises ValueError when no strategy supports the value type."""
        client = self.client_cls(config)
        client._strategies = []
        with pytest.raises(ValueError, match=f"No storage backend can handle {self.client_cls.ROUTE_ITEM_AS_VALUE}"):
            client.put(["k1"], [1])

    def test_clear_with_empty_backend_meta_silent(self, config):
        """Verify that clear silently skips keys with empty backend_meta (not previously stored)."""
        client = self.client_cls(config)
        empty_meta = [""]
        # No exception, no warning — only debug log
        client.clear(["k1"], empty_meta)