"""Regression guard for the top-level ``amct_pytorch`` public API.
Background: the graph-based interfaces (``create_quant_config`` and friends)
are defined in ``amct_pytorch.classic.graph_based.amct_pytorch`` and surfaced at
the top level by a single wildcard re-export in ``amct_pytorch/__init__.py``.
Refactoring the package layout once dropped that line, which silently removed
~25 interfaces from ``amct_pytorch.*``. These tests make that regression fail
loudly.
The static checks parse source with ``ast`` and need no built package, so they
run in any CI environment. The behavioural check imports the real package and is
skipped when the generated protobuf modules are absent (unbuilt source tree).
"""
import ast
import importlib
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
TOP_INIT = REPO_ROOT / "amct_pytorch" / "__init__.py"
GRAPH_INIT = (
REPO_ROOT
/ "amct_pytorch"
/ "classic"
/ "graph_based"
/ "amct_pytorch"
/ "__init__.py"
)
GRAPH_MODULE = "amct_pytorch.classic.graph_based.amct_pytorch"
BASELINE_INTERFACES = (
"create_quant_config",
"quantize_model",
"save_model",
"create_quant_cali_config",
"create_quant_retrain_config",
"accuracy_based_auto_calibration",
"create_distill_config",
"auto_channel_prune_search",
)
CLASSIC_SOURCE = "amct_pytorch.classic"
CLASSIC_INTERFACES = ("quantize", "convert", "algorithm_register")
CONFIG_SOURCE = "amct_pytorch.common.config"
CONFIG_INTERFACES = (
"INT4_AWQ_WEIGHT_QUANT_CFG",
"INT4_GPTQ_WEIGHT_QUANT_CFG",
"INT8_SMOOTHQUANT_CFG",
"INT8_MINMAX_WEIGHT_QUANT_CFG",
"HIFP8_OFMR_CFG",
"FP8_OFMR_CFG",
"MXFP8_QUANT_CFG",
"MXFP4_AWQ_WEIGHT_QUANT_CFG",
"HIFP8_CAST_CFG",
"HIFP8_QUANTILE_CFG",
)
NON_GRAPH_INTERFACES = {name: CLASSIC_SOURCE for name in CLASSIC_INTERFACES}
NON_GRAPH_INTERFACES.update({name: CONFIG_SOURCE for name in CONFIG_INTERFACES})
def _module_dunder_all(init_path):
"""Return the literal ``__all__`` list defined in ``init_path``."""
tree = ast.parse(init_path.read_text(encoding="utf-8"))
for node in ast.walk(tree):
if not isinstance(node, ast.Assign):
continue
names = {t.id for t in node.targets if isinstance(t, ast.Name)}
if "__all__" not in names:
continue
return [e.value for e in node.value.elts if isinstance(e, ast.Constant)]
return []
def _explicit_imported_names(init_path, module):
"""Return names explicitly imported via ``from <module> import ...``."""
tree = ast.parse(init_path.read_text(encoding="utf-8"))
names = []
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.module == module:
names.extend(a.name for a in node.names if a.name != "*")
return names
def test_graph_module_all_covers_baseline_interfaces():
"""The graph module ``__all__`` must contain every baseline interface.
Deriving the "expected top-level API" from ``__all__`` alone would pass
vacuously if ``__all__`` were emptied. This anchors a known-good subset so
that silently dropping any of these names fails the suite.
"""
exported = set(_module_dunder_all(GRAPH_INIT))
missing = [name for name in BASELINE_INTERFACES if name not in exported]
assert not missing, (
f"Baseline public interfaces missing from __all__ of {GRAPH_MODULE}: "
f"{missing}. Declared names: {sorted(exported)}"
)
def test_top_level_interfaces_importable_when_built():
"""Behavioural guard: every ``__all__`` name is reachable as ``amct_pytorch.*``.
Skipped on an unbuilt source tree (generated ``*_pb2.py`` missing); runs as a
true regression test wherever the package is installed/built. Covers the full
graph-module ``__all__`` (not just the baseline), so any interface dropped
from the top-level namespace is caught.
"""
try:
import amct_pytorch
importlib.import_module(GRAPH_MODULE)
except ImportError as exc:
pytest.skip(
f"graph-based package not importable in this environment: {exc}")
expected = _module_dunder_all(GRAPH_INIT)
assert expected, "graph module __all__ unexpectedly empty"
missing = [name for name in expected if not hasattr(amct_pytorch, name)]
assert not missing, (
f"amct_pytorch is missing top-level interfaces {missing}; the "
f"graph-based wildcard re-export in __init__.py was likely dropped."
)
@pytest.mark.parametrize("name", sorted(NON_GRAPH_INTERFACES))
def test_non_graph_interface_in_top_all(name):
"""Each non-graph public interface stays declared in the top-level ``__all__``.
Static check (no build needed): guards the contract that ``quantize`` /
``convert`` / ``algorithm_register`` and the ``*_CFG`` config constants are
advertised as part of the public API.
"""
top_all = _module_dunder_all(TOP_INIT)
assert name in top_all, (
f"'{name}' dropped from amct_pytorch.__all__; it is part of the public "
f"top-level API. Declared names: {top_all}"
)
@pytest.mark.parametrize("name", sorted(NON_GRAPH_INTERFACES))
def test_non_graph_interface_explicitly_imported(name):
"""Each non-graph interface is bound by an explicit ``from <src> import``.
Static check: presence in ``__all__`` alone does not bind the name; the
matching import must exist or ``amct_pytorch.<name>`` raises AttributeError.
"""
source = NON_GRAPH_INTERFACES[name]
imported = _explicit_imported_names(TOP_INIT, source)
assert name in imported, (
f"'{name}' is in __all__ but not imported from '{source}' in "
f"amct_pytorch/__init__.py, so amct_pytorch.{name} would fail. "
f"Names imported from {source}: {imported}"
)
def test_config_interfaces_importable():
"""Behavioural guard for the config constants (build-independent).
``amct_pytorch.common.config`` does not pull in the graph protobuf chain, so
this runs in any environment and verifies the constants resolve to real
objects at their source module.
"""
config = importlib.import_module(CONFIG_SOURCE)
missing = [n for n in CONFIG_INTERFACES if not hasattr(config, n)]
assert not missing, (
f"{CONFIG_SOURCE} is missing config constants {missing}; the top-level "
f"amct_pytorch re-export of these would break."
)
def test_non_graph_interfaces_importable_when_built():
"""Behavioural guard: non-graph interfaces reachable as ``amct_pytorch.*``.
Skipped on an unbuilt source tree (``amct_pytorch.classic`` transitively
imports the graph protobuf chain). Where built, asserts every non-graph
public name actually resolves on the top-level package.
"""
try:
import amct_pytorch
except ImportError as exc:
pytest.skip(f"amct_pytorch not importable in this environment: {exc}")
missing = [n for n in NON_GRAPH_INTERFACES if not hasattr(amct_pytorch, n)]
assert not missing, (
f"amct_pytorch is missing non-graph top-level interfaces {missing}; an "
f"explicit import in __init__.py was likely dropped."
)