import unittest
from unittest.mock import MagicMock, patch
import torch
from mindie_llm.runtime.utils.torch_utils import set_default_torch_dtype
class TestSetDefaultTorchDtype(unittest.TestCase):
"""Test cases for set_default_torch_dtype context manager."""
def setUp(self):
"""Set up test fixtures."""
self.original_dtype = torch.get_default_dtype()
def tearDown(self):
"""Clean up after tests."""
torch.set_default_dtype(self.original_dtype)
def test_set_default_torch_dtype_float32(self):
"""Test setting default dtype to float32."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(torch.float32):
self.assertEqual(torch.get_default_dtype(), torch.float32)
tensor = torch.tensor([1.0, 2.0, 3.0])
self.assertEqual(tensor.dtype, torch.float32)
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_float16(self):
"""Test setting default dtype to float16."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(torch.float16):
self.assertEqual(torch.get_default_dtype(), torch.float16)
tensor = torch.tensor([1.0, 2.0, 3.0])
self.assertEqual(tensor.dtype, torch.float16)
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_bfloat16(self):
"""Test setting default dtype to bfloat16."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(torch.bfloat16):
self.assertEqual(torch.get_default_dtype(), torch.bfloat16)
tensor = torch.tensor([1.0, 2.0, 3.0])
self.assertEqual(tensor.dtype, torch.bfloat16)
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_float64(self):
"""Test setting default dtype to float64."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(torch.float64):
self.assertEqual(torch.get_default_dtype(), torch.float64)
tensor = torch.tensor([1.0, 2.0, 3.0])
self.assertEqual(tensor.dtype, torch.float64)
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_restores_after_exception(self):
"""Test that dtype is restored even when exception occurs."""
original_dtype = torch.get_default_dtype()
try:
with set_default_torch_dtype(torch.float16):
self.assertEqual(torch.get_default_dtype(), torch.float16)
raise ValueError("Test exception")
except ValueError:
pass
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_nested_contexts(self):
"""Test nested context managers."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(torch.float16):
self.assertEqual(torch.get_default_dtype(), torch.float16)
with set_default_torch_dtype(torch.float32):
self.assertEqual(torch.get_default_dtype(), torch.float32)
tensor = torch.tensor([1.0, 2.0])
self.assertEqual(tensor.dtype, torch.float32)
self.assertEqual(torch.get_default_dtype(), torch.float16)
tensor = torch.tensor([1.0, 2.0])
self.assertEqual(tensor.dtype, torch.float16)
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_multiple_tensors(self):
"""Test that multiple tensors created inside context use new dtype."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(torch.float16):
tensor1 = torch.tensor([1.0, 2.0])
tensor2 = torch.tensor([3.0, 4.0])
tensor3 = torch.randn(5, 5)
self.assertEqual(tensor1.dtype, torch.float16)
self.assertEqual(tensor2.dtype, torch.float16)
self.assertEqual(tensor3.dtype, torch.float16)
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_context_manager_returns_none(self):
"""Test that context manager yields None."""
with set_default_torch_dtype(torch.float16) as ctx:
self.assertIsNone(ctx)
def test_set_default_torch_dtype_same_dtype(self):
"""Test setting dtype to the same value."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(original_dtype):
self.assertEqual(torch.get_default_dtype(), original_dtype)
tensor = torch.tensor([1.0, 2.0])
self.assertEqual(tensor.dtype, original_dtype)
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_empty_context(self):
"""Test context manager with empty body."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(torch.float16):
pass
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_sequential_contexts(self):
"""Test sequential context managers."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(torch.float16):
self.assertEqual(torch.get_default_dtype(), torch.float16)
self.assertEqual(torch.get_default_dtype(), original_dtype)
with set_default_torch_dtype(torch.float32):
self.assertEqual(torch.get_default_dtype(), torch.float32)
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_with_explicit_dtype_override(self):
"""Test that explicit dtype in tensor creation overrides default."""
original_dtype = torch.get_default_dtype()
with set_default_torch_dtype(torch.float16):
tensor1 = torch.tensor([1.0, 2.0], dtype=torch.float32)
self.assertEqual(tensor1.dtype, torch.float32)
tensor2 = torch.tensor([1.0, 2.0])
self.assertEqual(tensor2.dtype, torch.float16)
self.assertEqual(torch.get_default_dtype(), original_dtype)
def test_set_default_torch_dtype_exception_in_context(self):
"""Test exception handling within context."""
original_dtype = torch.get_default_dtype()
try:
with set_default_torch_dtype(torch.float16):
self.assertEqual(torch.get_default_dtype(), torch.float16)
tensor1 = torch.tensor([1.0, 2.0])
self.assertEqual(tensor1.dtype, torch.float16)
raise RuntimeError("Test error")
except RuntimeError:
pass
self.assertEqual(torch.get_default_dtype(), original_dtype)
tensor2 = torch.tensor([1.0, 2.0])
self.assertEqual(tensor2.dtype, original_dtype)
if __name__ == '__main__':
unittest.main()