import sys
from unittest import mock
import pytest
import torch
try:
import torch_npu
except ImportError:
pass
class MockDsTensorClient:
def __init__(self, host, port, device_id):
self.storage = {}
def init(self):
pass
def mset_d2h(self, keys, values):
for k, v in zip(keys, values, strict=True):
assert v.device.type == "npu"
self.storage[k] = v
def mget_h2d(self, keys, out_tensors):
for i, k in enumerate(keys):
if k in self.storage:
out_tensors[i].copy_(self.storage[k])
def delete(self, keys):
for k in keys:
self.storage.pop(k, None)
class MockKVClient:
def __init__(self, host, port):
self.storage = {}
def init(self):
pass
def mcreate(self, keys, sizes):
class MockBuffer:
def __init__(self, size):
self._data = bytearray(size)
def MutableData(self):
return memoryview(self._data)
self._current_keys = keys
return [MockBuffer(s) for s in sizes]
def mset_buffer(self, buffers):
for key, buf in zip(self._current_keys, buffers, strict=True):
self.storage[key] = bytes(buf.MutableData())
def get_buffers(self, keys):
return [memoryview(self.storage[k]) if k in self.storage else None for k in keys]
def delete(self, keys):
for k in keys:
self.storage.pop(k, None)
@pytest.fixture
def mock_yr_datasystem():
"""Wipe real 'yr' modules and inject mocks."""
to_delete = [k for k in sys.modules if k.startswith("yr")]
for mod in to_delete:
del sys.modules[mod]
ds_mock = mock.MagicMock()
ds_mock.DsTensorClient = MockDsTensorClient
ds_mock.KVClient = MockKVClient
yr_mock = mock.MagicMock(datasystem=ds_mock)
def mock_find_reachable_host(port, timeout=1.0):
return "127.0.0.1"
with (
mock.patch.dict("sys.modules", {"yr": yr_mock, "yr.datasystem": ds_mock}),
mock.patch("transfer_queue.storage.clients.yuanrong_client.YUANRONG_DATASYSTEM_IMPORTED", True, create=True),
mock.patch("transfer_queue.storage.clients.yuanrong_client.datasystem", ds_mock),
mock.patch(
"transfer_queue.storage.clients.yuanrong_client.find_reachable_host", side_effect=mock_find_reachable_host
),
):
yield
@pytest.fixture
def config():
return {"worker_port": 12345, "enable_yr_npu_optimization": True}
def assert_tensors_equal(a: torch.Tensor, b: torch.Tensor):
assert a.shape == b.shape and a.dtype == b.dtype
assert torch.equal(a.cpu(), b.cpu())
class TestYuanrongStorageE2E:
@pytest.fixture(autouse=True)
def setup_client(self, mock_yr_datasystem, config):
from transfer_queue.storage.clients.yuanrong_client import YuanrongStorageClient
self.client_cls = YuanrongStorageClient
self.config = config
def _create_data(self, mode="cpu"):
if mode == "cpu":
keys = ["t", "s", "i"]
vals = [torch.randn(2), "hi", 1]
elif mode == "npu":
if not (hasattr(torch, "npu") and torch.npu.is_available()):
pytest.skip("NPU required")
keys = ["n1", "n2"]
vals = [torch.randn(2).npu(), torch.tensor([1]).npu()]
else:
if not (hasattr(torch, "npu") and torch.npu.is_available()):
pytest.skip("NPU required")
keys = ["n1", "c1"]
vals = [torch.randn(2).npu(), "cpu"]
shapes = [list(v.shape) if isinstance(v, torch.Tensor) else [] for v in vals]
dtypes = [v.dtype if isinstance(v, torch.Tensor) else None for v in vals]
return keys, vals, shapes, dtypes
def test_mock_can_work(self, config):
mock_class = (MockDsTensorClient, MockKVClient)
client = self.client_cls(config)
for strategy in client._strategies:
assert isinstance(strategy._ds_client, mock_class)
def test_cpu_only_flow(self, config):
client = self.client_cls(config)
keys, vals, shp, dt = self._create_data("cpu")
meta = client.put(keys, vals)
assert all(m == "2" for m in meta)
ret = client.get(keys, shp, dt, meta)
for o, r in zip(vals, ret, strict=True):
if isinstance(o, torch.Tensor):
assert_tensors_equal(o, r)
else:
assert o == r
client.clear(keys, meta)
assert all(v is None for v in client.get(keys, shp, dt, meta))
def test_npu_only_flow(self, config):
keys, vals, shp, dt = self._create_data("npu")
client = self.client_cls(config)
meta = client.put(keys, vals)
assert all(m == "1" for m in meta)
ret = client.get(keys, shp, dt, meta)
for o, r in zip(vals, ret, strict=True):
assert_tensors_equal(o, r)
client.clear(keys, meta)
def test_mixed_flow(self, config):
keys, vals, shp, dt = self._create_data("mixed")
client = self.client_cls(config)
meta = client.put(keys, vals)
assert set(meta) == {"1", "2"}
ret = client.get(keys, shp, dt, meta)
for o, r in zip(vals, ret, strict=True):
if isinstance(o, torch.Tensor):
assert_tensors_equal(o, r)
else:
assert o == r
def test_get_with_invalid_backend_meta_raises_error(self, config):
"""Verify that get raises ValueError when backend_meta contains an unrecognized tag."""
client = self.client_cls(config)
keys = ["k1"]
shapes = [[]]
dtypes = [None]
invalid_meta = ["99"]
with pytest.raises(ValueError, match="Cannot retrieve stored data"):
client.get(keys, shapes, dtypes, invalid_meta)
def test_get_with_empty_backend_meta_raises_error(self, config):
"""Verify that get raises ValueError when backend_meta contains empty tags (not previously stored)."""
client = self.client_cls(config)
keys = ["k1"]
shapes = [[]]
dtypes = [None]
empty_meta = [""]
with pytest.raises(ValueError, match="no backend metadata"):
client.get(keys, shapes, dtypes, empty_meta)
def test_put_with_no_strategies_raises_error(self, config):
"""Verify that put raises ValueError when no strategy supports the value type."""
client = self.client_cls(config)
client._strategies = []
with pytest.raises(ValueError, match=f"No storage backend can handle {self.client_cls.ROUTE_ITEM_AS_VALUE}"):
client.put(["k1"], [1])
def test_clear_with_empty_backend_meta_silent(self, config):
"""Verify that clear silently skips keys with empty backend_meta (not previously stored)."""
client = self.client_cls(config)
empty_meta = [""]
client.clear(["k1"], empty_meta)