"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
...
-------------------------------------------------------------------------
"""
import unittest
from unittest.mock import Mock, patch
import torch
from torch import nn
from msmodelslim.core.base.protocol import BatchProcessRequest
from msmodelslim.ir import FP8FakeQuantActivationPerHead, FakeQuantActivationPerToken
from msmodelslim.processor.quant.fa3 import (
FA3QuantProcessor,
FA3QuantProcessorConfig,
FA3QuantAdapterInterface,
FA3QuantPlaceHolder,
)
from msmodelslim.processor.quant.fa3.processor import FA3AttentionDetails, _FA3PerHeadObserver
from msmodelslim.utils.exception import UnsupportedError, SpecError, SchemaValidateError
from msmodelslim.core.quantizer.base import QConfig
from msmodelslim.ir.qal.qbase import QDType, QScope
from msmodelslim.ir.qal import QParam, QScheme
def create_qconfig(dtype, scope, symmetric=True, method="minmax"):
"""创建合法的 QConfig 对象"""
return QConfig(dtype=dtype, scope=scope, symmetric=symmetric, method=method)
def create_branch_qconfig_dict():
"""创建 FA3AttentionDetails 所需的字典"""
return {
"fa_q": {"dtype": "int8", "scope": "per_head", "symmetric": True, "method": "minmax"},
"fa_k": {"dtype": "fp8_e4m3", "scope": "per_head", "symmetric": True, "method": "minmax"},
}
def create_processor_config(include=None, exclude=None, qconfig=None, details=None):
"""直接构造 FA3QuantProcessorConfig 实例,补充必要字段"""
kwargs = {
"type": "fa3_quant",
"include": include if include is not None else ["*"],
"exclude": exclude if exclude is not None else [],
}
if qconfig is not None:
kwargs["qconfig"] = qconfig
if details is not None:
kwargs["details"] = details
return FA3QuantProcessorConfig(**kwargs)
def create_mock_adapter():
"""创建模拟适配器"""
adapter = Mock(spec=FA3QuantAdapterInterface)
adapter.inject_fa3_placeholders = Mock()
return adapter
def create_simple_model():
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(64, 64)
self.block = None
self.layer = None
def forward(self, x):
return self.linear(x)
return SimpleModel()
class TestFA3QuantProcessorConfig(unittest.TestCase):
"""测试 FA3QuantProcessorConfig 的构造和默认行为"""
def test_qconfig_default_to_INT8_per_head_when_qconfig_not_provided(self):
"""未提供 qconfig 时,默认生成 INT8 per‑head 对称 QConfig"""
config = create_processor_config()
self.assertIsInstance(config.qconfig, QConfig)
self.assertIsNone(config.details)
self.assertEqual(config.qconfig.dtype, QDType.INT8)
self.assertEqual(config.qconfig.scope, QScope.PER_HEAD)
self.assertTrue(config.qconfig.symmetric)
def test_qconfig_stores_instance_when_qconfig_instance_given(self):
"""传入 QConfig 实例时,正确保存该实例"""
qconfig = create_qconfig(QDType.FP8_E4M3, QScope.PER_HEAD, symmetric=True)
config = create_processor_config(qconfig=qconfig)
self.assertIsInstance(config.qconfig, QConfig)
self.assertIsNone(config.details)
self.assertEqual(config.qconfig.dtype, QDType.FP8_E4M3)
def test_details_stores_attention_details_when_branch_dict_given(self):
"""传入 details 字典时,转换为 FA3AttentionDetails,qconfig 为 None"""
details_dict = create_branch_qconfig_dict()
config = create_processor_config(details=details_dict)
self.assertIsInstance(config.details, FA3AttentionDetails)
self.assertIsNone(config.qconfig)
details = config.details
self.assertIsNotNone(details.fa_q)
self.assertIsNotNone(details.fa_k)
self.assertIsNone(details.fa_v)
def test_mutual_exclusivity_raises_when_both_qconfig_and_details_provided(self):
"""同时提供 qconfig 和 details 时应抛出 SchemaValidateError"""
qconfig = create_qconfig(QDType.INT8, QScope.PER_HEAD)
details = create_branch_qconfig_dict()
with self.assertRaises(SchemaValidateError) as ctx:
create_processor_config(qconfig=qconfig, details=details)
self.assertIn("only one of the qconfig and details", str(ctx.exception))
def test_type_equals_fa3_quant_when_created(self):
"""type 字段固定为 'fa3_quant'"""
config = create_processor_config()
self.assertEqual(config.type, "fa3_quant")
def test_include_exclude_stored_when_provided(self):
"""提供 include/exclude 列表时,正确保存"""
config = create_processor_config(include=["layer1"], exclude=["layer2"])
self.assertEqual(config.include, ["layer1"])
self.assertEqual(config.exclude, ["layer2"])
class TestFA3QuantProcessor(unittest.TestCase):
"""测试 FA3QuantProcessor 的初始化、属性和 pre/postprocess 方法"""
def setUp(self):
self.adapter = create_mock_adapter()
self.simple_model = create_simple_model()
self._breakpoint_patcher = patch('builtins.breakpoint', lambda *args, **kwargs: None)
self._breakpoint_patcher.start()
def tearDown(self):
self._breakpoint_patcher.stop()
def test_init_stores_config_and_adapter_when_valid(self):
config = create_processor_config()
processor = FA3QuantProcessor(self.simple_model, config, self.adapter)
self.assertEqual(processor.config, config)
self.assertEqual(processor.adapter, self.adapter)
self.assertIsNotNone(processor.include)
self.assertIsNotNone(processor.exclude)
def test_init_raises_UnsupportedError_when_adapter_is_None(self):
config = create_processor_config()
with self.assertRaises(UnsupportedError):
FA3QuantProcessor(self.simple_model, config, adapter=None)
def test_init_raises_UnsupportedError_when_adapter_not_interface(self):
config = create_processor_config()
with self.assertRaises(UnsupportedError):
FA3QuantProcessor(self.simple_model, config, adapter=Mock())
def test_is_data_free_returns_False_when_per_head_qconfig(self):
config = create_processor_config()
processor = FA3QuantProcessor(self.simple_model, config, self.adapter)
self.assertFalse(processor.is_data_free())
def test_is_data_free_returns_True_when_per_token_qconfig(self):
config = create_processor_config(qconfig=create_qconfig(QDType.INT8, QScope.PER_TOKEN))
processor = FA3QuantProcessor(self.simple_model, config, self.adapter)
self.assertTrue(processor.is_data_free())
def test_is_data_free_returns_False_when_details_per_head(self):
details = {
"fa_q": create_qconfig(QDType.INT8, QScope.PER_HEAD).model_dump(),
"fa_k": create_qconfig(QDType.FP8_E4M3, QScope.PER_HEAD).model_dump(),
}
config = create_processor_config(details=details)
processor = FA3QuantProcessor(self.simple_model, config, self.adapter)
self.assertFalse(processor.is_data_free())
def test_is_data_free_returns_True_when_details_per_token(self):
details = {
"fa_q": create_qconfig(QDType.INT8, QScope.PER_TOKEN).model_dump(),
"fa_v": create_qconfig(QDType.FP8_E4M3, QScope.PER_TOKEN).model_dump(),
}
config = create_processor_config(details=details)
processor = FA3QuantProcessor(self.simple_model, config, self.adapter)
self.assertTrue(processor.is_data_free())
def test_support_distributed_returns_True_always(self):
config = create_processor_config()
processor = FA3QuantProcessor(self.simple_model, config, self.adapter)
self.assertTrue(processor.support_distributed())
@patch("msmodelslim.processor.quant.fa3.processor.dist.is_initialized", return_value=False)
def test_preprocess_calls_adapter_and_replaces_placeholder_when_dist_not_init(self, mock_dist):
config = create_processor_config(include=["block.fa_q"])
model = create_simple_model()
block = nn.Module()
block.fa_q = FA3QuantPlaceHolder(ratio=0.9)
model.block = block
processor = FA3QuantProcessor(model, config, self.adapter)
request = BatchProcessRequest(name="block", module=block, datas=None, outputs=None)
self.adapter.inject_fa3_placeholders.side_effect = None
processor.preprocess(request)
self.adapter.inject_fa3_placeholders.assert_called_once()
call_args = self.adapter.inject_fa3_placeholders.call_args[0]
self.assertEqual(call_args[0], "block")
self.assertIs(call_args[1], block)
should_inject = call_args[2]
self.assertTrue(should_inject("block.fa_q"))
self.assertFalse(should_inject("block.other"))
self.assertIsInstance(block.fa_q, _FA3PerHeadObserver)
@patch("msmodelslim.processor.quant.fa3.processor.dist.is_initialized", return_value=False)
def test_preprocess_logs_warning_when_adapter_raises(self, mock_dist):
self.adapter.inject_fa3_placeholders.side_effect = RuntimeError("mock error")
config = create_processor_config()
model = create_simple_model()
block = nn.Module()
model.block = block
processor = FA3QuantProcessor(model, config, self.adapter)
request = BatchProcessRequest(name="block", module=block, datas=None, outputs=None)
with self.assertLogs('msmodelslim.processor.fa3_quant', level='WARNING') as log:
processor.preprocess(request)
self.assertTrue(any("mock error" in msg for msg in log.output))
def _prepare_postprocess_model(self, scope=QScope.PER_HEAD, collect_data=True, qconfig=None, details=None):
"""构建一个包含 observer 子模块的模型,并返回 processor 和相关对象。"""
model = create_simple_model()
config_kwargs = {"include": ["layer.fa_q"]}
if qconfig is not None:
config_kwargs["qconfig"] = qconfig
if details is not None:
config_kwargs["details"] = details
config = create_processor_config(**config_kwargs)
processor = FA3QuantProcessor(model, config, self.adapter)
layer = nn.Module()
observer = _FA3PerHeadObserver(ratio=1.0, name="layer.fa_q")
layer.fa_q = observer
model.layer = layer
if collect_data:
observer(torch.randn(2, 4, 10, 16))
request = BatchProcessRequest(name="layer", module=layer, datas=None, outputs=None)
return processor, request, layer, observer
@patch("msmodelslim.processor.quant.fa3.processor.calculate_qparam")
@patch("msmodelslim.ir.FP8FakeQuantActivationPerHead")
def test_postprocess_replaces_observer_with_per_head_IR_when_calibrated(self, mock_activation, mock_calc):
processor, request, layer, observer = self._prepare_postprocess_model(
scope=QScope.PER_HEAD, qconfig=create_qconfig(QDType.INT8, QScope.PER_HEAD)
)
fake_qparam = QParam(
scheme=QScheme(scope=QScope.PER_HEAD, dtype=QDType.FP8_E4M3, symmetric=True), ext={"scale": torch.rand(4)}
)
mock_calc.return_value = fake_qparam
real_ir = FP8FakeQuantActivationPerHead(fake_qparam)
mock_activation.create.return_value = real_ir
processor.postprocess(request)
mock_calc.assert_called_once()
_, kwargs = mock_calc.call_args
self.assertTrue(torch.equal(kwargs["min_val"], observer.min_val.squeeze()))
self.assertTrue(torch.equal(kwargs["max_val"], observer.max_val.squeeze()))
self.assertEqual(kwargs["q_dtype"], QDType.INT8)
self.assertIsInstance(layer.fa_q, FP8FakeQuantActivationPerHead)
test_input = torch.randn(2, 4, 10, 16)
with torch.no_grad():
output = layer.fa_q(test_input)
self.assertEqual(output.shape, test_input.shape)
def test_postprocess_raises_SpecError_when_no_calibration_data(self):
processor, request, _, _ = self._prepare_postprocess_model(
scope=QScope.PER_HEAD, qconfig=create_qconfig(QDType.INT8, QScope.PER_HEAD), collect_data=False
)
with self.assertRaises(SpecError) as ctx:
processor.postprocess(request)
self.assertIn("no any update_stats", str(ctx.exception))
@patch("msmodelslim.ir.auto.AutoFakeQuantActivation.create")
def test_postprocess_replaces_observer_with_per_token_IR_when_per_token_qconfig(self, mock_create):
model = create_simple_model()
config = create_processor_config(
include=["layer.fa_q"], qconfig=create_qconfig(QDType.FP8_E4M3, QScope.PER_TOKEN, symmetric=True)
)
processor = FA3QuantProcessor(model, config, self.adapter)
layer = nn.Module()
layer.fa_q = _FA3PerHeadObserver(ratio=1.0)
model.layer = layer
request = BatchProcessRequest(name="layer", module=layer, datas=None, outputs=None)
fake_qparam = QParam(scheme=QScheme(scope=QScope.PER_TOKEN, dtype=QDType.FP8_E4M3, symmetric=True))
real_ir = FakeQuantActivationPerToken(fake_qparam)
mock_create.return_value = real_ir
processor.postprocess(request)
mock_create.assert_called_once()
self.assertIs(layer.fa_q, real_ir)
def test_postprocess_raises_UnsupportedError_when_unsupported_scope(self):
model = create_simple_model()
config = create_processor_config(
include=["layer.fa_q"], qconfig=create_qconfig(QDType.INT8, QScope.PER_CHANNEL, symmetric=True)
)
processor = FA3QuantProcessor(model, config, self.adapter)
layer = nn.Module()
observer = _FA3PerHeadObserver(ratio=1.0)
layer.fa_q = observer
model.layer = layer
observer(torch.randn(2, 4, 10, 16))
request = BatchProcessRequest(name="layer", module=layer, datas=None, outputs=None)
with self.assertRaises(UnsupportedError):
processor.postprocess(request)
def test_postprocess_clears_dist_helper_after_execution(self):
processor, request, _, _ = self._prepare_postprocess_model(qconfig=create_qconfig(QDType.INT8, QScope.PER_HEAD))
processor.dist_helper = Mock()
processor.postprocess(request)
self.assertIsNone(processor.dist_helper)
@patch("msmodelslim.processor.quant.fa3.processor.calculate_qparam")
@patch("msmodelslim.ir.FP8FakeQuantActivationPerHead")
def test_postprocess_uses_details_branch_qconfig_when_details_configured(self, mock_activation, mock_calc):
"""details 模式下,postprocess 根据分支名提取对应的 QConfig 进行处理"""
details_dict = {
"fa_q": create_qconfig(QDType.INT8, QScope.PER_HEAD).model_dump(),
"fa_k": create_qconfig(QDType.FP8_E4M3, QScope.PER_HEAD).model_dump(),
}
processor, request, layer, observer = self._prepare_postprocess_model(details=details_dict, collect_data=True)
fake_qparam = QParam(
scheme=QScheme(scope=QScope.PER_HEAD, dtype=QDType.FP8_E4M3, symmetric=True), ext={"scale": torch.rand(4)}
)
mock_calc.return_value = fake_qparam
real_ir = FP8FakeQuantActivationPerHead(fake_qparam)
mock_activation.create.return_value = real_ir
processor.postprocess(request)
mock_calc.assert_called_once()
_, kwargs = mock_calc.call_args
self.assertEqual(kwargs["q_dtype"], QDType.INT8)
self.assertIsInstance(layer.fa_q, FP8FakeQuantActivationPerHead)
def test_postprocess_skips_module_when_qconfig_is_none_in_details(self):
"""当 details 中对应分支的 qconfig 为 None 时,postprocess 应跳过该模块并记录 debug"""
with patch("msmodelslim.processor.quant.fa3.processor.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
details = FA3AttentionDetails(fa_q=create_qconfig(QDType.INT8, QScope.PER_HEAD))
config = create_processor_config(include=["layer.fa_q", "layer.fa_k"], details=details)
model = create_simple_model()
processor = FA3QuantProcessor(model, config, self.adapter)
layer = nn.Module()
observer_q = _FA3PerHeadObserver(ratio=1.0, name="layer.fa_q")
observer_k = _FA3PerHeadObserver(ratio=1.0, name="layer.fa_k")
layer.fa_q = observer_q
layer.fa_k = observer_k
model.layer = layer
observer_q(torch.randn(2, 4, 10, 16))
observer_k(torch.randn(2, 4, 10, 16))
request = BatchProcessRequest(name="layer", module=layer, datas=None, outputs=None)
with (
patch.object(processor, '_process_per_head') as mock_per_head,
patch.object(processor, '_process_per_token') as mock_per_token,
):
processor.postprocess(request)
mock_per_head.assert_called_once()
args, _ = mock_per_head.call_args
self.assertEqual(args[1], "layer.fa_q")
mock_per_token.assert_not_called()
debug_calls = [
c for c in mock_logger.debug.call_args_list if "layer.fa_k" in str(c) and "skipping" in str(c)
]
self.assertTrue(len(debug_calls) > 0, "Debug log for skipping layer.fa_k not found")
if __name__ == '__main__':
unittest.main()