import hashlib
import unittest
import torch
from mindie_llm.text_generator.mempool.utils import flatten_tensors, str_to_md5_hex, parse_global_segment_size
class TestMemPoolConfig(unittest.TestCase):
def test_flatten_tensors_single_tensor(self):
t1 = torch.tensor([1, 2])
nested = t1
result = flatten_tensors(nested)
self.assertEqual(result, [t1])
def test_flatten_tensors_simple(self):
t1 = torch.tensor([1, 2])
t2 = torch.tensor([3, 4])
nested = [t1, [t2]]
result = flatten_tensors(nested)
self.assertEqual(result, [t1, t2])
def test_flatten_tensors_deep(self):
t1 = torch.tensor([1])
t2 = torch.tensor([2])
t3 = torch.tensor([3])
t4 = torch.tensor([3])
nested = [[[t1, t2], [t3, t4]]]
result = flatten_tensors(nested)
self.assertEqual(result, [t1, t2, t3, t4])
def test_flatten_tensors_type_error(self):
with self.assertRaises(TypeError):
flatten_tensors(["not", "a", "tensor"])
def test_str_to_md5_hex(self):
input_str = "hello"
expected = hashlib.md5(input_str.encode("utf-8")).hexdigest()
result = str_to_md5_hex(input_str)
self.assertEqual(result, expected)
def test_str_to_md5_hex_empty(self):
input_str = ""
expected = hashlib.md5(b"").hexdigest()
result = str_to_md5_hex(input_str)
self.assertEqual(result, expected)
def test_parse_global_segment_size_basic(self):
self.assertEqual(parse_global_segment_size('1kb'), 1024)
self.assertEqual(parse_global_segment_size('1mb'), 1024**2)
self.assertEqual(parse_global_segment_size('1GB'), 1024**3)
self.assertEqual(parse_global_segment_size(100), 100)
self.assertEqual(parse_global_segment_size('100'), 100)
self.assertEqual(parse_global_segment_size('1.5KB'), int(1.5 * 1024))
def test_parse_global_segment_size_error(self):
with self.assertRaises(ValueError):
parse_global_segment_size('abc')
with self.assertRaises(ValueError):
parse_global_segment_size("")