import unittest
import numpy as np
from mindie_llm.model_wrapper.utils.common_util import (
get_ipv4_obj, get_ipv6_obj, get_ip_obj,
ipv4_to_list, ipv6_to_list, ip_string_to_list,
ip_array_to_ipv4, ip_array_to_ipv6, ip_array_to_string
)
from mindie_llm.model_wrapper.utils.common_util import (
TransferType, generate_lora_strings, generate_user_request_id_string,
generate_mem_pool_decisions, generate_dp_inst_id, split_list_equally
)
class TestIPFunctions(unittest.TestCase):
def test_get_ipv4_obj_valid(self):
ipv4 = get_ipv4_obj('192.128.1.1', 'ip')
self.assertEqual(str(ipv4), '192.128.1.1')
def test_get_ipv4_obj_invalid(self):
with self.assertRaises(ValueError) as context:
_ = get_ipv4_obj('256.128.1.1', 'ip')
self.assertIn("ip=256.128.1.1 is invalid IPv4 address.", str(context.exception))
def test_get_ipv6_obj_valid(self):
ipv6 = get_ipv6_obj('2001:0db8:85a3:0000:0000:8a2e:0370:7334', 'ip')
self.assertEqual(str(ipv6), '2001:db8:85a3::8a2e:370:7334')
def test_get_ipv6_obj_invalid(self):
with self.assertRaises(ValueError) as context:
_ = get_ipv6_obj('2001:0db8:85a3:0000:0000:8a2e:0370:7334:1234', 'ip')
self.assertIn("ip=2001:0db8:85a3:0000:0000:8a2e:0370:7334:1234 is invalid IPv6 address.",
str(context.exception))
def test_get_ip_obj_valid_ipv4(self):
ip_obj = get_ip_obj('192.128.1.1', 'ip')
self.assertEqual(str(ip_obj), '192.128.1.1')
def test_get_ip_obj_valid_ipv6(self):
ip_obj = get_ip_obj('2001:0db8:85a3:0000:0000:8a2e:0370:7334', 'ip')
self.assertEqual(str(ip_obj), '2001:db8:85a3::8a2e:370:7334')
def test_get_ip_obj_invalid(self):
with self.assertRaises(ValueError) as context:
_ = get_ip_obj('1234', 'ip')
self.assertIn("ip=1234 is invalid IP address.", str(context.exception))
def test_ipv4_to_list_valid(self):
ip_list = ipv4_to_list('192.128.1.1')
self.assertEqual(ip_list, [192, 128, 1, 1, -1, -1, -1, -1])
def test_ipv4_to_list_format_error(self):
with self.assertRaises(ValueError) as context:
_ = ipv4_to_list('192.128.1')
self.assertIn("192.128.1 is invalid IPv4 format.", str(context.exception))
def test_ipv4_to_list_item_not_number(self):
with self.assertRaises(ValueError) as context:
_ = ipv4_to_list('192.128.1.a')
self.assertIn("IPv4 segment 'a' is not a valid number.", str(context.exception))
def test_ipv4_to_list_item_out_of_range(self):
with self.assertRaises(ValueError) as context:
_ = ipv4_to_list('192.128.1.256')
self.assertIn("IPv4 segment '256' out of range [0, 255].", str(context.exception))
def test_ipv6_to_list_valid(self):
ip_list = ipv6_to_list('2001:0db8:85a3:0000:0000:8a2e:0370:7334')
self.assertEqual(ip_list, [0x2001, 0x0db8, 0x85a3, 0, 0, 0x8a2e, 0x0370, 0x7334])
def test_ipv6_to_list_invalid(self):
with self.assertRaises(ValueError) as context:
_ = ipv6_to_list('2001:0db8:85a3:0000:0000:8a2e:0370:7334:1234')
self.assertIn("2001:0db8:85a3:0000:0000:8a2e:0370:7334:1234 is invalid IPv6 address.", str(context.exception))
def test_ip_string_to_list_valid_ipv4(self):
ip_list = ip_string_to_list('192.128.1.1')
self.assertEqual(ip_list, [192, 128, 1, 1, -1, -1, -1, -1])
def test_ip_string_to_list_valid_ipv6(self):
ip_list = ip_string_to_list('2001:0db8:85a3:0000:0000:8a2e:0370:7334')
self.assertEqual(ip_list, [0x2001, 0x0db8, 0x85a3, 0, 0, 0x8a2e, 0x0370, 0x7334])
def test_ip_string_to_list_invalid(self):
with self.assertRaises(ValueError) as context:
_ = ip_string_to_list('1234')
self.assertIn("1234 is invalid IP address.", str(context.exception))
def test_ip_array_to_ipv4_valid(self):
ipv4_str = ip_array_to_ipv4([192, 128, 1, 1, -1, -1, -1, -1])
self.assertEqual(ipv4_str, '192.128.1.1')
def test_ip_array_to_ipv4_invalid(self):
with self.assertRaises(ValueError) as context:
_ = ip_array_to_ipv4([256, 128, 1, 1, -1, -1, -1, -1])
self.assertIn("IPv4 segment '256' is out of range [0, 255].", str(context.exception))
def test_ip_array_to_ipv6_valid(self):
ipv6_str = ip_array_to_ipv6([0x2001, 0x0db8, 0x85a3, 0, 0, 0x8a2e, 0x0370, 0x7334])
self.assertEqual(ipv6_str, '2001:db8:85a3::8a2e:370:7334')
def test_ip_array_to_ipv6_invalid(self):
with self.assertRaises(ValueError) as context:
_ = ip_array_to_ipv6([0x20010, 0x0db8, 0x85a3, 0, 0, 0x8a2e, 0x0370, 0x7334])
self.assertIn("IPv6 segment '0x20010' is out of range [0, 0xFFFF].", str(context.exception))
def test_ip_array_to_string_valid_ipv4(self):
ip_str = ip_array_to_string([192, 128, 1, 1, -1, -1, -1, -1])
self.assertEqual(ip_str, '192.128.1.1')
def test_ip_array_to_string_valid_ipv6(self):
ip_str = ip_array_to_string([0x2001, 0x0db8, 0x85a3, 0, 0, 0x8a2e, 0x0370, 0x7334])
self.assertEqual(ip_str, '2001:db8:85a3::8a2e:370:7334')
def test_ip_array_to_string_invalid_array_len(self):
with self.assertRaises(ValueError) as context:
_ = ip_array_to_string([192, 128, 1, 1, -1, -1, -1])
self.assertIn("ip_array must be an array of 8 integers.", str(context.exception))
class TestGenerateLoraStrings(unittest.TestCase):
def test_generate_lora_strings_lora_id_is_none(self):
class MockRequest:
def get_tensor_by_name(self, name):
return "None"
request = MockRequest()
result = generate_lora_strings(request)
self.assertIsNone(result)
def test_generate_lora_strings_lora_id_is_valid(self):
class MockRequest:
def get_tensor_by_name(self, name):
return "Lora1"
request = MockRequest()
result = generate_lora_strings(request)
self.assertEqual(result, "Lora1")
class TestGenerateUserRequestIdString(unittest.TestCase):
def test_generate_user_request_id_string_valid(self):
class MockRequest:
def get_tensor_by_name(self, name):
return np.array([[65, 66, 67], [68, 69, 70]], dtype=np.uint8)
request = MockRequest()
result = generate_user_request_id_string(request)
self.assertEqual(result, "ABC")
def test_generate_user_request_id_string_invalid(self):
class MockRequest:
def get_tensor_by_name(self, name):
return np.array([1, 2, 3], dtype=np.uint8)
request = MockRequest()
result = generate_user_request_id_string(request)
self.assertIsNone(result)
class TestGenerateMemPoolDecisions(unittest.TestCase):
def test_generate_mem_pool_decisions_valid_transfer_tensor(self):
class MockRequest:
def get_tensor_by_name(self, name):
return np.array([b'[1, 1, 1]'], dtype=np.string_)
requests = [MockRequest()]
transfer_operation = TransferType.H2D
excepted_array = np.array([[1, 1, 1]])
result = generate_mem_pool_decisions(requests, transfer_operation)
self.assertTrue(np.array_equal(result, excepted_array))
def test_generate_mem_pool_decisions_invalid_transfer_tensor(self):
class MockRequest:
def get_tensor_by_name(self, name):
return None
requests = [MockRequest()]
transfer_operation = TransferType.H2D
result = generate_mem_pool_decisions(requests, transfer_operation)
self.assertIsNone(result)
def test_generate_mem_pool_decisions_invalid_transfer_operation(self):
class MockRequest:
def get_tensor_by_name(self, name):
return np.array([65, 66, 67, 68, 69, 70], dtype=np.uint8)
requests = [MockRequest()]
transfer_operation = "invalid"
result = generate_mem_pool_decisions(requests, transfer_operation)
self.assertIsNone(result)
def test_generate_mem_pool_decisions_requests_is_none(self):
requests = None
transfer_operation = TransferType.H2D
with self.assertRaises(ValueError) as context:
_ = generate_mem_pool_decisions(requests, transfer_operation)
self.assertIn("requests is not set.", str(context.exception))
def test_generate_mem_pool_decisions_decode_exception(self):
class MockRequest:
def get_tensor_by_name(self, name):
return np.array([b'[1]'], dtype=np.string_)
requests = [MockRequest()]
transfer_operation = TransferType.H2D
result = generate_mem_pool_decisions(requests, transfer_operation)
self.assertIsNone(result)
class TestGenerateDpInstId(unittest.TestCase):
def test_generate_dp_inst_id(self):
inst_id = "123"
dp_size = 3
expected_list = ["1230000000", "1230000001", "1230000002"]
result = generate_dp_inst_id(inst_id, dp_size)
self.assertEqual(result, expected_list)
class TestSplitListEqually(unittest.TestCase):
def test_split_list_equally_valid_split(self):
lst = [1, 2, 3, 4, 5, 6]
n = 3
expected_list = [[1, 2], [3, 4], [5, 6]]
result = split_list_equally(lst, n)
self.assertEqual(result, expected_list)
def test_split_list_equally_invalid_n(self):
lst = [1, 2, 3, 4, 5, 6]
n = 0
with self.assertRaises(ValueError) as context:
_ = split_list_equally(lst, n)
self.assertIn("Number of chunks 0 must be greater than 0", str(context.exception))
def test_split_list_equally_uneven_split(self):
lst = [1, 2, 3, 4, 5, 6]
n = 4
with self.assertRaises(ValueError) as context:
_ = split_list_equally(lst, n)
self.assertIn("Length 6 of the list cannot be divided evenly by 4", str(context.exception))
if __name__ == "__main__":
unittest.main()