"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import pytest
import torch
import torch.nn as nn
from msmodelslim.processor.anti_outlier.common import VirtualVModuleFromQKVFused
class TestVirtualVModule:
"""测试VirtualVModuleFromQKVFused的各种功能"""
@pytest.fixture
def qkv_module_with_bias(self):
"""创建带有偏置的QKV模块"""
qkv_module = nn.Linear(512, 1536, bias=True)
return qkv_module
@pytest.fixture
def qkv_module_without_bias(self):
"""创建不带偏置的QKV模块"""
qkv_module = nn.Linear(512, 1536, bias=False)
return qkv_module
def test_virtual_v_module_init_normal(self, qkv_module_with_bias):
"""测试初始化是否正确"""
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module_with_bias,
num_attention_heads=8,
num_key_value_heads=8
)
assert virtual_v.attention_type == "MHA"
assert virtual_v.num_attention_heads == 8
assert virtual_v.num_key_value_heads == 8
assert virtual_v.qkv_module is qkv_module_with_bias
assert virtual_v.weight is not None
assert virtual_v.bias is not None
assert virtual_v.weight.shape == (512, 512)
assert virtual_v.bias.shape == (512,)
def test_virtual_v_module_update_weights(self, qkv_module_with_bias):
"""测试权重更新是否正确"""
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module_with_bias,
num_attention_heads=8,
num_key_value_heads=8
)
original_qkv_weight = virtual_v.qkv_module.weight.clone()
original_qkv_bias = virtual_v.qkv_module.bias.clone()
new_v_weight = torch.randn_like(virtual_v.weight)
new_v_bias = torch.randn_like(virtual_v.bias)
virtual_v.weight.data = new_v_weight
virtual_v.bias.data = new_v_bias
virtual_v.update_weights()
head_dim = 64
v_start = 8 * head_dim + 8 * head_dim
v_end = v_start + 8 * head_dim
assert torch.allclose(virtual_v.qkv_module.weight[v_start:v_end], new_v_weight)
assert torch.allclose(virtual_v.qkv_module.bias[v_start:v_end], new_v_bias)
assert torch.allclose(virtual_v.qkv_module.weight[:v_start], original_qkv_weight[:v_start])
assert torch.allclose(virtual_v.qkv_module.bias[:v_start], original_qkv_bias[:v_start])
def test_virtual_v_module_determine_attention_type_mha(self):
"""测试注意力类型识别为MHA"""
qkv_module = nn.Linear(512, 1536)
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module,
num_attention_heads=8,
num_key_value_heads=8
)
assert virtual_v.attention_type == "MHA"
def test_virtual_v_module_determine_attention_type_mqa(self):
"""测试注意力类型识别为MQA"""
qkv_module = nn.Linear(512, 1024)
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module,
num_attention_heads=8,
num_key_value_heads=1
)
assert virtual_v.attention_type == "MQA"
def test_virtual_v_module_determine_attention_type_gqa(self):
"""测试注意力类型识别为GQA"""
qkv_module = nn.Linear(512, 1280)
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module,
num_attention_heads=8,
num_key_value_heads=4
)
assert virtual_v.attention_type == "GQA"
def test_virtual_v_module_get_v_indices_mha(self):
"""测试V部分索引计算(MHA)"""
qkv_module = nn.Linear(512, 1536)
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module,
num_attention_heads=8,
num_key_value_heads=8
)
head_dim = 64
v_start, v_end = virtual_v._get_v_indices(head_dim)
expected_v_start = 8 * head_dim + 8 * head_dim
expected_v_end = expected_v_start + 8 * head_dim
assert v_start == expected_v_start
assert v_end == expected_v_end
def test_virtual_v_module_get_v_indices_mqa(self):
"""测试V部分索引计算(MQA)"""
qkv_module = nn.Linear(512, 1024)
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module,
num_attention_heads=8,
num_key_value_heads=1
)
head_dim = 64
v_start, v_end = virtual_v._get_v_indices(head_dim)
expected_v_start = 8 * head_dim + 1 * head_dim
expected_v_end = expected_v_start + 1 * head_dim
assert v_start == expected_v_start
assert v_end == expected_v_end
def test_virtual_v_module_get_v_indices_gqa(self):
"""测试V部分索引计算(GQA)"""
qkv_module = nn.Linear(512, 1280)
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module,
num_attention_heads=8,
num_key_value_heads=4
)
head_dim = 64
v_start, v_end = virtual_v._get_v_indices(head_dim)
expected_v_start = 8 * head_dim + 4 * head_dim
expected_v_end = expected_v_start + 4 * head_dim
assert v_start == expected_v_start
assert v_end == expected_v_end
def test_virtual_v_module_extract_v_weights_with_bias(self, qkv_module_with_bias):
"""测试提取V部分权重和偏置(有偏置)"""
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module_with_bias,
num_attention_heads=8,
num_key_value_heads=8
)
assert virtual_v.weight is not None
assert virtual_v.bias is not None
assert virtual_v.weight.shape == (512, 512)
assert virtual_v.bias.shape == (512,)
head_dim = 64
v_start = 8 * head_dim + 8 * head_dim
v_end = v_start + 8 * head_dim
assert torch.allclose(virtual_v.weight, qkv_module_with_bias.weight[v_start:v_end])
assert torch.allclose(virtual_v.bias, qkv_module_with_bias.bias[v_start:v_end])
def test_virtual_v_module_extract_v_weights_without_bias(self, qkv_module_without_bias):
"""测试提取V部分权重和偏置(无偏置)"""
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module_without_bias,
num_attention_heads=8,
num_key_value_heads=8
)
assert virtual_v.weight is not None
assert virtual_v.bias is None
assert virtual_v.weight.shape == (512, 512)
head_dim = 64
v_start = 8 * head_dim + 8 * head_dim
v_end = v_start + 8 * head_dim
assert torch.allclose(virtual_v.weight, qkv_module_without_bias.weight[v_start:v_end])
def test_virtual_v_module_update_weights_without_bias(self, qkv_module_without_bias):
"""测试更新权重(无原始偏置)"""
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module_without_bias,
num_attention_heads=8,
num_key_value_heads=8
)
new_v_bias = torch.randn(512)
virtual_v.bias = nn.Parameter(new_v_bias)
new_v_weight = torch.randn_like(virtual_v.weight)
virtual_v.weight.data = new_v_weight
virtual_v.update_weights()
assert virtual_v.qkv_module.bias is not None
head_dim = 64
v_start = 8 * head_dim + 8 * head_dim
v_end = v_start + 8 * head_dim
assert torch.allclose(virtual_v.qkv_module.weight[v_start:v_end], new_v_weight)
assert torch.allclose(virtual_v.qkv_module.bias[v_start:v_end], new_v_bias)
assert torch.allclose(
virtual_v.qkv_module.bias[:v_start],
torch.zeros_like(virtual_v.qkv_module.bias[:v_start])
)
def test_virtual_v_module_update_weights_with_bias(self, qkv_module_with_bias):
"""测试更新权重(有原始偏置)"""
virtual_v = VirtualVModuleFromQKVFused(
qkv_module=qkv_module_with_bias,
num_attention_heads=8,
num_key_value_heads=8
)
original_bias = virtual_v.qkv_module.bias.clone()
new_v_weight = torch.randn_like(virtual_v.weight)
new_v_bias = torch.randn_like(virtual_v.bias)
virtual_v.weight.data = new_v_weight
virtual_v.bias.data = new_v_bias
virtual_v.update_weights()
head_dim = 64
v_start = 8 * head_dim + 8 * head_dim
v_end = v_start + 8 * head_dim
assert torch.allclose(virtual_v.qkv_module.weight[v_start:v_end], new_v_weight)
assert torch.allclose(virtual_v.qkv_module.bias[v_start:v_end], new_v_bias)
assert torch.allclose(virtual_v.qkv_module.bias[:v_start], original_bias[:v_start])
if __name__ == "__main__":
pytest.main([__file__])