from types import SimpleNamespace
import pytest
import torch
from amct_pytorch.common.datasets.ptq_provider import (
BlockPtqBatch,
LlmPtqDataProvider,
)
from amct_pytorch.common.models.llm.common.ptq_units import make_ptq_unit
MLP = 'mlp'
@pytest.fixture(autouse=True)
def _stub_npu_empty_cache(monkeypatch):
fake = type("F", (), {"empty_cache": staticmethod(lambda: None)})()
monkeypatch.setattr(torch, "npu", fake, raising=False)
def _make_provider(cali_bsz=2, device="cpu", data_dir="/tmp/ptq", pipeline=None):
args = SimpleNamespace(
cali_bsz=cali_bsz,
device=device,
data_dir=data_dir,
start_block_idx=2,
end_block_idx=5,
)
return LlmPtqDataProvider(args, pipeline=pipeline)
def test_load_unit_inputs_delegates_to_pipeline():
captured = {}
class _Pipe:
@staticmethod
def load_unit_inputs(data_dir, unit):
captured["data_dir"] = data_dir
captured["unit"] = unit
return ("inps", "kwargs")
provider = _make_provider(data_dir="/d", pipeline=_Pipe())
unit = make_ptq_unit(MLP, MLP, 1, module=None)
out = provider.load_unit_inputs(unit)
assert out == ("inps", "kwargs")
assert captured == {"data_dir": "/d", "unit": unit}
def test_get_model_data_packs_args():
provider = _make_provider(data_dir="/d")
assert provider.get_model_data() == {
"data_dir": "/d",
"start_block_idx": 2,
"end_block_idx": 5,
}
def test_build_unit_batch_without_gts():
provider = _make_provider(cali_bsz=2)
inps = torch.arange(8).reshape(4, 2)
unit = make_ptq_unit(MLP, MLP, layer_idx=3, module=None, metadata={"k": "v"})
batch = provider.build_unit_batch(unit, inps, kwargs={"a": 1})
assert isinstance(batch, BlockPtqBatch)
assert batch.layer_idx == 3
assert batch.unit_name == MLP
assert batch.kwargs == {"a": 1}
assert batch.has_gts is False
assert batch.metadata == {"k": "v"}
assert batch.num_samples == 4
batches = list(batch.data_loader)
assert len(batches) == 2
assert all(len(b) == 1 for b in batches)
assert batches[0][0].shape == (2, 2)
def test_build_unit_batch_with_gts_pairs_inputs_and_targets():
provider = _make_provider(cali_bsz=2)
inps = torch.arange(4).reshape(2, 2)
gts = torch.arange(4, 8).reshape(2, 2)
unit = make_ptq_unit(MLP, MLP, 0, module=None)
batch = provider.build_unit_batch(unit, inps, kwargs=None, gts=gts)
assert batch.has_gts is True
assert batch.kwargs is None
(x, y), = list(batch.data_loader)
assert torch.equal(x, inps)
assert torch.equal(y, gts)
def test_build_unit_batch_metadata_is_none_when_unit_has_empty_dict():
provider = _make_provider()
unit = make_ptq_unit(MLP, MLP, 0, module=None)
batch = provider.build_unit_batch(unit, torch.zeros(2, 2), kwargs=None)
assert batch.metadata is None
class _ScalingModule(torch.nn.Module):
def __init__(self, factor):
super().__init__()
self.factor = factor
def forward(self, x, **kwargs):
return x * self.factor
def test_materialize_gt_runs_module_and_concatenates():
provider = _make_provider(cali_bsz=2)
inps = torch.arange(8, dtype=torch.float32).reshape(4, 2)
out = provider.materialize_gt(inps, _ScalingModule(2.0))
assert out.shape == inps.shape
assert torch.equal(out, inps * 2.0)
def test_materialize_gt_passes_forward_kwargs():
seen = []
class _M(torch.nn.Module):
def forward(self, x, scale=1.0):
seen.append(scale)
return x * scale
provider = _make_provider(cali_bsz=2)
inps = torch.ones(4, 2)
out = provider.materialize_gt(inps, _M(), kwargs={"scale": 3.0})
assert seen == [3.0, 3.0]
assert torch.equal(out, inps * 3.0)
def test_materialize_gt_takes_first_when_module_returns_tuple():
class _M(torch.nn.Module):
def forward(self, x, **kw):
return x + 1, "aux"
provider = _make_provider(cali_bsz=2)
inps = torch.zeros(2, 3)
out = provider.materialize_gt(inps, _M())
assert torch.equal(out, torch.ones(2, 3))
def test_materialize_gt_non_floating_point_input():
class _EmbeddingModule(torch.nn.Module):
def forward(self, x, **kwargs):
return x.float() * 2.0
provider = _make_provider(cali_bsz=2)
inps = torch.randint(0, 10, (4,)).to(torch.int64)
out = provider.materialize_gt(inps, _EmbeddingModule())
assert out.shape == (4,)
assert out.dtype == torch.float32