"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import unittest
import torch
from msmodelslim.processor.quarot.common.hadamard import (
is_pow2,
walsh_matrix,
random_hadamard_matrix,
matmul_had_u,
matmul_had_u_t,
)
class TestIsPow2(unittest.TestCase):
"""测试 is_pow2 函数"""
def test_powers_of_two_should_return_true(self):
"""测试 2 的幂次返回 True"""
powers = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
for n in powers:
with self.subTest(n=n):
self.assertTrue(is_pow2(n))
def test_non_powers_of_two_should_return_false(self):
"""测试非 2 的幂次返回 False"""
non_powers = [0, 3, 5, 6, 7, 9, 10, 15, 17, 100, 1000]
for n in non_powers:
with self.subTest(n=n):
self.assertFalse(is_pow2(n))
def test_negative_numbers_should_return_false(self):
"""测试负数返回 False"""
self.assertFalse(is_pow2(-1))
self.assertFalse(is_pow2(-2))
class TestWalshMatrix(unittest.TestCase):
"""测试 walsh_matrix 函数"""
def test_size_1_should_return_1x1_matrix(self):
"""测试大小为 1 返回 1x1 矩阵"""
result = walsh_matrix(1, torch.float32, torch.device('cpu'))
self.assertEqual(result.shape, (1, 1))
self.assertEqual(result[0, 0].item(), 1.0)
def test_size_2_should_return_2x2_matrix(self):
"""测试大小为 2 返回 2x2 矩阵"""
result = walsh_matrix(2, torch.float32, torch.device('cpu'))
self.assertEqual(result.shape, (2, 2))
def test_size_4_should_return_4x4_matrix(self):
"""测试大小为 4 返回 4x4 矩阵"""
result = walsh_matrix(4, torch.float32, torch.device('cpu'))
self.assertEqual(result.shape, (4, 4))
def test_matrix_should_be_orthogonal(self):
"""测试矩阵是正交的"""
for size in [2, 4, 8]:
with self.subTest(size=size):
result = walsh_matrix(size, torch.float32, torch.device('cpu'))
product = result.T @ result
expected = torch.eye(size) * size
self.assertTrue(torch.allclose(product, expected, atol=1e-5))
def test_entries_should_be_plus_or_minus_one(self):
"""测试矩阵元素为 +1 或 -1"""
result = walsh_matrix(8, torch.float32, torch.device('cpu'))
self.assertTrue(torch.all((result == 1.0) | (result == -1.0)))
class TestRandomHadamardMatrix(unittest.TestCase):
"""测试 random_hadamard_matrix 函数"""
def test_should_return_correct_size(self):
"""测试返回正确大小"""
size = 8
result = random_hadamard_matrix(size, torch.float32, torch.device('cpu'))
self.assertEqual(result.shape, (size, size))
def test_should_be_orthogonal(self):
"""测试矩阵是正交的"""
size = 16
result = random_hadamard_matrix(size, torch.float32, torch.device('cpu'))
product = result.T @ result
expected = torch.eye(size)
self.assertTrue(torch.allclose(product, expected, atol=1e-5))
def test_different_calls_should_produce_different_matrices(self):
"""测试不同调用产生不同矩阵"""
size = 8
result1 = random_hadamard_matrix(size, torch.float32, torch.device('cpu'))
result2 = random_hadamard_matrix(size, torch.float32, torch.device('cpu'))
self.assertFalse(torch.allclose(result1, result2))
class TestMatmulHadU(unittest.TestCase):
"""测试 matmul_had_u 函数"""
def test_should_preserve_shape(self):
"""测试保持形状"""
x = torch.randn(2, 4, 8)
result = matmul_had_u(x)
self.assertEqual(result.shape, x.shape)
def test_power_of_two_size_should_work(self):
"""测试 2 的幂次大小"""
for size in [2, 4, 8, 16]:
with self.subTest(size=size):
x = torch.randn(1, size)
result = matmul_had_u(x)
self.assertEqual(result.shape, x.shape)
def test_should_be_normalized(self):
"""测试结果是归一化的"""
x = torch.eye(8)
result = matmul_had_u(x)
norms = torch.norm(result, dim=1)
self.assertTrue(torch.allclose(norms, torch.ones(8), atol=1e-5))
def test_transpose_should_work(self):
"""测试转置模式"""
x = torch.randn(2, 8)
result = matmul_had_u(x, transpose=True)
self.assertEqual(result.shape, x.shape)
class TestMatmulHadUT(unittest.TestCase):
"""测试 matmul_had_u_t 函数"""
def test_should_produce_same_result_as_transpose_mode(self):
"""测试与 matmul_had_u(transpose=True) 结果相同"""
x = torch.randn(2, 8)
result1 = matmul_had_u_t(x)
result2 = matmul_had_u(x, transpose=True)
self.assertTrue(torch.allclose(result1, result2))
def test_should_preserve_shape(self):
"""测试保持形状"""
x = torch.randn(3, 16)
result = matmul_had_u_t(x)
self.assertEqual(result.shape, x.shape)
def test_inverse_should_recover_original(self):
"""测试逆变换恢复原始数据"""
x = torch.randn(1, 8)
h_x = matmul_had_u(x)
ht_h_x = matmul_had_u_t(h_x)
self.assertTrue(torch.allclose(ht_h_x, x, atol=1e-5))
class TestGetHadK(unittest.TestCase):
"""测试 get_had_k 函数"""
def test_should_return_hadamard_matrix_for_power_of_two(self):
"""测试 2 的幂次返回 Hadamard 矩阵"""
from msmodelslim.processor.quarot.common.hadamard import get_had_k
had_k, k = get_had_k(8)
self.assertIsNone(had_k)
self.assertEqual(k, 1)
def test_should_raise_for_non_decomposable(self):
"""测试不可分解的数抛出异常"""
from msmodelslim.processor.quarot.common.hadamard import get_had_k
from msmodelslim.utils.exception import UnsupportedError
with self.assertRaises(UnsupportedError):
get_had_k(3)
class TestLoadHadamardMatrixFromTxt(unittest.TestCase):
"""测试 load_hadamard_matrix_from_txt 函数"""
def test_should_load_existing_file(self):
"""测试加载存在的文件"""
from msmodelslim.processor.quarot.common.hadamard import load_hadamard_matrix_from_txt
try:
matrix = load_hadamard_matrix_from_txt("had.12.txt")
self.assertIsNotNone(matrix)
self.assertEqual(matrix.shape[0], 12)
self.assertEqual(matrix.shape[1], 12)
except Exception:
pass
class TestTxtSafeLoad(unittest.TestCase):
"""测试 txt_safe_load 函数"""
def test_should_raise_for_nonexistent_file(self):
"""测试不存在的文件抛出异常"""
from msmodelslim.processor.quarot.common.hadamard import txt_safe_load
with self.assertRaises(Exception):
txt_safe_load("/nonexistent/path/file.txt")
if __name__ == '__main__':
unittest.main()