"""
Unit tests for PureInferenceEngine - Zero ROS dependencies.
These tests verify the core inference logic can run independently:
- Device resolution
- Mock policy inference
- Tensor conversion
- InferenceResult structure
Run with: pytest tests/test_pure_inference_engine.py -v
"""
import numpy as np
import pytest
import torch
from torch import Tensor
from inference_service.core import (
InferenceResult,
MockPolicyWrapper,
PureInferenceEngine,
resolve_device,
)
class TestResolveDevice:
"""Tests for device resolution."""
def test_auto_returns_valid_device(self):
"""Auto should return a valid torch device."""
device = resolve_device("auto")
assert isinstance(device, torch.device)
assert device.type in ("cuda", "mps", "cpu")
def test_cpu_explicit(self):
"""Explicit CPU should return CPU device."""
device = resolve_device("cpu")
assert device.type == "cpu"
def test_cuda_explicit_if_available(self):
"""Explicit CUDA should work if available."""
if torch.cuda.is_available():
device = resolve_device("cuda")
assert device.type == "cuda"
else:
with pytest.raises(RuntimeError, match="CUDA requested but not available"):
resolve_device("cuda")
def test_cuda_with_index(self):
"""CUDA with index should parse correctly."""
if torch.cuda.is_available():
device = resolve_device("cuda:0")
assert device.type == "cuda"
assert device.index == 0
def test_invalid_device_raises(self):
"""Invalid device string should raise ValueError."""
with pytest.raises(ValueError, match="Unknown device"):
resolve_device("invalid_device")
class TestMockPolicyWrapper:
"""Tests for mock policy wrapper."""
def test_single_action_output(self):
"""Mock should produce single action."""
wrapper = MockPolicyWrapper(action_dim=7, chunk_size=1)
wrapper.load("", torch.device("cpu"))
batch = {"observation.state": torch.randn(1, 7)}
action = wrapper.infer(batch)
assert action.shape == (7,)
assert action.dtype == torch.float32
def test_chunk_action_output(self):
"""Mock should produce action chunk."""
wrapper = MockPolicyWrapper(action_dim=7, chunk_size=16)
wrapper.load("", torch.device("cpu"))
batch = {"observation.state": torch.randn(1, 7)}
action = wrapper.infer(batch)
assert action.shape == (16, 7)
def test_policy_type(self):
"""Mock should report correct policy type."""
wrapper = MockPolicyWrapper(policy_type="test_policy")
assert wrapper.policy_type == "test_policy"
assert wrapper.backend_type == ""
assert wrapper.uses_action_chunking is False
def test_chunk_size(self):
"""Mock should report correct chunk size."""
wrapper = MockPolicyWrapper(chunk_size=32)
assert wrapper.get_chunk_size() == 32
class TestPureInferenceEngine:
"""Tests for PureInferenceEngine with mock policy."""
@pytest.fixture
def mock_engine(self):
"""Create engine with mock policy."""
return PureInferenceEngine(policy_wrapper=MockPolicyWrapper(action_dim=7, chunk_size=1))
@pytest.fixture
def mock_chunking_engine(self):
"""Create engine with mock chunking policy."""
return PureInferenceEngine(policy_wrapper=MockPolicyWrapper(action_dim=7, chunk_size=16))
def test_inference_with_tensor_input(self, mock_engine):
"""Engine should accept tensor inputs."""
batch = {
"observation.state": torch.randn(1, 7),
}
result = mock_engine(batch)
assert isinstance(result, InferenceResult)
assert result.action.shape == (7,)
assert result.chunk_size == 1
assert result.policy_type == "mock"
def test_inference_with_numpy_input(self, mock_engine):
"""Engine should accept numpy inputs."""
batch = {
"observation.state": np.random.randn(1, 7).astype(np.float32),
}
result = mock_engine(batch)
assert isinstance(result.action, Tensor)
assert result.action.shape == (7,)
def test_inference_with_image(self, mock_engine):
"""Engine should handle image tensors."""
batch = {
"observation.state": torch.randn(1, 7),
"observation.image": torch.randn(1, 3, 224, 224),
}
result = mock_engine(batch)
assert result.action.shape == (7,)
def test_chunking_inference(self, mock_chunking_engine):
"""Chunking policy should return action chunk."""
batch = {
"observation.state": torch.randn(1, 7),
}
result = mock_chunking_engine(batch)
assert result.action.shape == (16, 7)
assert result.chunk_size == 16
def test_latency_measurement(self, mock_engine):
"""Engine should measure latency."""
batch = {"observation.state": torch.randn(1, 7)}
result = mock_engine(batch)
assert result.latency_ms >= 0
def test_device_property(self, mock_engine):
"""Engine should expose device property."""
assert mock_engine.device.type in ("cpu", "cuda", "mps")
def test_policy_type_property(self, mock_engine):
"""Engine should expose policy type."""
assert mock_engine.policy_type == "mock"
def test_chunk_size_property(self, mock_engine):
"""Engine should expose chunk size."""
assert mock_engine.chunk_size == 1
def test_use_action_chunking_property(self, mock_engine, mock_chunking_engine):
"""Engine should report chunking status."""
assert mock_engine.use_action_chunking is False
assert mock_chunking_engine.use_action_chunking is False
def test_result_to_numpy(self, mock_engine):
"""InferenceResult should convert to numpy."""
batch = {"observation.state": torch.randn(1, 7)}
result = mock_engine(batch)
action_np = result.to_numpy()
assert isinstance(action_np, np.ndarray)
assert action_np.shape == (7,)
class TestInferenceResult:
"""Tests for InferenceResult dataclass."""
def test_default_values(self):
"""Result should have sensible defaults."""
action = torch.randn(7)
result = InferenceResult(action=action)
assert result.chunk_size == 1
assert result.latency_ms == 0.0
assert result.policy_type == ""
assert result.backend_type == ""
def test_shape_property(self):
"""Result should expose shape."""
action = torch.randn(16, 7)
result = InferenceResult(action=action, chunk_size=16)
assert result.shape == (16, 7)
def test_to_numpy(self):
"""Result should convert to numpy."""
action = torch.randn(7)
result = InferenceResult(action=action)
action_np = result.to_numpy()
assert isinstance(action_np, np.ndarray)
np.testing.assert_array_almost_equal(action_np, action.numpy())
class TestTensorConversion:
"""Tests for internal tensor conversion."""
def test_integer_image_normalization(self):
"""Integer images should be normalized to 0-1."""
wrapper = MockPolicyWrapper()
engine = PureInferenceEngine(policy_wrapper=wrapper)
batch = {"observation.image": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)}
result = engine(batch)
assert result.action is not None
def test_float_image_passthrough(self):
"""Float images should pass through."""
wrapper = MockPolicyWrapper()
engine = PureInferenceEngine(policy_wrapper=wrapper)
batch = {"observation.image": np.random.randn(224, 224, 3).astype(np.float32)}
result = engine(batch)
assert result.action is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])