import pytest
import torch
import torch.nn as nn
from amct_pytorch.algorithms import AlgorithmRegistry
from amct_pytorch.common.config.config import INT8_MINMAX_WEIGHT_QUANT_CFG
from amct_pytorch.common.config.fields import QuantConfig
from amct_pytorch.common.config.parser import (
_build_layer_types_and_quant_type,
_check_fuzzy_config_warnings,
_check_layer_constraints,
_is_layer_supported,
check_config,
check_quant_op_constraint,
check_skip_layer,
get_supported_layers,
parse_config,
set_default_config,
)
MINMAX = 'minmax'
MY_CUSTOM_PARSER = 'my_custom_parser'
def test_set_default_config_returns_int8_minmax():
assert set_default_config() == INT8_MINMAX_WEIGHT_QUANT_CFG
def test_check_skip_layer_empty_returns_false():
assert check_skip_layer("layer.0", []) is False
assert check_skip_layer("layer.0", None) is False
def test_check_skip_layer_matched():
assert check_skip_layer("model.layers.0.lm_head", ["lm_head"]) is True
def test_check_skip_layer_substring_match():
assert check_skip_layer("model.layers.0.self_attn", ["self_attn"]) is True
def test_check_skip_layer_no_match():
assert check_skip_layer("model.layers.0.mlp", ["lm_head"]) is False
def _make_quant_config(*, weight_type="int8", weight_strategy="channel", group_size=None,
input_type="int8", input_strategy="tensor", enable_input=True):
cfg = {
"batch_num": 1,
"quant_cfg": {
"weights": {"type": weight_type, "symmetric": True, "strategy": weight_strategy},
"inputs": {"type": input_type, "symmetric": True, "strategy": input_strategy, "enable_quant": enable_input},
},
"algorithm": {MINMAX: {}},
}
if group_size is not None:
cfg["quant_cfg"]["weights"]["group_size"] = group_size
return QuantConfig(cfg, AlgorithmRegistry)
def test_check_config_valid_int8_int8():
qc = _make_quant_config()
check_config("int8 int8", qc, MINMAX)
def test_check_config_invalid_comb():
with pytest.raises(ValueError, match="Do not support combination"):
check_config("float64 float64", _make_quant_config(), MINMAX)
def test_check_config_algo_not_support_comb():
with pytest.raises(ValueError, match="do not support act and weight quant dtype"):
check_config("mxfp8_e4m3fn mxfp8_e4m3fn", _make_quant_config(), MINMAX)
def test_check_config_weight_strategy_not_supported():
qc = _make_quant_config(weight_type="int8", weight_strategy="group", group_size=64)
with pytest.raises(ValueError, match="do not support weight quant strategy"):
check_config("int8 int8", qc, MINMAX)
def test_check_config_act_strategy_not_supported():
qc = _make_quant_config(input_strategy="token", weight_type="float8_e4m3fn",
input_type="float8_e4m3fn")
with pytest.raises(ValueError, match="do not support activation quant strategy"):
check_config("float8_e4m3fn float8_e4m3fn", qc, "ofmr")
def test_check_config_mxfp8_act_strategy_not_group():
qc = _make_quant_config(input_type="mxfp8_e4m3fn", input_strategy="tensor",
weight_type="mxfp8_e4m3fn", weight_strategy="group", group_size=32)
with pytest.raises(ValueError, match="only support activation quant strategy group"):
check_config("mxfp8_e4m3fn mxfp8_e4m3fn", qc, "mxquant")
def test_check_config_group_size_not_multiple_of_32():
qc = _make_quant_config(weight_type="int4", weight_strategy="group", group_size=33, enable_input=False)
with pytest.raises(ValueError, match="integer multiple of 32"):
check_config("NOT_QUANTIZE int4", qc, MINMAX)
def test_check_config_group_size_less_than_32():
qc = _make_quant_config(weight_type="int4", weight_strategy="group", group_size=16, enable_input=False)
with pytest.raises(ValueError, match="group_size larger than 32"):
check_config("NOT_QUANTIZE int4", qc, MINMAX)
def _make_mock_linear(in_features=64, out_features=64, has_bias=False):
mod = nn.Linear(in_features, out_features, bias=has_bias)
return mod
def test_check_quant_op_constraint_non_linear():
mod = nn.Conv2d(3, 3, 1)
assert check_quant_op_constraint(mod, "conv", "int8 int8", _make_quant_config()) is True
def test_check_quant_op_constraint_cin_not_multiple_of_64():
mod = _make_mock_linear(in_features=63, out_features=64)
result = check_quant_op_constraint(
mod, "layer.0", "float8_e4m3fn float4_e2m1",
_make_quant_config(
weight_type="float4_e2m1", weight_strategy="group", group_size=64),
)
assert result is False
def test_check_quant_op_constraint_has_bias():
mod = _make_mock_linear(in_features=64, out_features=64, has_bias=True)
result = check_quant_op_constraint(
mod, "layer.0", "float8_e4m3fn float4_e2m1",
_make_quant_config(
weight_type="float4_e2m1", weight_strategy="group", group_size=64),
)
assert result is False
def test_check_quant_op_constraint_no_bias_no_group_size():
mod = _make_mock_linear(in_features=64, out_features=64, has_bias=False)
result = check_quant_op_constraint(mod, "layer.0", "int8 int8", _make_quant_config(weight_strategy="channel"))
assert result is True
def test_check_quant_op_constraint_group_size_none():
mod = _make_mock_linear(in_features=32, out_features=32)
qc = _make_quant_config(weight_type="int8", weight_strategy="channel", enable_input=False)
result = check_quant_op_constraint(mod, "layer.0", "NOT_QUANTIZE int8", qc)
assert result is True
def test_check_quant_op_constraint_mxfp8_ceiling_odd():
mod = _make_mock_linear(in_features=31, out_features=64)
qc = _make_quant_config(weight_type="mxfp8_e4m3fn", weight_strategy="group", group_size=32,
input_type="mxfp8_e4m3fn", input_strategy="group")
result = check_quant_op_constraint(mod, "layer.0", "mxfp8_e4m3fn mxfp8_e4m3fn", qc)
assert result is False
def test_check_quant_op_constraint_mxfp4_shape_cout_not_64():
mod = _make_mock_linear(in_features=64, out_features=63)
result = check_quant_op_constraint(
mod, "layer.0", "NOT_QUANTIZE mxfp4_e2m1",
_make_quant_config(
weight_type="mxfp4_e2m1", weight_strategy="group",
group_size=32, enable_input=False),
)
assert result is False
def test_check_quant_op_constraint_int4_cin_not_8():
mod = _make_mock_linear(in_features=7, out_features=8)
result = check_quant_op_constraint(
mod, "layer.0", "NOT_QUANTIZE int4",
_make_quant_config(
weight_type="int4", weight_strategy="channel", enable_input=False),
)
assert result is False
def test_check_fuzzy_config_warnings_skip_interaction(caplog):
qc = QuantConfig(
{
"batch_num": 1,
"quant_cfg": {
"*self_attn.q_proj.weights": {
"type": "int4", "symmetric": True, "strategy": "channel",
},
},
"algorithm": {MINMAX: {}},
"skip_layers": ["model.layers.0.self_attn.q_proj"],
},
AlgorithmRegistry,
)
_check_fuzzy_config_warnings(["model.layers.0.self_attn.q_proj"], qc)
def test_is_layer_supported_conv2d_padding_mode():
mod = nn.Conv2d(3, 3, 1, padding_mode='reflect')
lt = {"Conv2d": "ofmr"}
qc = _make_quant_config(weight_type="int8", input_type="int8")
assert _is_layer_supported(mod, "conv.0", lt, "int8 int8", qc) is False
def test_check_layer_constraints_custom_algo():
mod = _make_mock_linear(in_features=64, out_features=64)
qc = _make_quant_config()
assert _check_layer_constraints(mod, "layer.0", "custom_algo", "int8 int8", qc) is True
def test_check_layer_constraints_weight_dtype_mismatch():
mod = nn.Linear(64, 64, dtype=torch.float64)
qc = _make_quant_config(weight_type="int4", enable_input=False)
result = _check_layer_constraints(mod, "layer.0", MINMAX, "NOT_QUANTIZE int4", qc)
assert result is False
def test_check_quant_op_constraint_group_size_too_large():
qc = _make_quant_config(weight_type="int4", weight_strategy="group", group_size=128)
mod = _make_mock_linear(in_features=32, out_features=32)
result = check_quant_op_constraint(mod, "layer.0", "NOT_QUANTIZE int4", qc)
assert result is False
def test_check_quant_op_constraint_int4_shape():
mod = _make_mock_linear(in_features=7, out_features=7)
result = check_quant_op_constraint(mod, "layer.0", "NOT_QUANTIZE int4",
_make_quant_config(weight_type="int4", weight_strategy="channel"))
assert result is False
def test_check_fuzzy_config_warnings_no_fuzzy():
qc = QuantConfig(
{
"batch_num": 1,
"quant_cfg": {
"weights": {
"type": "int8", "symmetric": True, "strategy": "channel",
},
},
"algorithm": {MINMAX: {}},
},
AlgorithmRegistry,
)
_check_fuzzy_config_warnings(["layer.0", "layer.1"], qc)
def test_check_fuzzy_config_warnings_with_match(caplog):
qc = QuantConfig(
{
"batch_num": 1,
"quant_cfg": {
"*self_attn.q_proj.weights": {
"type": "int4", "symmetric": True, "strategy": "channel",
},
},
"algorithm": {MINMAX: {}},
},
AlgorithmRegistry,
)
_check_fuzzy_config_warnings(["model.layers.0.self_attn.q_proj"], qc)
def test_check_fuzzy_config_warnings_no_match(caplog):
qc = QuantConfig(
{
"batch_num": 1,
"quant_cfg": {
"*self_attn.q_proj.weights": {
"type": "int4", "symmetric": True, "strategy": "channel",
},
},
"algorithm": {MINMAX: {}},
},
AlgorithmRegistry,
)
_check_fuzzy_config_warnings(["model.layers.0.mlp.gate_proj"], qc)
def test_build_layer_types_single_algo():
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"weights": {"type": "int8", "symmetric": True, "strategy": "channel"}},
"algorithm": {MINMAX: {}}}, AlgorithmRegistry)
lt, qtc = _build_layer_types_and_quant_type(qc, AlgorithmRegistry)
assert "Linear" in lt
assert qtc == "NOT_QUANTIZE int8"
def test_build_layer_types_multi_algo():
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"weights": {"type": "int4", "symmetric": True, "strategy": "channel"}},
"algorithm": {"awq": {"grids_num": 20}}}, AlgorithmRegistry)
lt, qtc = _build_layer_types_and_quant_type(qc, AlgorithmRegistry)
assert "Linear" in lt
assert qtc == "NOT_QUANTIZE int4"
def test_build_layer_types_weight_none():
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"weights": {"type": "int8", "symmetric": True, "strategy": "channel"},
"inputs": {"enable_quant": False}},
"algorithm": {MINMAX: {}}}, AlgorithmRegistry)
lt, qtc = _build_layer_types_and_quant_type(qc, AlgorithmRegistry)
assert qtc == "NOT_QUANTIZE int8"
def test_is_layer_supported_linear():
mod = nn.Linear(4, 4)
lt = {"Linear": MINMAX}
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"weights": {"type": "int8", "symmetric": True, "strategy": "channel"}},
"algorithm": {MINMAX: {}}}, AlgorithmRegistry)
assert _is_layer_supported(mod, "layer.0", lt, "NOT_QUANTIZE int8", qc) is True
def test_is_layer_supported_skip_layer():
mod = nn.Linear(4, 4)
lt = {"Linear": MINMAX}
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"weights": {"type": "int8", "symmetric": True, "strategy": "channel"}},
"algorithm": {MINMAX: {}},
"skip_layers": ["layer.0"]}, AlgorithmRegistry)
assert _is_layer_supported(mod, "layer.0", lt, "NOT_QUANTIZE int8", qc) is False
def test_check_layer_constraints_no_weight():
class NoWeightModule(nn.Module):
pass
mod = NoWeightModule()
lt = {"Linear": MINMAX}
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"weights": {"type": "int8", "symmetric": True, "strategy": "channel"}},
"algorithm": {MINMAX: {}}}, AlgorithmRegistry)
assert _check_layer_constraints(mod, "norm.0", MINMAX, "NOT_QUANTIZE int8", qc) is True
class _MockModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Module(),
nn.Module(),
])
self.layers[0].self_attn = nn.Module()
self.layers[0].self_attn.q_proj = nn.Linear(64, 64, dtype=torch.bfloat16)
self.layers[0].mlp = nn.Module()
self.layers[0].mlp.gate_proj = nn.Linear(64, 64, dtype=torch.bfloat16)
self.layers[1].self_attn = nn.Module()
self.layers[1].self_attn.q_proj = nn.Linear(64, 64, dtype=torch.bfloat16)
self.layers[1].mlp = nn.Module()
self.layers[1].mlp.gate_proj = nn.Linear(64, 64, dtype=torch.bfloat16)
def test_get_supported_layers_finds_linears():
model = _MockModel()
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"weights": {"type": "int8", "symmetric": True, "strategy": "channel"}},
"algorithm": {MINMAX: {}}}, AlgorithmRegistry)
result = get_supported_layers(model, qc, AlgorithmRegistry)
assert len(result) == 4
def test_parse_config_returns_detail():
model = _MockModel()
config = {"batch_num": 1,
"quant_cfg": {"weights": {"type": "int8", "symmetric": True, "strategy": "channel"}},
"algorithm": {MINMAX: {}}}
result = parse_config(model, config, AlgorithmRegistry)
assert len(result) == 4
for name in result:
assert "batch_num" in result[name]
assert "weights_cfg" in result[name]
assert "algorithm" in result[name]
def test_check_quant_op_constraint_group_size_valid_returns_true():
mod = _make_mock_linear(in_features=128, out_features=64)
qc = _make_quant_config(weight_type="int4", weight_strategy="group", group_size=64)
result = check_quant_op_constraint(mod, "layer.0", "NOT_QUANTIZE int4", qc)
assert result is True
def test_check_quant_op_constraint_group_size_valid_less_than_64_returns_true():
mod = _make_mock_linear(in_features=64, out_features=32)
qc = _make_quant_config(weight_type="int8", weight_strategy="group", group_size=32)
result = check_quant_op_constraint(mod, "layer.0", "NOT_QUANTIZE int8", qc)
assert result is True
def test_check_fuzzy_config_warnings_pattern_no_match(caplog):
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"*nonexistent.weights": {"type": "int4", "symmetric": True, "strategy": "channel"}},
"algorithm": {MINMAX: {}}}, AlgorithmRegistry)
_check_fuzzy_config_warnings(["model.layers.0.self_attn.q_proj"], qc)
def test_build_layer_types_customized_algo(monkeypatch):
AlgorithmRegistry.algo[MY_CUSTOM_PARSER] = {"Linear": object()}
monkeypatch.setattr("amct_pytorch.common.config.parser.BUILT_IN_ALGORITHM", [])
try:
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"weights": {"type": "int8", "symmetric": True, "strategy": "channel"}},
"algorithm": {MY_CUSTOM_PARSER: {}}}, AlgorithmRegistry)
_build_layer_types_and_quant_type(qc, AlgorithmRegistry)
finally:
del AlgorithmRegistry.algo[MY_CUSTOM_PARSER]
def test_get_supported_layers_no_weights():
model = _MockModel()
qc = QuantConfig({"batch_num": 1,
"quant_cfg": {"weights": {"type": "int4", "symmetric": True, "strategy": "channel"}},
"algorithm": {MINMAX: {}}}, AlgorithmRegistry)
result = get_supported_layers(model, qc, AlgorithmRegistry)
def test_get_supported_layers_constraint_skip():
model = _MockModel()
class BadWeightLinear(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(64, 63).to(torch.float32))
self.bias = None
model.layers[0].self_attn.q_proj = BadWeightLinear()
qc = _make_quant_config(weight_type="float4_e2m1", weight_strategy="group", group_size=64, enable_input=False)
result = get_supported_layers(model, qc, AlgorithmRegistry)
assert "layers.0.self_attn.q_proj" not in result
def test_build_layer_types_and_quant_type_when_wts_type_is_none():
from unittest.mock import MagicMock
from amct_pytorch.common.config.parser import _build_layer_types_and_quant_type
quant_config = MagicMock()
quant_config.quant_cfg.inputs_cfg.quant_input = False
quant_config.quant_cfg.inputs_cfg.quant_type = "NOT_QUANTIZE"
quant_config.quant_cfg.weights_cfg.quant_type = None
registed_alg = MagicMock()
registed_alg.algo = {"awq": {"Linear": {}}}
layer_types, quant_type_comb = _build_layer_types_and_quant_type(quant_config, registed_alg)
assert quant_type_comb is None