import sys
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from amct_pytorch.common.optimization.blockwise_solver import BlockwiseSolver
sys.modules["torch_npu"] = MagicMock()
def _args(epochs=1, base_lr=1e-3, optimizer="adam"):
return SimpleNamespace(
quant_target=["mlp"],
device=torch.device("cpu"),
epochs=epochs,
base_lr=base_lr,
optimizer=optimizer,
weight_decay=0.0,
momentum=0.9,
lr_scheduler="none",
)
def test_reconstruction_loss_is_mse():
solver = BlockwiseSolver(_args(), layer_idx=0, model=nn.Linear(4, 4))
out = torch.tensor([[1.0, 2.0]])
tgt = torch.tensor([[2.0, 4.0]])
expected = torch.nn.functional.mse_loss(out, tgt)
assert torch.allclose(solver._reconstruction_loss(out, tgt), expected)
class _ModuleWithTrainable(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(4))
self.bias = nn.Parameter(torch.zeros(4))
def trainable_params(self):
return [self.weight, self.bias]
def test_collect_trainable_param_groups_marks_grad_and_dedups():
layer = nn.Module()
layer.a = _ModuleWithTrainable()
layer.b = _ModuleWithTrainable()
layer.shared_alias = layer.a
solver = BlockwiseSolver(_args(), layer_idx=0, model=layer)
groups = solver._collect_trainable_param_groups(layer)
assert len(groups) == 1
params = groups[0]["params"]
assert len(params) == 4
assert all(p.requires_grad for p in params)
assert groups[0]["lr"] == pytest.approx(_args().base_lr * 10)
def test_collect_trainable_param_groups_returns_empty_when_no_trainable():
layer = nn.Linear(4, 4)
solver = BlockwiseSolver(_args(), layer_idx=0, model=layer)
assert not solver._collect_trainable_param_groups(layer)
def test_collect_trainable_param_groups_freezes_other_params_first():
layer = nn.Linear(4, 4)
solver = BlockwiseSolver(_args(), layer_idx=0, model=layer)
solver._collect_trainable_param_groups(layer)
assert all(not p.requires_grad for p in layer.parameters())
def test_collect_trainable_param_groups_skips_none_and_duplicates():
class _ModuleWithFlawedTrainable(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.zeros(4))
self.b = nn.Parameter(torch.zeros(4))
def trainable_params(self):
return [None, self.w, self.w]
layer = nn.Module()
layer.flawed = _ModuleWithFlawedTrainable()
solver = BlockwiseSolver(_args(), layer_idx=0, model=layer)
groups = solver._collect_trainable_param_groups(layer)
assert len(groups) == 1
assert len(groups[0]["params"]) == 1
class _TrainableLinear(nn.Linear):
def trainable_params(self):
return [self.weight, self.bias]
def _build_loader(in_dim=4, out_dim=4, n=4):
inps = torch.randn(n, in_dim)
targets = torch.randn(n, out_dim)
return DataLoader(TensorDataset(inps, targets), batch_size=2)
def test_solve_runs_one_epoch_and_advances_iter():
layer = _TrainableLinear(4, 4)
solver = BlockwiseSolver(_args(epochs=1), layer_idx=0, model=layer)
solver.solve(_build_loader())
assert solver.current_iter == 2
assert solver.optimizer is not None
assert solver.lr_scheduler is None
def test_solve_returns_model_unchanged_when_no_trainable_params():
layer = nn.Linear(4, 4)
solver = BlockwiseSolver(_args(), layer_idx=0, model=layer)
out = solver.solve(_build_loader())
assert out is layer
assert solver.optimizer is None
def test_optimize_block_raises_on_uninitialized_optimizer():
solver = BlockwiseSolver(_args(), layer_idx=0, model=nn.Linear(4, 4))
with pytest.raises(RuntimeError, match="Optimizer has not been initialized"):
solver._optimize_block(_build_loader())
def test_optimize_block_raises_on_loader_yielding_single_tensor():
layer = _TrainableLinear(4, 4)
solver = BlockwiseSolver(_args(), layer_idx=0, model=layer)
groups = solver._collect_trainable_param_groups(layer)
from amct_pytorch.common.optimization.factory import build_optimizer
solver.optimizer = build_optimizer(_args(), groups)
bad_loader = DataLoader(TensorDataset(torch.randn(2, 4)), batch_size=1)
with pytest.raises(ValueError, match="Expected PTQ dataloader to yield"):
solver._optimize_block(bad_loader)
def test_optimize_block_returns_zero_when_loader_empty():
layer = _TrainableLinear(4, 4)
solver = BlockwiseSolver(_args(), layer_idx=0, model=layer)
groups = solver._collect_trainable_param_groups(layer)
from amct_pytorch.common.optimization.factory import build_optimizer
solver.optimizer = build_optimizer(_args(), groups)
empty_loader = DataLoader(
TensorDataset(torch.empty(0, 4), torch.empty(0, 4)), batch_size=1
)
assert solver._optimize_block(empty_loader) == 0.0
def test_optimize_block_handles_tuple_outputs_from_model():
class _TupleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 4)
def forward(self, x):
return self.linear(x), "metadata"
def trainable_params(self):
return [self.linear.weight, self.linear.bias]
layer = _TupleModel()
solver = BlockwiseSolver(_args(), layer_idx=0, model=layer)
groups = solver._collect_trainable_param_groups(layer)
from amct_pytorch.common.optimization.factory import build_optimizer
solver.optimizer = build_optimizer(_args(), groups)
solver._optimize_block(_build_loader())
assert solver.current_iter == 2