"""UT for the public re-export (forwarding) modules.
These modules expose the documented public import paths and forward to the real
implementations under ``amct_pytorch.classic.graph_based.amct_pytorch``. The test
asserts that each documented path imports and resolves to the exact same object
as its underlying implementation.
"""
import importlib
import unittest
class TestPublicForwarding(unittest.TestCase):
"""Verify documented public import paths forward to the real implementations."""
def test_auto_calibration_forwarding(self):
"""auto_calibration public path forwards to the real base classes."""
base = "amct_pytorch.classic.graph_based.amct_pytorch.common.auto_calibration"
for name in ("AutoCalibrationEvaluatorBase", "AccuracyBasedAutoCalibrationBase",
"AutoCalibrationStrategyBase", "SensitivityBase",
"BinarySearchStrategy", "CosineSimilaritySensitivity"):
self._assert_same("amct_pytorch.common.auto_calibration", name, base, name)
def test_qat_module_forwarding(self):
"""Each QAT submodule public path forwards to the real implementation."""
base = "amct_pytorch.classic.graph_based.amct_pytorch.nn.module.quantization"
cases = [
("amct_pytorch.nn.module.quantization.conv2d", "Conv2dQAT"),
("amct_pytorch.nn.module.quantization.conv3d", "Conv3dQAT"),
("amct_pytorch.nn.module.quantization.conv_transpose_2d", "ConvTranspose2dQAT"),
("amct_pytorch.nn.module.quantization.linear", "LinearQAT"),
("amct_pytorch.nn.module.quantization.quant_calibration_op", "QuantCalibrationOp"),
]
for public_mod, attr in cases:
self._assert_same(public_mod, attr, base, attr)
def test_qat_quantization_package_exports(self):
"""The quantization package re-exports all QAT classes."""
mod = importlib.import_module("amct_pytorch.nn.module.quantization")
for attr in ("Conv2dQAT", "Conv3dQAT", "ConvTranspose2dQAT",
"LinearQAT", "QuantCalibrationOp"):
self.assertTrue(hasattr(mod, attr))
def test_tensor_decompose_forwarding(self):
"""tensor_decompose public path forwards to the real functions."""
base = "amct_pytorch.classic.graph_based.amct_pytorch.tensor_decompose"
for name in ("auto_decomposition", "decompose_network"):
self._assert_same("amct_pytorch.tensor_decompose", name, base, name)
def test_auto_channel_prune_forwarding(self):
"""auto_channel_prune public submodules forward to the real base classes."""
base = "amct_pytorch.classic.graph_based.amct_pytorch.common.auto_channel_prune"
self._assert_same(
"amct_pytorch.common.auto_channel_prune.sensitivity_base",
"SensitivityBase", base, "SensitivityBase")
self._assert_same(
"amct_pytorch.common.auto_channel_prune.search_channel_base",
"SearchChannelBase", base, "SearchChannelBase")
def test_auto_channel_prune_package_exports(self):
"""The auto_channel_prune package re-exports both base classes."""
mod = importlib.import_module("amct_pytorch.common.auto_channel_prune")
self.assertTrue(hasattr(mod, "SensitivityBase"))
self.assertTrue(hasattr(mod, "SearchChannelBase"))
def _assert_same(self, public_mod, public_attr, impl_mod, impl_attr):
"""Assert the public attribute is the same object as the implementation."""
pub = getattr(importlib.import_module(public_mod), public_attr)
impl = getattr(importlib.import_module(impl_mod), impl_attr)
self.assertIs(pub, impl, f"{public_mod}.{public_attr} should forward to "
f"{impl_mod}.{impl_attr}")
if __name__ == '__main__':
unittest.main()