from types import SimpleNamespace
import pytest
import torch
from amct_pytorch.algorithms.quant.omniquant import OmniQuant
def _make(dim_size=4):
return OmniQuant(args=SimpleNamespace(), ctx=SimpleNamespace(dim_size=dim_size))
def test_initial_state():
om = _make(dim_size=8)
assert om.dim == 8
assert om.log_scale.shape == (1, 8)
assert torch.equal(om.log_scale.data, torch.zeros(1, 8))
assert om.is_observe is False
def test_forward_with_zero_log_scale_is_passthrough_division():
om = _make(dim_size=4)
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
y = om(x)
assert torch.allclose(y, x)
def test_forward_inv_t_multiplies_by_scale():
om = _make(dim_size=4)
with torch.no_grad():
om.log_scale.fill_(0.0)
x = torch.tensor([[2.0, 4.0, 6.0, 8.0]])
y = om(x, inv_t=True)
assert torch.allclose(y, x)
def test_forward_with_log_scale_2_divides_input_by_e2():
om = _make(dim_size=4)
with torch.no_grad():
om.log_scale.fill_(2.0)
x = torch.ones(1, 4)
y_div = om(x, inv_t=False)
y_mul = om(x, inv_t=True)
expected_scale = torch.exp(torch.tensor(2.0))
assert torch.allclose(y_div, x / expected_scale, atol=1e-4)
assert torch.allclose(y_mul, x * expected_scale, atol=1e-4)
def test_observe_mode_updates_log_scale_with_running_max():
om = _make(dim_size=2)
om.is_observe = True
out = om(torch.tensor([[3.0, 5.0]]))
assert torch.equal(out, torch.tensor([[3.0, 5.0]]))
assert torch.allclose(om.log_scale.data, torch.log(torch.tensor([[3.0, 5.0]])), atol=1e-4)
om(torch.tensor([[1.0, 6.0]]))
expected = torch.log(torch.tensor([[3.0, 6.0]]))
assert torch.allclose(om.log_scale.data, expected, atol=1e-4)
def test_get_scale_clamps_to_finite_range():
om = _make(dim_size=2)
with torch.no_grad():
om.log_scale.fill_(50.0)
scale = om._get_scale(dtype=torch.float32, device=torch.device("cpu"))
assert scale.max().item() == pytest.approx(1e4)
def test_get_scale_clamps_low_floor():
om = _make(dim_size=2)
with torch.no_grad():
om.log_scale.fill_(-50.0)
scale = om._get_scale(dtype=torch.float32, device=torch.device("cpu"))
assert scale.min().item() == pytest.approx(1e-4)
def test_export_load_round_trip():
om = _make(dim_size=4)
with torch.no_grad():
om.log_scale.copy_(torch.tensor([[0.1, 0.2, 0.3, 0.4]]))
params = om.export_ptq_params()
assert "log_scale" in params
other = _make(dim_size=4)
other.load_ptq_params(params)
assert torch.equal(other.log_scale.data, om.log_scale.data)
def test_load_ptq_params_is_no_op_when_log_scale_missing():
om = _make(dim_size=4)
original = om.log_scale.data.clone()
om.load_ptq_params({"unrelated": torch.zeros(1)})
assert torch.equal(om.log_scale.data, original)
def test_trainable_params_returns_log_scale():
om = _make(dim_size=4)
params = om.trainable_params()
assert any(p is om.log_scale for p in params)
def test_forward_preserves_input_dtype():
om = _make(dim_size=4)
x = torch.ones(1, 4, dtype=torch.bfloat16)
assert om(x).dtype == torch.bfloat16
def test_transform_is_a_no_op():
assert _make().transform() is None