from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from transfer_queue.storage.clients.yuanrong_client import GeneralKVClientAdapter
pytest.importorskip("yr")
class MockBuffer:
def __init__(self, size):
self.data = bytearray(size)
def MutableData(self):
return self.data
class TestYuanrongKVClientZCopy:
@pytest.fixture
def mock_kv_client(self, mocker):
mock_client = MagicMock()
mock_client.init.return_value = None
mocker.patch("yr.datasystem.KVClient", return_value=mock_client)
mocker.patch("yr.datasystem.DsTensorClient")
mocker.patch("transfer_queue.storage.clients.yuanrong_client.find_reachable_host", return_value="127.0.0.1")
return mock_client
@pytest.fixture
def storage_client(self, mock_kv_client):
return GeneralKVClientAdapter({"worker_port": 31501})
def test_mset_mget_p2p(self, storage_client, mocker):
def mock_encode(obj):
if isinstance(obj, torch.Tensor):
return [obj.numpy().tobytes()]
return [str(obj).encode("utf-8")]
def mock_decode(frames):
data = frames[0]
if len(data) == 12:
return torch.from_numpy(np.frombuffer(data, dtype=np.float32).copy())
try:
return data.tobytes().decode("utf-8")
except UnicodeDecodeError:
return data
mocker.patch("transfer_queue.utils.serial_utils.encode", side_effect=mock_encode)
mocker.patch("transfer_queue.utils.serial_utils.decode", side_effect=mock_decode)
stored_raw_buffers = []
def side_effect_mcreate(keys, sizes):
buffers = [MockBuffer(size) for size in sizes]
for b in buffers:
stored_raw_buffers.append(b.MutableData())
return buffers
storage_client._ds_client.mcreate.side_effect = side_effect_mcreate
storage_client._ds_client.get_buffers.return_value = stored_raw_buffers
storage_client.mset_zero_copy(
["tensor_key", "string_key"], [torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), "hello yuanrong"]
)
results = storage_client.mget_zero_copy(["tensor_key", "string_key"])
assert torch.allclose(results[0], torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32))
assert results[1] == "hello yuanrong"