from types import SimpleNamespace
import pytest
import torch
from amct_pytorch.algorithms.quant.auto_round import (
AutoRound,
_get_scale_shape,
_reshape_pad_tensor_by_group_size,
_revert_tensor_by_pad,
)
def _auto_round_args(w_size=(4, 8), quant_dtype="int"):
return SimpleNamespace(w_size=w_size, quant_dtype=quant_dtype)
def test_reshape_pad_tensor_by_group_size_no_pad():
weight = torch.randn(4, 8)
reshaped, orig, pad = _reshape_pad_tensor_by_group_size(weight, -1)
assert reshaped.shape == (4, 8)
assert pad == 0
def test_reshape_pad_tensor_by_group_size_with_pad():
weight = torch.randn(2, 10)
reshaped, orig, pad = _reshape_pad_tensor_by_group_size(weight, 4)
assert pad == 2
assert reshaped.shape[-1] == 4
def test_reshape_pad_tensor_by_group_size_raises_on_non_2d():
with pytest.raises(ValueError, match="2D"):
_reshape_pad_tensor_by_group_size(torch.randn(2, 3, 4), 4)
def test_revert_tensor_by_pad_no_pad():
tensor = torch.randn(4, 8)
reverted = _revert_tensor_by_pad(tensor, (4, 8), 0)
assert reverted.shape == (4, 8)
assert torch.equal(reverted, tensor)
def test_revert_tensor_by_pad_with_pad():
tensor = torch.randn(4, 8)
reverted = _revert_tensor_by_pad(tensor, (4, 6), 2)
assert reverted.shape == (4, 6)
def test_revert_tensor_by_pad_raises_on_non_2d():
with pytest.raises(ValueError, match="2D"):
_revert_tensor_by_pad(torch.randn(4, 8), (2, 3, 4), 0)
def test_get_scale_shape_per_row():
assert _get_scale_shape((4, 8), -1) == 4
def test_get_scale_shape_per_tensor():
assert _get_scale_shape((4, 8), 0) == 1
def test_get_scale_shape_grouped():
assert _get_scale_shape((4, 32), 16) == 8
def test_auto_round_init_int_dtype():
args = _auto_round_args(w_size=(4, 8), quant_dtype="int")
ar = AutoRound(args, w_bits=8)
assert ar.bits == 8
assert ar.group_size == -1
def test_auto_round_init_mxfp_dtype():
args = _auto_round_args(w_size=(4, 32), quant_dtype="mxfp")
ar = AutoRound(args, w_bits=4)
assert ar.bits == 4
assert ar.group_size == 32
def test_auto_round_trainable_params():
args = _auto_round_args()
ar = AutoRound(args, w_bits=8)
params = ar.trainable_params()
assert len(params) == 3
def test_auto_round_export_ptq_params():
args = _auto_round_args()
ar = AutoRound(args, w_bits=8)
exported = ar.export_ptq_params()
assert "value" in exported
assert "min_scale" in exported
assert "max_scale" in exported
def test_auto_round_forward_passthrough():
args = _auto_round_args()
ar = AutoRound(args, w_bits=8)
weight = torch.randn(4, 8)
out = ar(weight)
assert torch.equal(out, weight)
def test_get_group_size_unsupported_raises():
args = _auto_round_args(w_size=(4, 8), quant_dtype="fp8")
with pytest.raises(ValueError, match="hifx"):
AutoRound._get_group_size(args)
def test_auto_round_load_ptq_params():
args = _auto_round_args(w_size=(4, 32), quant_dtype="int")
ar = AutoRound(args, w_bits=8)
value = torch.randn(4, 32)
min_scale = torch.randn(4)
max_scale = torch.randn(4)
ar.load_ptq_params({"value": value, "min_scale": min_scale, "max_scale": max_scale})
assert torch.equal(ar.value.data, value.to(dtype=ar.value.dtype, device=ar.value.device))
def test_auto_round_compute_clip_range():
args = _auto_round_args(w_size=(2, 8), quant_dtype="int")
ar = AutoRound(args, w_bits=8)
weight = torch.randn(2, 8)
clip_min, clip_max = ar._compute_clip_range(weight)
assert clip_min.shape == (2, 1)
assert clip_max.shape == (2, 1)
assert (clip_min <= clip_max).all()
def test_auto_round_quantize():
args = _auto_round_args(w_size=(2, 8), quant_dtype="int")
ar = AutoRound(args, w_bits=8)
def mock_quant_obj(weight, v=None):
return weight
weight = torch.randn(2, 8)
result = ar.quantize(weight, mock_quant_obj)
assert result.shape == weight.shape
def test_auto_round_quantize_with_group_size():
args = _auto_round_args(w_size=(2, 12), quant_dtype="int")
ar = AutoRound(args, w_bits=8)
def mock_quant_obj(weight, v=None):
return weight
weight = torch.randn(2, 12)
result = ar.quantize(weight, mock_quant_obj)
assert result.shape == weight.shape