from types import SimpleNamespace
import pytest
import torch
import torch.nn as nn
from amct_pytorch.common.optimization.base_solver import BaseSolver
class _ConcreteSolver(BaseSolver):
def solve(self, calibration_data):
return calibration_data
def _args(quant_target=None):
return SimpleNamespace(quant_target=quant_target or ["mlp"])
def test_cannot_instantiate_abstract_base_solver():
with pytest.raises(TypeError):
BaseSolver(_args(), 0, nn.Linear(4, 4))
def test_concrete_solver_records_attributes():
solver = _ConcreteSolver(args=_args(["attn-linear"]), layer_idx=3, model=nn.Linear(4, 4))
assert solver.layer_idx == 3
assert solver.quant_target == ["attn-linear"]
assert solver.optimizer is None and solver.lr_scheduler is None
assert solver.max_iters == 100
assert solver.current_iter == 0
assert solver.granularity == "block"
def test_finalize_uses_export_ptq_params_when_available():
class _Exportable(nn.Module):
def export_ptq_params(self):
return {"k": torch.tensor([1.0])}
solver = _ConcreteSolver(args=_args(), layer_idx=0, model=_Exportable())
out = solver.finalize()
assert torch.equal(out["k"], torch.tensor([1.0]))
def test_finalize_falls_back_to_trainable_params():
m = nn.Module()
m.trainable = nn.Parameter(torch.ones(2))
m.frozen = nn.Parameter(torch.zeros(2), requires_grad=False)
solver = _ConcreteSolver(args=_args(), layer_idx=0, model=m)
out = solver.finalize()
assert set(out) == {"trainable"}
assert torch.equal(out["trainable"], torch.ones(2))
def test_step_calls_optimizer_and_scheduler_and_increments_iter():
class _StubOpt:
def __init__(self):
self.stepped = 0
self.zeroed = 0
def step(self):
self.stepped += 1
def zero_grad(self):
self.zeroed += 1
class _StubSched:
def __init__(self):
self.stepped = 0
def step(self):
self.stepped += 1
opt = _StubOpt()
sched = _StubSched()
solver = _ConcreteSolver(args=_args(), layer_idx=0, model=nn.Linear(4, 4))
solver.optimizer = opt
solver.lr_scheduler = sched
solver.step()
solver.step()
assert opt.stepped == 2
assert opt.zeroed == 2
assert sched.stepped == 2
assert solver.current_iter == 2
def test_step_no_op_when_optimizer_and_scheduler_unset():
solver = _ConcreteSolver(args=_args(), layer_idx=0, model=nn.Linear(4, 4))
solver.step()
assert solver.current_iter == 1
def test_unimplemented_hooks_default_to_no_op():
solver = _ConcreteSolver(args=_args(), layer_idx=0, model=nn.Linear(4, 4))
assert solver.forward_once() is None
assert solver.origin_forward() is None
assert solver.quant_forward() is None
def test_base_solver_granularity_class_attribute():
from amct_pytorch.common.optimization.base_solver import (
BaseSolver as base_solver_cls,
)
assert base_solver_cls.granularity == "block"