"""Unit tests for tensor_cast.core.quantization.config helpers."""
from tensor_cast.core.quantization.config import (
_filter_action_kwargs,
_set_linear_patterns,
create_quant_config,
)
from tensor_cast.core.quantization.datatypes import QuantizeLinearAction
from tensor_cast.model_config import QuantConfig
from tensor_cast.quantize_utils import LinearQuantType, QuantGranularity, get_quant_config
class TestQuantizationConfigHelpers:
def test_filter_action_kwargs_strips_mxfp4_only_fields_for_fp8(self):
kwargs = {
"weight_group_size": 32,
"weight_quant_granularity": QuantGranularity.PER_GROUP,
"extra_flag": True,
}
filtered = _filter_action_kwargs(QuantizeLinearAction.FP8, kwargs)
assert filtered == {"extra_flag": True}
def test_filter_action_kwargs_preserves_all_fields_for_mxfp4(self):
kwargs = {
"weight_group_size": 32,
"weight_quant_granularity": QuantGranularity.PER_GROUP,
}
assert _filter_action_kwargs(QuantizeLinearAction.MXFP4, kwargs) == kwargs
def test_set_linear_patterns_registers_config_per_pattern(self):
quant_config = QuantConfig()
_set_linear_patterns(quant_config, ["*.self_attn.*", "*.mlp.gate_proj"], QuantizeLinearAction.FP8)
assert quant_config.linear_configs["*.self_attn.*"].quant_type == LinearQuantType.FP8
assert quant_config.linear_configs["*.mlp.gate_proj"].quant_type == LinearQuantType.FP8
def test_create_quant_config_mxfp4_experts_fp8_backbone_override(self):
quant_config = create_quant_config(
quantize_linear_action=QuantizeLinearAction.MXFP4,
quantize_backbone_linear_action=QuantizeLinearAction.FP8,
weight_group_size=32,
weight_quant_granularity=QuantGranularity.PER_GROUP,
)
backbone_cfg = get_quant_config("model.layers.3.self_attn.q_a_proj", quant_config, "default_dit")
shared_cfg = get_quant_config("model.layers.3.mlp.shared_experts.gate_proj", quant_config, "default_dit")
expert_cfg = get_quant_config("model.layers.3.mlp.experts.5.gate_proj", quant_config, "default_dit")
assert backbone_cfg.quant_type == LinearQuantType.FP8
assert shared_cfg.quant_type == LinearQuantType.FP8
assert expert_cfg.quant_type == LinearQuantType.MXFP4
assert backbone_cfg.dynamic_quant_granularity == QuantGranularity.PER_TENSOR
assert expert_cfg.weight_group_size == 32
assert expert_cfg.weight_quant_granularity == QuantGranularity.PER_GROUP
def test_create_quant_config_linear_only_keeps_original_behavior(self):
quant_config = create_quant_config(quantize_linear_action=QuantizeLinearAction.W8A8_DYNAMIC)
layer_cfg = get_quant_config("model.layers.0.mlp.gate_proj", quant_config, "default_dit")
assert layer_cfg.quant_type == LinearQuantType.W8A8