from types import SimpleNamespace
import pytest
import torch
import torch.nn as nn
from amct_pytorch.algorithms.registry_factory import ALGO_REGISTRY
from amct_pytorch.quantization.dtypes import register_dtype
from amct_pytorch.quantization.modules.quant_base import (
ActivationQuantizer,
WeightQuantizer,
build_algorithms_by_target,
get_algo_names_by_target,
set_act_quantizer_state,
set_quantizer_state,
set_weight_quantizer_state,
)
register_dtype()
UT_OBSERVE_ALGO = '_ut_observe'
UT_DOUBLE_ALGO = '_ut_double'
UT_QUANT_HOOK_ALGO = '_ut_quant_hook'
UT_QH_A_ALGO = '_ut_qh_a'
UT_QH_EXPORT_ALGO = '_ut_qh_export'
UT_QH_B = '_ut_qh_b'
def _args(algos=(), quant_dtype="int", w_bits=8, quant_target=()):
return SimpleNamespace(
algos=list(algos),
quant_dtype=quant_dtype,
w_bits=w_bits,
quant_target=list(quant_target),
)
@pytest.fixture
def _ephemeral_algo():
"""Register a temporary algorithm with explicit targets; remove after test."""
name = "_ut_lwc_like"
@ALGO_REGISTRY.register(name=name, targets=("weight", "activation"))
class _Algo(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
def forward(self, x):
return x * 2
yield name
ALGO_REGISTRY._items.pop(name, None)
def test_get_algo_names_filters_by_target(_ephemeral_algo):
args = _args(algos=[_ephemeral_algo])
assert get_algo_names_by_target(args, "weight") == [_ephemeral_algo]
assert get_algo_names_by_target(args, "activation") == [_ephemeral_algo]
assert not get_algo_names_by_target(args, "structure")
def test_get_algo_names_raises_on_algo_without_targets():
name = "_ut_no_targets"
ALGO_REGISTRY.register(name=name)(type("T", (), {}))
try:
with pytest.raises(ValueError, match="missing registry metadata 'targets'"):
get_algo_names_by_target(_args(algos=[name]), "weight")
finally:
ALGO_REGISTRY._items.pop(name, None)
def test_build_algorithms_returns_module_dict_for_non_structure(_ephemeral_algo):
out = build_algorithms_by_target(_args(algos=[_ephemeral_algo]), "weight")
assert isinstance(out, nn.ModuleDict)
assert _ephemeral_algo in out
def test_build_algorithms_structure_returns_none_when_no_match(_ephemeral_algo):
out = build_algorithms_by_target(_args(algos=[_ephemeral_algo]), "structure")
assert out is None
def test_build_algorithms_structure_returns_single_algorithm():
name = "_ut_struct_one"
@ALGO_REGISTRY.register(name=name, targets=("structure",))
class _Algo(nn.Module):
def __init__(self, args, ctx):
super().__init__()
try:
out = build_algorithms_by_target(
_args(algos=[name]), "structure", SimpleNamespace()
)
assert isinstance(out, _Algo)
finally:
ALGO_REGISTRY._items.pop(name, None)
def test_build_algorithms_structure_raises_on_multiple_matches():
n1, n2 = "_ut_struct_a", "_ut_struct_b"
for n in (n1, n2):
@ALGO_REGISTRY.register(name=n, targets=("structure",))
class _Algo(nn.Module):
def __init__(self, args, ctx):
super().__init__()
try:
with pytest.raises(ValueError, match="Only one 'structure' algorithm"):
build_algorithms_by_target(
_args(algos=[n1, n2]), "structure", SimpleNamespace()
)
finally:
ALGO_REGISTRY._items.pop(n1, None)
ALGO_REGISTRY._items.pop(n2, None)
def _model_with_quantizers():
m = nn.Module()
m.act = ActivationQuantizer(_args(), bits=8)
m.weight = WeightQuantizer(_args(), w_bits=8)
m.linear = nn.Linear(4, 4)
return m
@pytest.mark.parametrize("flag", [True, False])
def test_set_quantizer_state_toggles_both_kinds(flag):
m = _model_with_quantizers()
set_quantizer_state(m, enable=flag)
assert m.act.enable is flag
assert m.weight.enable is flag
def test_set_weight_quantizer_state_only_touches_weight():
m = _model_with_quantizers()
set_weight_quantizer_state(m, enable=True)
assert m.weight.enable is True
assert m.act.enable is False
def test_set_act_quantizer_state_only_touches_activation():
m = _model_with_quantizers()
set_act_quantizer_state(m, enable=True)
assert m.act.enable is True
assert m.weight.enable is False
def test_activation_quantizer_forward_passthrough_when_disabled():
aq = ActivationQuantizer(_args(), bits=8)
x = torch.randn(2, 32)
assert torch.equal(aq(x), x)
def test_activation_quantizer_forward_quantizes_when_enabled():
aq = ActivationQuantizer(_args(), bits=8)
aq.enable = True
x = torch.randn(2, 32, dtype=torch.float32)
out = aq(x)
assert out.shape == x.shape
assert out.dtype == x.dtype
def test_activation_quantizer_trainable_params_collects_from_algorithms(_ephemeral_algo):
aq = ActivationQuantizer(_args(algos=[_ephemeral_algo]), bits=8)
assert not aq.trainable_params()
def test_activation_quantizer_deploy_hooks_are_no_ops():
aq = ActivationQuantizer(_args(), bits=8)
assert aq.deploy() is None
assert aq.load_deploy(scale=1.0, zero=0.0) is None
def test_weight_quantizer_forward_passthrough_when_disabled():
wq = WeightQuantizer(_args(w_bits=8), w_bits=8)
w = torch.randn(4, 8)
assert torch.equal(wq(w), w)
def test_weight_quantizer_forward_quantizes_when_enabled():
wq = WeightQuantizer(_args(w_bits=8), w_bits=8)
wq.enable = True
w = torch.randn(4, 8, dtype=torch.float32)
out = wq(w)
assert out.shape == w.shape
def test_weight_quantizer_observe_input_dispatches_to_algorithms_with_hook():
seen = []
@ALGO_REGISTRY.register(name=UT_OBSERVE_ALGO, targets=("weight",))
class _Obs(nn.Module):
def __init__(self, args, *_):
super().__init__()
def observe_input(self, x, weight):
seen.append((x, weight))
try:
wq = WeightQuantizer(_args(algos=[UT_OBSERVE_ALGO], w_bits=8), w_bits=8)
x = torch.zeros(1, 4)
w = torch.ones(4, 4)
wq.observe_input(x, w)
assert len(seen) == 1
assert torch.equal(seen[0][0], x) and torch.equal(seen[0][1], w)
finally:
ALGO_REGISTRY._items.pop(UT_OBSERVE_ALGO, None)
def test_weight_quantizer_algo_forward_chains_non_quantize_algos():
@ALGO_REGISTRY.register(name=UT_DOUBLE_ALGO, targets=("weight",))
class _Double(nn.Module):
def __init__(self, args, *_):
super().__init__()
def forward(self, x):
return x * 2
try:
wq = WeightQuantizer(_args(algos=[UT_DOUBLE_ALGO], w_bits=8), w_bits=8)
out, qa = wq.algo_forward(torch.ones(1, 4))
assert qa is None
assert torch.equal(out, torch.full((1, 4), 2.0))
finally:
ALGO_REGISTRY._items.pop(UT_DOUBLE_ALGO, None)
def test_weight_quantizer_algo_forward_picks_quantize_hook_separately():
@ALGO_REGISTRY.register(name=UT_QUANT_HOOK_ALGO, targets=("weight",))
class _Q(nn.Module):
def __init__(self, args, *_):
super().__init__()
def quantize(self, x, quant_obj):
return x * 0
try:
wq = WeightQuantizer(_args(algos=[UT_QUANT_HOOK_ALGO], w_bits=8), w_bits=8)
x = torch.ones(1, 4)
out, qa = wq.algo_forward(x)
assert torch.equal(out, x)
assert isinstance(qa, _Q)
finally:
ALGO_REGISTRY._items.pop(UT_QUANT_HOOK_ALGO, None)
def test_weight_quantizer_algo_forward_rejects_multiple_quantize_hooks():
for n in (UT_QH_A_ALGO, UT_QH_B):
@ALGO_REGISTRY.register(name=n, targets=("weight",))
class _Q(nn.Module):
def __init__(self, args, *_):
super().__init__()
def quantize(self, x, q):
return x
try:
wq = WeightQuantizer(_args(algos=[UT_QH_A_ALGO, UT_QH_B], w_bits=8), w_bits=8)
with pytest.raises(ValueError, match="Only one weight algorithm"):
wq.algo_forward(torch.zeros(1, 4))
finally:
for n in (UT_QH_A_ALGO, UT_QH_B):
ALGO_REGISTRY._items.pop(n, None)
def test_weight_quantizer_export_deploy_uses_quant_obj_export():
wq = WeightQuantizer(_args(w_bits=8), w_bits=8)
out = wq.export_deploy(torch.randn(4, 8))
assert "qweight" in out and "weight_scale" in out
def test_weight_quantizer_export_deploy_rejects_quantize_hook_path():
@ALGO_REGISTRY.register(name=UT_QH_EXPORT_ALGO, targets=("weight",))
class _Q(nn.Module):
def __init__(self, args, *_):
super().__init__()
def quantize(self, x, q):
return x
try:
wq = WeightQuantizer(_args(algos=[UT_QH_EXPORT_ALGO], w_bits=8), w_bits=8)
with pytest.raises(NotImplementedError, match="custom weight quantize"):
wq.export_deploy(torch.zeros(4, 8))
finally:
ALGO_REGISTRY._items.pop(UT_QH_EXPORT_ALGO, None)
def test_build_algorithms_raises_when_algo_declares_targets_but_mismatches():
name = "_ut_struct_mis"
@ALGO_REGISTRY.register(name=name, targets=("weight",))
class _Algo(nn.Module):
def __init__(self, args, ctx=None):
super().__init__()
try:
out = build_algorithms_by_target(_args(algos=[name]), "activation")
assert isinstance(out, nn.ModuleDict)
assert len(out) == 0
finally:
ALGO_REGISTRY._items.pop(name, None)
def test_activation_quantizer_trainable_params_returns_params_from_algo():
name = "_ut_act_tp"
@ALGO_REGISTRY.register(name=name, targets=("activation",))
class _AlgoWithParams(nn.Module):
def __init__(self, args):
super().__init__()
self.p = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
return x
def trainable_params(self):
return [self.p]
try:
aq = ActivationQuantizer(_args(algos=[name]), bits=8)
params = aq.trainable_params()
assert len(params) == 1
finally:
ALGO_REGISTRY._items.pop(name, None)
def test_activation_quantizer_forward_applies_algo_when_enabled():
name = "_ut_act_fwd"
@ALGO_REGISTRY.register(name=name, targets=("activation",))
class _DoubleAlgo(nn.Module):
def __init__(self, args):
super().__init__()
def forward(self, x):
return x * 2
try:
aq = ActivationQuantizer(_args(algos=[name]), bits=8)
aq.enable = True
x = torch.tensor([1.0, 2.0, 3.0])
out = aq(x)
assert out.dtype == x.dtype
assert out.shape == x.shape
finally:
ALGO_REGISTRY._items.pop(name, None)
def test_weight_quantizer_trainable_params_returns_params_from_algo():
name = "_ut_wt_tp"
@ALGO_REGISTRY.register(name=name, targets=("weight",))
class _WtAlgoWithParams(nn.Module):
def __init__(self, args, *_):
super().__init__()
self.p = nn.Parameter(torch.tensor(2.0))
def forward(self, x):
return x
def trainable_params(self):
return [self.p]
try:
wq = WeightQuantizer(_args(algos=[name], w_bits=8), w_bits=8)
params = wq.trainable_params()
assert len(params) == 1
finally:
ALGO_REGISTRY._items.pop(name, None)
def test_weight_quantizer_forward_uses_quantize_algo_when_enabled():
name = "_ut_wt_qalgo"
@ALGO_REGISTRY.register(name=name, targets=("weight",))
class _QAlgo(nn.Module):
def __init__(self, args, *_):
super().__init__()
def quantize(self, x, quant_obj):
return x * 100
try:
wq = WeightQuantizer(_args(algos=[name], w_bits=8), w_bits=8)
wq.enable = True
x = torch.tensor([1.0, 2.0])
out = wq(x)
assert torch.equal(out, x * 100)
finally:
ALGO_REGISTRY._items.pop(name, None)
def test_weight_quantizer_export_deploy_rejects_unsupported_dtype(monkeypatch):
wq = WeightQuantizer(_args(w_bits=8), w_bits=8)
wq.quant_obj.export_deploy = None
with pytest.raises(NotImplementedError, match="does not implement export_deploy"):
wq.export_deploy(torch.randn(4, 8))
def test_build_algorithms_raises_with_missing_targets():
name = "_ut_missing_targets"
ALGO_REGISTRY.register(name=name)(type("T", (), {}))
try:
with pytest.raises(ValueError, match="missing registry metadata"):
build_algorithms_by_target(_args(algos=[name]), "weight")
finally:
ALGO_REGISTRY._items.pop(name, None)
def test_build_algorithms_raises_when_target_not_in_algo_targets():
name = "_ut_struct_nonmatch"
@ALGO_REGISTRY.register(name=name, targets=("weight",))
class _Algo(nn.Module):
def __init__(self, args, ctx=None):
super().__init__()
try:
out = build_algorithms_by_target(_args(algos=[name]), "activation")
assert isinstance(out, nn.ModuleDict)
assert len(out) == 0
finally:
ALGO_REGISTRY._items.pop(name, None)
def test_build_algorithms_by_target_raises_on_missing_targets_metadata(monkeypatch):
from types import SimpleNamespace as simple_ns
from amct_pytorch.algorithms.registry_factory import ALGO_REGISTRY as algo_registry
from amct_pytorch.quantization.modules import quant_base as quant_base_mod
monkeypatch.setattr(
quant_base_mod, "get_algo_names_by_target",
lambda args, target: ["fake_algo"],
)
monkeypatch.setattr(
algo_registry, "get_item",
lambda name: simple_ns(metadata={}, target=lambda *a: None),
)
args = SimpleNamespace()
with pytest.raises(ValueError, match="missing registry metadata"):
quant_base_mod.build_algorithms_by_target(args, "mlp")
def test_build_algorithms_by_target_raises_on_mismatched_target(monkeypatch):
from types import SimpleNamespace as simple_ns
from amct_pytorch.algorithms.registry_factory import ALGO_REGISTRY as algo_registry
from amct_pytorch.quantization.modules import quant_base as quant_base_mod
monkeypatch.setattr(
quant_base_mod, "get_algo_names_by_target",
lambda args, target: ["fake_algo"],
)
monkeypatch.setattr(
algo_registry, "get_item",
lambda name: simple_ns(metadata={"targets": ("attn",)}, target=lambda *a: None),
)
args = SimpleNamespace()
with pytest.raises(ValueError, match="cannot be used for target"):
quant_base_mod.build_algorithms_by_target(args, "mlp")