"""Unit tests for transfer_queue.utils.tensor_utils."""
import pytest
import torch
from transfer_queue.utils.tensor_utils import (
allocate_empty_tensors,
compute_stride,
get_nbytes,
merge_contiguous_memory,
)
class TestComputeStride:
"""Tests for compute_stride."""
def test_3d(self):
assert compute_stride((2, 3, 4)) == (12, 4, 1)
def test_1d(self):
assert compute_stride((5,)) == (1,)
def test_scalar(self):
assert compute_stride(()) == ()
def test_2d(self):
assert compute_stride((3, 5)) == (5, 1)
class TestGetNbytes:
"""Tests for get_nbytes."""
def test_basic(self):
dtypes = [torch.float32, torch.int32]
shapes = [(2, 3), (4,)]
result = get_nbytes(dtypes, shapes)
assert result == [2 * 3 * 4, 4 * 4]
def test_scalar(self):
dtypes = [torch.float64]
shapes = [()]
result = get_nbytes(dtypes, shapes)
assert result == [8]
def test_list_shape(self):
dtypes = [torch.float32]
shapes = [[]]
result = get_nbytes(dtypes, shapes)
assert result == [4]
def test_mixed_dtypes(self):
dtypes = [torch.float16, torch.float32, torch.int64]
shapes = [(10,), (10,), (10,)]
result = get_nbytes(dtypes, shapes)
assert result == [10 * 2, 10 * 4, 10 * 8]
class TestAllocateEmptyTensors:
"""Tests for allocate_empty_tensors."""
def test_basic(self):
dtypes = [torch.float32, torch.float32, torch.int32]
shapes = [(2, 3), (4,), (5,)]
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
assert len(tensors) == 3
assert len(ptrs) == 3
assert len(region_ptrs) == 2
assert len(region_sizes) == 2
assert tensors[0].untyped_storage().data_ptr() == region_ptrs[0]
assert tensors[1].untyped_storage().data_ptr() == region_ptrs[0]
assert tensors[2].untyped_storage().data_ptr() == region_ptrs[1]
assert list(tensors[0].shape) == [2, 3]
assert list(tensors[1].shape) == [4]
assert list(tensors[2].shape) == [5]
def test_scalar(self):
dtypes = [torch.float32, torch.int32]
shapes = [(), ()]
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
assert len(tensors) == 2
assert tensors[0].numel() == 1
assert tensors[1].numel() == 1
assert len(region_ptrs) == 2
def test_empty(self):
result = allocate_empty_tensors([], [])
assert result == ([], [], [], [])
def test_regions_complex(self):
"""Mixed dtypes and shapes: verify region counts, sizes, and per-tensor offsets."""
dtypes = [
torch.float32,
torch.int32,
torch.float32,
torch.float64,
torch.int32,
]
shapes = [(2, 3), (4,), (), (2, 2), (3, 2)]
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
assert len(region_ptrs) == 3
assert len(region_sizes) == 3
assert len(set(region_ptrs)) == 3
assert region_sizes[0] == 7 * 4
assert region_sizes[1] == 10 * 4
assert region_sizes[2] == 4 * 8
assert ptrs[0] == region_ptrs[0]
assert ptrs[1] == region_ptrs[1]
assert ptrs[2] == region_ptrs[0] + 6 * 4
assert ptrs[3] == region_ptrs[2]
assert ptrs[4] == region_ptrs[1] + 4 * 4
class TestMergeContiguousMemory:
"""Tests for merge_contiguous_memory."""
def test_basic_merge(self):
ptrs = [0, 10, 30]
sizes = [10, 20, 10]
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
assert merged_ptrs == [0]
assert merged_sizes == [40]
def test_no_contiguous(self):
ptrs = [0, 100, 200]
sizes = [50, 50, 50]
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
assert merged_ptrs == [0, 100, 200]
assert merged_sizes == [50, 50, 50]
def test_unsorted_input(self):
ptrs = [100, 0, 50]
sizes = [50, 50, 50]
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
assert merged_ptrs == [0]
assert merged_sizes == [150]
def test_single_region(self):
ptrs = [10]
sizes = [100]
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
assert merged_ptrs == [10]
assert merged_sizes == [100]
def test_empty(self):
assert merge_contiguous_memory([], []) == ([], [])
def test_mismatched_lengths_both_empty_not_triggered(self):
with pytest.raises(ValueError, match="ptrs and sizes must have the same length"):
merge_contiguous_memory([], [10])
with pytest.raises(ValueError, match="ptrs and sizes must have the same length"):
merge_contiguous_memory([0], [])
def test_three_continuous(self):
ptrs = [0, 10, 20]
sizes = [10, 10, 10]
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
assert merged_ptrs == [0]
assert merged_sizes == [30]