"""Unit tests for the packed-buffer batch serialization helpers in
``transfer_queue.utils.serial_utils``:
* ``calc_packed_size``
* ``pack_into`` / ``unpack_from``
* ``batch_encode_into``
* ``batch_decode_from``
"""
import numpy as np
import pytest
import torch
from transfer_queue.utils import serial_utils
def test_calc_packed_size_then_pack_unpack_roundtrip():
items = [b"hello", b"world!", b"x"]
size = serial_utils.calc_packed_size(items)
buf = bytearray(size)
serial_utils.pack_into(buf, items)
recovered = serial_utils.unpack_from(buf)
assert [bytes(mv) for mv in recovered] == items
def test_pack_into_writes_only_within_its_slice():
items = [b"alpha", b"beta", b"gamma"]
sz = serial_utils.calc_packed_size(items)
pad_before, pad_after = 17, 23
big = bytearray(pad_before + sz + pad_after)
serial_utils.pack_into(memoryview(big)[pad_before : pad_before + sz], items)
assert all(b == 0 for b in big[:pad_before])
assert all(b == 0 for b in big[pad_before + sz :])
recovered = serial_utils.unpack_from(memoryview(big)[pad_before : pad_before + sz])
assert [bytes(mv) for mv in recovered] == items
def test_unpack_from_zero_item_buffer():
items: list[bytes] = []
sz = serial_utils.calc_packed_size(items)
buf = bytearray(sz)
serial_utils.pack_into(buf, items)
assert serial_utils.unpack_from(buf) == []
def _mooncake_alloc(sizes: list[int]) -> list[torch.Tensor]:
"""Single big torch.uint8 tensor sliced into N views (mooncake-style)."""
big = torch.empty(sum(sizes), dtype=torch.uint8)
buffers: list[torch.Tensor] = []
offset = 0
for s in sizes:
buffers.append(big[offset : offset + s])
offset += s
return buffers
def _yuanrong_alloc(sizes: list[int]) -> list[bytearray]:
"""N independent bytearrays (yuanrong-style per-key buffer)."""
return [bytearray(s) for s in sizes]
def _decode_from_returned(buffers, alloc_kind):
if alloc_kind == "mooncake":
return serial_utils.batch_decode_from(buffers)
return serial_utils.batch_decode_from([bytes(b) for b in buffers])
def _roundtrip(values, alloc, alloc_kind, *, num_workers: int = 1):
buffers, sizes = serial_utils.batch_encode_into(values, alloc, num_workers=num_workers)
decoded = _decode_from_returned(buffers, alloc_kind)
return decoded, buffers, sizes
def test_batch_encode_into_return_shapes():
values = [{"x": 1}, "a string", torch.arange(8, dtype=torch.float32)]
buffers, sizes = serial_utils.batch_encode_into(values, _mooncake_alloc)
assert len(buffers) == len(values)
assert len(sizes) == len(values)
for b, s in zip(buffers, sizes, strict=True):
assert b.nbytes == s
def test_batch_encode_into_allows_padded_buffers():
"""Alloc may return buffers larger than requested sizes; batch_sizes still
reports the actual packed length, and the data round-trips correctly."""
pad = 32
def padded_alloc(sizes):
return [bytearray(s + pad) for s in sizes]
values = [b"alpha", {"k": "v"}, torch.arange(4, dtype=torch.float32)]
buffers, sizes = serial_utils.batch_encode_into(values, padded_alloc)
for b, s in zip(buffers, sizes, strict=True):
assert len(b) == s + pad
decoded = serial_utils.batch_decode_from([bytes(b[:s]) for b, s in zip(buffers, sizes, strict=True)])
_assert_equal_payloads(decoded, values)
_ROUNDTRIP_PARAMS = [
pytest.param([42, 3.14, "hello", b"bytes"], id="primitives"),
pytest.param([{"a": 1, "b": [1, 2, 3]}, {"nested": {"k": "v"}}], id="nested-dicts"),
pytest.param([torch.arange(10, dtype=torch.float32)], id="single-tensor"),
pytest.param(
[
torch.arange(100, dtype=torch.float32),
torch.randn(4, 4, dtype=torch.bfloat16),
torch.zeros(3, 5, dtype=torch.int64),
],
id="mixed-tensors",
),
pytest.param(
[np.arange(50, dtype=np.float64), np.ones((3, 3), dtype=np.int32)],
id="numpy-arrays",
),
pytest.param(
[{"meta": "v1", "arr": torch.arange(5, dtype=torch.float32)}, [1, 2, "three"]],
id="heterogeneous",
),
pytest.param(
[
torch.randn(2, 3, 4, 5, dtype=torch.float32),
torch.randn(2, 3, 4, 5, 6, dtype=torch.bfloat16),
],
id="high-rank-tensors",
),
pytest.param(
[
torch.nested.nested_tensor(
[torch.arange(3, dtype=torch.float32), torch.arange(5, dtype=torch.float32)],
layout=torch.strided,
),
torch.nested.nested_tensor(
[torch.randn(3, dtype=torch.bfloat16), torch.randn(5, dtype=torch.bfloat16)],
layout=torch.strided,
),
torch.nested.nested_tensor(
[torch.arange(4, dtype=torch.float32), torch.arange(7, dtype=torch.float32)],
layout=torch.jagged,
),
torch.nested.nested_tensor(
[torch.randn(4, dtype=torch.bfloat16), torch.randn(7, dtype=torch.bfloat16)],
layout=torch.jagged,
),
],
id="nested-tensors",
),
pytest.param(
[{"only": "one", "tensor": torch.arange(3, dtype=torch.float32)}],
id="single-value",
),
]
@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
def test_batch_encode_decode_roundtrip_mooncake(values):
decoded, *_ = _roundtrip(values, _mooncake_alloc, "mooncake")
_assert_equal_payloads(decoded, values)
@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
def test_batch_encode_decode_roundtrip_yuanrong(values):
decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong")
_assert_equal_payloads(decoded, values)
def test_batch_encode_decode_empty_list():
calls = []
def alloc(sizes):
calls.append(list(sizes))
return []
buffers, sizes = serial_utils.batch_encode_into([], alloc)
assert buffers == [] and sizes == []
assert calls == [[]]
assert serial_utils.batch_decode_from([]) == []
@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
def test_batch_encode_into_parallel_matches_serial(values):
serial_buffers, serial_sizes = serial_utils.batch_encode_into(values, _yuanrong_alloc, num_workers=1)
par_buffers, par_sizes = serial_utils.batch_encode_into(values, _yuanrong_alloc, num_workers=4)
assert serial_sizes == par_sizes
assert [bytes(b) for b in serial_buffers] == [bytes(b) for b in par_buffers]
def test_batch_encode_into_parallel_roundtrip_many_objects():
rng = np.random.default_rng(42)
values = []
for _ in range(64):
n = int(rng.integers(1, 257))
values.append(torch.from_numpy(rng.random(n).astype(np.float32)))
decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong", num_workers=8)
_assert_equal_payloads(decoded, values)
def _assert_equal_payloads(decoded, original):
assert len(decoded) == len(original)
for got, want in zip(decoded, original, strict=True):
if isinstance(want, torch.Tensor):
assert isinstance(got, torch.Tensor)
assert got.dtype == want.dtype
if want.is_nested:
assert got.is_nested
assert got.layout == want.layout
got_subs = got.unbind()
want_subs = want.unbind()
assert len(got_subs) == len(want_subs)
for g, w in zip(got_subs, want_subs, strict=True):
assert g.shape == w.shape
assert torch.equal(g, w)
else:
assert got.shape == want.shape
assert torch.equal(got, want)
elif isinstance(want, np.ndarray):
assert isinstance(got, np.ndarray)
assert got.dtype == want.dtype
assert got.shape == want.shape
assert np.array_equal(got, want)
elif isinstance(want, dict):
assert isinstance(got, dict)
assert got.keys() == want.keys()
for k in want:
_assert_equal_payloads([got[k]], [want[k]])
elif isinstance(want, list):
assert isinstance(got, list)
_assert_equal_payloads(got, want)
else:
assert got == want