"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import unittest
import torch
from torch import nn
from msmodelslim.processor.quarot.common.quarot_utils import (
QuaRotMode,
create_rot,
rotate_linear,
rotate_weight,
bake_mean_into_linear,
is_power_of_two,
get_decompose_dim,
)
from msmodelslim.utils.exception import UnsupportedError
class TestQuaRotMode(unittest.TestCase):
"""测试 QuaRotMode 枚举类"""
def test_enum_values_should_be_correct(self):
"""测试枚举值"""
self.assertEqual(QuaRotMode.HADAMARD.value, "hadamard")
self.assertEqual(QuaRotMode.BLOCK_HADAMARD_SHIFTED.value, "block_hadamard_shifted")
class TestIsPowerOfTwo(unittest.TestCase):
"""测试 is_power_of_two 函数"""
def test_powers_of_two_should_return_true(self):
"""测试 2 的幂次返回 True"""
powers = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
for n in powers:
with self.subTest(n=n):
self.assertTrue(is_power_of_two(n))
def test_non_powers_of_two_should_return_false(self):
"""测试非 2 的幂次返回 False"""
non_powers = [0, 3, 5, 6, 7, 9, 10, 15, 17, 100, 1000]
for n in non_powers:
with self.subTest(n=n):
self.assertFalse(is_power_of_two(n))
def test_negative_numbers_should_return_false(self):
"""测试负数返回 False"""
self.assertFalse(is_power_of_two(-1))
self.assertFalse(is_power_of_two(-2))
class TestCreateRot(unittest.TestCase):
"""测试 create_rot 函数"""
def test_hadamard_mode_should_return_correct_size(self):
"""测试 HADAMARD 模式返回正确大小"""
size = 8
rot = create_rot(QuaRotMode.HADAMARD, size)
self.assertEqual(rot.shape, (size, size))
def test_hadamard_mode_with_block_size_should_return_correct_size(self):
"""测试 HADAMARD 模式带 block_size"""
size = 16
block_size = 4
rot = create_rot(QuaRotMode.HADAMARD, size, block_size=block_size)
self.assertEqual(rot.shape, (size, size))
def test_hadamard_mode_should_be_orthogonal(self):
"""测试 HADAMARD 模式生成正交矩阵"""
size = 8
rot = create_rot(QuaRotMode.HADAMARD, size)
product = rot.T @ rot
expected = torch.eye(size)
self.assertTrue(torch.allclose(product, expected, atol=1e-5))
def test_block_hadamard_shifted_mode_should_return_correct_size(self):
"""测试 BLOCK_HADAMARD_SHIFTED 模式返回正确大小"""
size = 64
rot = create_rot(QuaRotMode.BLOCK_HADAMARD_SHIFTED, size)
self.assertEqual(rot.shape, (size, size))
def test_block_hadamard_shifted_with_rot_step_1(self):
"""测试 BLOCK_HADAMARD_SHIFTED 模式 rot_step=1"""
size = 64
rot = create_rot(QuaRotMode.BLOCK_HADAMARD_SHIFTED, size, rot_step=1)
self.assertEqual(rot.shape, (size, size))
def test_block_hadamard_shifted_with_rot_step_2(self):
"""测试 BLOCK_HADAMARD_SHIFTED 模式 rot_step=2"""
size = 64
rot = create_rot(QuaRotMode.BLOCK_HADAMARD_SHIFTED, size, rot_step=2)
self.assertEqual(rot.shape, (size, size))
def test_block_hadamard_shifted_with_rot_step_3(self):
"""测试 BLOCK_HADAMARD_SHIFTED 模式 rot_step=3"""
size = 64
rot = create_rot(QuaRotMode.BLOCK_HADAMARD_SHIFTED, size, rot_step=3)
self.assertEqual(rot.shape, (size, size))
def test_block_hadamard_shifted_with_invalid_rot_step_should_raise(self):
"""测试 BLOCK_HADAMARD_SHIFTED 模式无效 rot_step 抛出异常"""
size = 64
with self.assertRaises(UnsupportedError):
create_rot(QuaRotMode.BLOCK_HADAMARD_SHIFTED, size, rot_step=4)
class TestRotateLinear(unittest.TestCase):
"""测试 rotate_linear 函数"""
def test_right_rotate_should_modify_weight(self):
"""测试右旋转修改权重"""
linear = nn.Linear(8, 4)
rot = torch.eye(8)
original_weight = linear.weight.data.clone()
rotate_linear(linear, rot, right_rotate=True)
self.assertTrue(torch.allclose(linear.weight.data, original_weight))
def test_left_rotate_should_modify_weight(self):
"""测试左旋转修改权重"""
linear = nn.Linear(8, 4)
rot = torch.eye(4)
original_weight = linear.weight.data.clone()
rotate_linear(linear, rot, right_rotate=False)
self.assertTrue(torch.allclose(linear.weight.data, original_weight))
def test_left_rotate_with_bias_should_modify_bias(self):
"""测试左旋转修改偏置"""
linear = nn.Linear(8, 4, bias=True)
rot = torch.eye(4)
original_bias = linear.bias.data.clone()
rotate_linear(linear, rot, right_rotate=False)
self.assertTrue(torch.allclose(linear.bias.data, original_bias))
def test_right_rotate_should_not_modify_bias(self):
"""测试右旋转不修改偏置"""
linear = nn.Linear(8, 4, bias=True)
rot = torch.eye(8)
original_bias = linear.bias.data.clone()
rotate_linear(linear, rot, right_rotate=True)
self.assertTrue(torch.allclose(linear.bias.data, original_bias))
def test_block_rotation_should_expand_rotation_matrix(self):
"""测试块旋转自动扩展旋转矩阵"""
linear = nn.Linear(16, 4)
rot = torch.eye(8)
original_weight = linear.weight.data.clone()
rotate_linear(linear, rot, right_rotate=True)
self.assertTrue(torch.allclose(linear.weight.data, original_weight, atol=1e-5))
def test_invalid_rotation_dim_should_raise(self):
"""测试无效旋转维度抛出异常"""
linear = nn.Linear(10, 4)
rot = torch.eye(3)
with self.assertRaises(UnsupportedError):
rotate_linear(linear, rot, right_rotate=True)
class TestRotateWeight(unittest.TestCase):
"""测试 rotate_weight 函数"""
def test_right_rotate_should_modify_weight(self):
"""测试右旋转修改权重"""
weight = torch.randn(4, 8)
rot = torch.eye(8)
original_weight = weight.clone()
rotate_weight(weight, rot, right_rotate=True)
self.assertTrue(torch.allclose(weight.data, original_weight))
def test_left_rotate_should_modify_weight(self):
"""测试左旋转修改权重"""
weight = torch.randn(4, 8)
rot = torch.eye(4)
original_weight = weight.clone()
rotate_weight(weight, rot, right_rotate=False)
self.assertTrue(torch.allclose(weight.data, original_weight))
def test_list_rotation_should_be_converted_to_block_diag(self):
"""测试列表旋转转换为块对角矩阵"""
weight = torch.randn(4, 16)
rot_list = [torch.eye(8), torch.eye(8)]
original_weight = weight.clone()
rotate_weight(weight, rot_list, right_rotate=True)
self.assertTrue(torch.allclose(weight.data, original_weight, atol=1e-5))
class TestBakeMeanIntoLinear(unittest.TestCase):
"""测试 bake_mean_into_linear 函数"""
def test_should_subtract_mean_from_weight(self):
"""测试从权重中减去均值"""
linear = nn.Linear(8, 4)
bake_mean_into_linear(linear)
weight_mean = linear.weight.data.double().mean()
self.assertTrue(torch.abs(weight_mean) < 0.5)
def test_should_subtract_mean_from_bias(self):
"""测试从偏置中减去均值"""
linear = nn.Linear(8, 4, bias=True)
bake_mean_into_linear(linear)
bias_mean = linear.bias.data.double().mean()
self.assertTrue(torch.abs(bias_mean) < 1e-4)
def test_without_bias_should_only_modify_weight(self):
"""测试无偏置时只修改权重"""
linear = nn.Linear(8, 4, bias=False)
original_weight = linear.weight.data.clone()
bake_mean_into_linear(linear)
self.assertFalse(torch.allclose(linear.weight.data, original_weight))
class TestGetDecomposeDim(unittest.TestCase):
"""测试 get_decompose_dim 函数"""
def test_valid_dimensions_should_return_tuple(self):
"""测试有效维度返回元组"""
for n in [16, 32, 64]:
with self.subTest(n=n):
a, b = get_decompose_dim(n)
self.assertIsInstance(a, int)
self.assertIsInstance(b, int)
self.assertGreater(a, 0)
self.assertGreater(b, 0)
class TestFuseLnLinear(unittest.TestCase):
"""测试 fuse_ln_linear 函数"""
def test_should_fuse_single_layernorm(self):
"""测试融合单个 layernorm"""
from msmodelslim.processor.quarot.common.quarot_utils import fuse_ln_linear
ln = nn.Module()
ln.weight = nn.Parameter(torch.tensor([2.0, 3.0, 4.0]))
ln.bias = nn.Parameter(torch.tensor([0.1, 0.2, 0.3]))
linear = nn.Linear(3, 2)
fuse_ln_linear([ln], [linear])
self.assertTrue(torch.allclose(ln.weight.data, torch.ones(3)))
self.assertTrue(torch.allclose(ln.bias.data, torch.zeros(3)))
def test_should_raise_for_2d_layernorm_weight(self):
"""测试 2D layernorm weight 抛出异常"""
from msmodelslim.processor.quarot.common.quarot_utils import fuse_ln_linear
ln = nn.Module()
ln.weight = nn.Parameter(torch.randn(2, 3))
linear = nn.Linear(3, 2)
with self.assertRaises(UnsupportedError):
fuse_ln_linear([ln], [linear])
def test_should_raise_for_mismatched_dimensions(self):
"""测试维度不匹配抛出异常"""
from msmodelslim.processor.quarot.common.quarot_utils import fuse_ln_linear
ln = nn.Module()
ln.weight = nn.Parameter(torch.randn(4))
linear = nn.Linear(3, 2)
with self.assertRaises(UnsupportedError):
fuse_ln_linear([ln], [linear])
def test_should_fuse_with_bias(self):
"""测试融合带 bias 的 layernorm"""
from msmodelslim.processor.quarot.common.quarot_utils import fuse_ln_linear
ln = nn.Module()
ln.weight = nn.Parameter(torch.tensor([2.0, 3.0]))
ln.bias = nn.Parameter(torch.tensor([0.1, 0.2]))
linear = nn.Linear(2, 3, bias=True)
fuse_ln_linear([ln], [linear])
self.assertTrue(torch.allclose(ln.weight.data, torch.ones(2)))
class TestOnlineRotateOprojInput(unittest.TestCase):
"""测试 online_rotate_o_proj_input 函数"""
def test_should_rotate_o_proj_weight(self):
"""测试旋转 o_proj 权重"""
from msmodelslim.processor.quarot.common.quarot_utils import online_rotate_o_proj_input
o_proj = nn.Linear(8, 4)
v_proj = nn.Linear(8, 4)
ov_pairs = {o_proj: v_proj}
rot_online = torch.tensor([[0.5, 0.5], [0.5, -0.5]])
num_attn_heads = 2
online_rotate_o_proj_input(ov_pairs, rot_online, num_attn_heads)
class TestOnlineRotateDownProj(unittest.TestCase):
"""测试 online_rotate_down_proj 函数"""
def test_should_rotate_down_proj_weight(self):
"""测试旋转 down_proj 权重"""
from msmodelslim.processor.quarot.common.quarot_utils import online_rotate_down_proj
up_proj = nn.Linear(8, 4)
down_proj = nn.Linear(4, 8)
pairs = {up_proj: down_proj}
rot1 = torch.tensor(
[[0.5, 0.5, 0.5, 0.5], [0.5, -0.5, 0.5, -0.5], [0.5, 0.5, -0.5, -0.5], [0.5, -0.5, -0.5, 0.5]]
)
rot2 = torch.tensor([[0.5, 0.5], [0.5, -0.5]])
online_rotate_down_proj(pairs, rot1, rot2)
if __name__ == '__main__':
unittest.main()