"""UT for the HiFloat8 amct_ops fallback path.
These tests exercise the code that the mocked-npu_quantize tests in test_gptq.py
and test_cast.py cannot reach: the hifloat8_supported() liveness probe and the
hifloat8_fake_quant amct_ops fallback (quant_util.py). The fallback is the whole
point of PR #147 -- run hifloat8 fake-quant when torch_npu lacks a native cast.
"""
import sys
import types
import unittest
from unittest.mock import MagicMock, patch
import torch
from amct_pytorch.common.utils import quant_util
def _install_fake_torch_npu(with_hif8=True, cast_raises=False):
"""Build a fake torch_npu module and register it in sys.modules."""
mod = types.ModuleType('torch_npu')
if with_hif8:
mod.hifloat8 = 'hifloat8_enum'
def _cast(tensor, dtype, input_dtype=None):
if cast_raises:
raise RuntimeError('aclnnCast 161002: DT_HIFLOAT8 unsupported')
return tensor
mod.npu_dtype_cast = _cast
sys.modules['torch_npu'] = mod
return mod
def _install_fake_amct_ops(record):
"""Inject a stub amct_ops.hifloat8_cast with encode/decode that record calls."""
pkg = types.ModuleType('amct_ops')
sub = types.ModuleType('amct_ops.hifloat8_cast')
def encode_to_hifloat8(tensor):
record['encoded_dtype'] = tensor.dtype
record['encode_called'] = True
return tensor
def decode_from_hifloat8(codes, dtype):
record['decode_called'] = True
return codes.to(dtype)
sub.encode_to_hifloat8 = encode_to_hifloat8
sub.decode_from_hifloat8 = decode_from_hifloat8
pkg.hifloat8_cast = sub
sys.modules['amct_ops'] = pkg
sys.modules['amct_ops.hifloat8_cast'] = sub
class TestHifloat8Supported(unittest.TestCase):
"""quant_util.hifloat8_supported() probe branches."""
def setUp(self):
quant_util.hifloat8_supported.cache_clear()
self._saved_npu = torch.Tensor.npu
torch.Tensor.npu = lambda self: self
def tearDown(self):
quant_util.hifloat8_supported.cache_clear()
torch.Tensor.npu = self._saved_npu
sys.modules.pop('torch_npu', None)
def test_returns_false_when_torch_npu_missing(self):
sys.modules['torch_npu'] = None
self.assertFalse(quant_util.hifloat8_supported())
def test_returns_false_when_hif8_attr_absent(self):
_install_fake_torch_npu(with_hif8=False)
self.assertFalse(quant_util.hifloat8_supported())
def test_returns_false_when_probe_cast_raises(self):
_install_fake_torch_npu(with_hif8=True, cast_raises=True)
self.assertFalse(quant_util.hifloat8_supported())
def test_returns_true_when_probe_round_trip_ok(self):
_install_fake_torch_npu(with_hif8=True, cast_raises=False)
self.assertTrue(quant_util.hifloat8_supported())
class TestHifloat8FakeQuant(unittest.TestCase):
"""quant_util.hifloat8_fake_quant() native vs amct_ops fallback."""
def setUp(self):
self._saved_npu = torch.Tensor.npu
torch.Tensor.npu = lambda self: self
def tearDown(self):
torch.Tensor.npu = self._saved_npu
for name in ('torch_npu', 'amct_ops', 'amct_ops.hifloat8_cast'):
sys.modules.pop(name, None)
@patch.object(quant_util, 'hifloat8_supported', return_value=True)
def test_uses_native_cast_when_supported(self, _):
mod = _install_fake_torch_npu(with_hif8=True)
x = torch.randn(4, 8, dtype=torch.float16)
out = quant_util.hifloat8_fake_quant(x)
self.assertEqual(out.shape, x.shape)
self.assertTrue(hasattr(mod, 'npu_dtype_cast'))
@patch.object(quant_util, 'hifloat8_supported', return_value=False)
def test_falls_back_to_amct_ops_when_unsupported(self, _):
record = {}
_install_fake_amct_ops(record)
x = torch.randn(4, 8, dtype=torch.bfloat16)
out = quant_util.hifloat8_fake_quant(x)
self.assertTrue(record.get('encode_called'))
self.assertTrue(record.get('decode_called'))
self.assertEqual(out.dtype, torch.bfloat16)
self.assertEqual(out.shape, x.shape)
@patch.object(quant_util, 'hifloat8_supported', return_value=False)
def test_fp32_input_downcast_to_bf16_for_amct_ops(self, _):
record = {}
_install_fake_amct_ops(record)
x = torch.randn(4, 8, dtype=torch.float32)
out = quant_util.hifloat8_fake_quant(x)
self.assertEqual(record.get('encoded_dtype'), torch.bfloat16)
self.assertEqual(out.dtype, torch.float32)
@patch.object(quant_util, 'hifloat8_supported', return_value=False)
def test_raises_importerror_when_amct_ops_missing(self, _):
sys.modules['amct_ops'] = None
sys.modules['amct_ops.hifloat8_cast'] = None
x = torch.randn(2, 4, dtype=torch.bfloat16)
with self.assertRaises(ImportError):
quant_util.hifloat8_fake_quant(x)
if __name__ == '__main__':
unittest.main()