import os
import importlib
import unittest
from unittest import mock
import json
import torch
from torch import nn
from mindiesd.quantization.config import QuantConfig, LayerQuantConfig, OnlineQuantConfig
from mindiesd.quantization.layer import (
W8A8QuantBaseLinear,
WeightQuantLinear,
FP8RotateQuantFA,
W8A8MXFP8QuantLinear,
W4A4QuantLinear,
W4A4MXFP4QuantLinear,
)
from mindiesd.quantization.mode import QuantAlgorithm
from mindiesd.quantization.quantize import smooth_quantize_w8a8, smooth_quantize, quantize
from mindiesd.quantization.quantize import weight_quantize, w8a16_quantize, add_fa_quant
from mindiesd.quantization.quantize import get_cfg_and_weights
from mindiesd.quantization.quantize import _online_quantize_impl
from mindiesd.utils import ParametersInvalid, ConfigError
from mindiesd.utils.get_platform import NPUDevice, get_npu_device
quantize_module = importlib.import_module("mindiesd.quantization.quantize")
class CustomLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device=device, dtype=dtype)
class MockSafeTensorHandler:
def __init__(self, data):
self.data = data
def get_tensor(self, key):
return self.data.get(key, None)
def keys(self):
return self.data.keys()
def create_mock_handler(mock_data):
return MockSafeTensorHandler(mock_data)
class FakeOnlineQuantLinear(nn.Module):
init_records = []
def __init__(self, original_linear, dtype=torch.bfloat16, fallback_timesteps=None):
super().__init__()
self.input_feature = original_linear.in_features
self.output_feature = original_linear.out_features
self.dtype = dtype
self.fallback_timesteps = fallback_timesteps
self.register_buffer(
"weight", torch.empty(original_linear.out_features, original_linear.in_features), persistent=False
)
FakeOnlineQuantLinear.init_records.append(
{
"dtype": dtype,
"fallback_timesteps": fallback_timesteps,
"in_features": original_linear.in_features,
"out_features": original_linear.out_features,
}
)
def forward(self, x):
return torch.empty(*x.shape[:-1], self.output_feature, dtype=self.dtype)
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."
)
class TestSmoothQuantize(unittest.TestCase):
def setUp(self):
in_features = 10
out_features = 10
self.weights = {
"0.quant_bias": torch.ones(out_features, dtype=torch.int32),
"0.deq_scale": torch.ones(out_features, dtype=torch.int64),
"0.input_scale": torch.ones(1, dtype=torch.float16),
"0.input_offset": torch.ones(1, dtype=torch.float16),
"0.weight": torch.ones(out_features, in_features, dtype=torch.int8),
"0.bias": torch.ones(out_features, dtype=torch.float32),
}
self.weights2 = {
"0.linear.quant_bias": torch.ones(out_features, dtype=torch.int32),
"0.linear.deq_scale": torch.ones(out_features, dtype=torch.int64),
"0.linear.input_scale": torch.ones(1, dtype=torch.float16),
"0.linear.input_offset": torch.ones(1, dtype=torch.float16),
"0.linear.weight": torch.ones(out_features, in_features, dtype=torch.int8),
"0.linear.bias": torch.ones(out_features, dtype=torch.float32),
"0.div.mul_scale": torch.ones(out_features, dtype=torch.float32),
}
self.weights3 = {
"0.weight": torch.ones(out_features, in_features, dtype=torch.int8),
"0.weight_scale": torch.ones(out_features, out_features, dtype=torch.float32),
"0.bias": torch.ones(out_features, dtype=torch.float32),
}
self.weights4 = {
"0.linear.weight": torch.ones(out_features, in_features, dtype=torch.int8),
"0.linear.weight_scale": torch.ones(out_features, out_features, dtype=torch.float32),
"0.linear.bias": torch.ones(out_features, dtype=torch.float32),
"0.div.mul_scale": torch.ones(out_features, dtype=torch.float32),
}
in_features_w4a4 = 8
out_features_w4a4 = 8
self.weights5 = {
"0.linear.weight": torch.ones(out_features_w4a4, in_features_w4a4, dtype=torch.int8),
"0.linear.weight_scale": torch.ones(out_features_w4a4, out_features_w4a4, dtype=torch.float32),
"0.linear.bias": torch.ones(out_features_w4a4, dtype=torch.float32),
"0.div.mul_scale": torch.ones(out_features_w4a4, dtype=torch.float32),
}
self.weights6 = {
"0.weight": torch.ones(out_features_w4a4, in_features_w4a4, dtype=torch.float8_e4m3fn),
"0.weight_scale": torch.ones(out_features_w4a4, out_features_w4a4, dtype=torch.uint8),
"0.bias": torch.ones(out_features_w4a4, dtype=torch.float32),
}
def test_smooth_quantize_w8a8_with_linear(self):
layer = nn.Linear(10, 10)
cfg = QuantConfig()
quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights))
self.assertIsInstance(quant_layer, W8A8QuantBaseLinear)
self.assertTrue(is_modified)
def test_smooth_quantize_w4a4_with_linear(self):
layer = nn.Linear(8, 8)
cfg = QuantConfig(quant_algo=QuantAlgorithm.W4A4_DYNAMIC)
quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights5))
self.assertIsInstance(quant_layer, W4A4QuantLinear)
self.assertTrue(is_modified)
def test_smooth_quantize_w8a8_with_anti_linear(self):
layer = nn.Linear(10, 10)
cfg = QuantConfig()
quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights2))
self.assertIsInstance(quant_layer, W8A8QuantBaseLinear)
self.assertTrue(is_modified)
def test_smooth_quantize_w8a8_with_fuse_linear(self):
layer = nn.Linear(10, 10)
layer.fuse_algo = QuantAlgorithm.W8A8
cfg = QuantConfig()
quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights))
self.assertIsInstance(quant_layer, W8A8QuantBaseLinear)
self.assertTrue(is_modified)
def test_smooth_quantize_w8a8_with_unsupported_layer(self):
layer = nn.ReLU()
cfg = QuantConfig()
quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights))
self.assertEqual(quant_layer, layer)
self.assertFalse(is_modified)
def test_smooth_quantize_w8a8_mxfp8_with_linear(self):
layer = nn.Linear(10, 10)
cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A8_MXFP8)
quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights3))
self.assertIsInstance(quant_layer, W8A8MXFP8QuantLinear)
self.assertTrue(is_modified)
def test_smooth_quantize_w8a8_mxfp8_with_anti_linear(self):
layer = nn.Linear(10, 10)
cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A8_MXFP8)
quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights4))
self.assertIsInstance(quant_layer, W8A8MXFP8QuantLinear)
self.assertTrue(is_modified)
@unittest.skipIf(get_npu_device() != NPUDevice.A5, "Skip unsupported tests when device is not available.")
def test_smooth_quantize_w4a4_mxfp4_with_linear(self):
layer = nn.Linear(8, 8)
cfg = QuantConfig(quant_algo=QuantAlgorithm.W4A4_MXFP4_DYNAMIC)
quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights6))
self.assertIsInstance(quant_layer, W4A4MXFP4QuantLinear)
self.assertTrue(is_modified)
def test_smooth_quantize_with_supported_algo(self):
layer = nn.Linear(10, 10)
cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A8)
quant_layer, is_modified = smooth_quantize("0", layer, cfg, create_mock_handler(self.weights))
self.assertIsInstance(quant_layer, W8A8QuantBaseLinear)
self.assertTrue(is_modified)
def test_smooth_quantize_with_unsupported_algo(self):
layer = nn.Linear(10, 10)
cfg = QuantConfig(quant_algo=QuantAlgorithm.NO_QUANT)
quant_layer, is_modified = smooth_quantize("0", layer, cfg, create_mock_handler(self.weights))
self.assertEqual(quant_layer, layer)
self.assertFalse(is_modified)
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."
)
class TestQuantize(unittest.TestCase):
def setUp(self):
in_features = 10
out_features = 10
self.weights = {
"0.quant_bias": torch.ones(out_features, dtype=torch.int32),
"0.weight_scale": torch.ones(1, dtype=torch.bfloat16),
"0.deq_scale": torch.ones(out_features, dtype=torch.int64),
"0.input_scale": torch.ones(1, dtype=torch.float16),
"0.input_offset": torch.ones(1, dtype=torch.float16),
"0.weight": torch.ones(out_features, in_features, dtype=torch.int8),
"0.bias": torch.ones(out_features, dtype=torch.float32),
}
self.weights2 = {
"0.weight_scale": torch.ones(1, dtype=torch.bfloat16),
"0.weight": torch.ones(out_features, in_features, dtype=torch.int8),
"0.bias": torch.ones(out_features, dtype=torch.float32),
}
@mock.patch.object(quantize_module, "get_cfg_and_weights")
def test_quantize_with_non_quant_config(self, mock_func):
model = nn.Sequential(nn.Linear(10, 10))
cfg = LayerQuantConfig()
mock_func.return_value = (cfg, create_mock_handler(self.weights))
quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)
self.assertEqual(quantized_model, model)
@mock.patch.object(quantize_module, "get_cfg_and_weights")
def test_quantize_with_empty_model(self, mock_func):
model = nn.Sequential()
cfg = LayerQuantConfig()
mock_func.return_value = (cfg, create_mock_handler(self.weights))
quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)
self.assertEqual(quantized_model, model)
@mock.patch.object(quantize_module, "get_cfg_and_weights")
def test_quantize_with_excluded_layer(self, mock_func):
model = nn.Sequential(nn.Linear(10, 10))
cfg = LayerQuantConfig(
quantized_layers={"1": QuantConfig(quant_algo=QuantAlgorithm.W8A8, exclude_layers=tuple(["0"]))}
)
mock_func.return_value = (cfg, create_mock_handler(self.weights))
quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)
self.assertEqual(quantized_model, model)
@mock.patch.object(quantize_module, "get_cfg_and_weights")
def test_quantize_with_w8a8_dynamic_layer(self, mock_func):
model = nn.Sequential(nn.Linear(10, 10))
cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8)})
mock_func.return_value = cfg, create_mock_handler(self.weights)
quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)
self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)
@mock.patch.object(quantize_module, "get_cfg_and_weights")
def test_quantize_with_w8a8_timestep_layer(self, mock_func):
model = nn.Sequential(nn.Linear(10, 10))
cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8_DYNAMIC)})
mock_func.return_value = cfg, create_mock_handler(self.weights)
quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)
self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)
@mock.patch.object(quantize_module, "get_cfg_and_weights")
def test_quantize_with_w8a8_layer(self, mock_func):
model = nn.Sequential(nn.Linear(10, 10))
cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8_TIMESTEP)})
mock_func.return_value = cfg, create_mock_handler(self.weights)
quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg, t_idx=5)
self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)
@mock.patch.object(quantize_module, "get_cfg_and_weights")
def test_quantize_with_custom_w8a8_layer(self, mock_func):
model = nn.Sequential(nn.Linear(10, 10))
cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8)})
mock_func.return_value = cfg, create_mock_handler(self.weights)
quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg, map={CustomLinear: W8A8QuantBaseLinear})
self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)
@mock.patch.object(quantize_module, "get_cfg_and_weights")
def test_quantize_with_w8a8_fuse_layer(self, mock_func):
model = nn.Sequential(nn.Linear(10, 10))
model[0].fuse_algo = QuantAlgorithm.W8A8
cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8)})
mock_func.return_value = cfg, create_mock_handler(self.weights)
quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)
self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)
@mock.patch.object(quantize_module, "get_cfg_and_weights")
def test_quantize_with_w8a16_layer(self, mock_func):
model = nn.Sequential(nn.Linear(10, 10))
cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A16)})
mock_func.return_value = cfg, create_mock_handler(self.weights2)
quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)
self.assertIsInstance(quantized_model[0], WeightQuantLinear)
@mock.patch("mindiesd.utils.file_utils.safe_open")
@mock.patch("mindiesd.utils.file_utils.check_file_safety")
def test_quantize_decorator_invalid_config(self, mock_check_safety, mock_safe_open):
mock_file = mock.MagicMock()
mock_file.read.return_value = json.dumps({"layer1": "W8A8"})
mock_safe_open.return_value.__enter__.return_value = mock_file
model = nn.Sequential(nn.Linear(10, 10))
with self.assertRaises(ParametersInvalid):
quantize(model, "path/to/quant_des.json")
@mock.patch("mindiesd.utils.file_utils.safe_open")
@mock.patch("mindiesd.utils.file_utils.check_file_safety")
def test_quantize_decorator_file_error(self, mock_check_safety, mock_safe_open):
mock_safe_open.side_effect = FileNotFoundError()
model = nn.Sequential(nn.Linear(10, 10))
with self.assertRaises(FileNotFoundError):
quantize(model, "path/to/quant_des.json")
@mock.patch.object(torch.npu, "empty_cache")
def test_quantize_with_online_config(self, mock_empty_cache):
model = nn.Sequential(nn.Linear(10, 8), nn.ReLU(), nn.Linear(8, 4))
config = OnlineQuantConfig(quant_type=QuantAlgorithm.W8A8_MXFP8)
FakeOnlineQuantLinear.init_records = []
with mock.patch.dict(
quantize_module._ONLINE_QUANT_LAYER_MAP,
{QuantAlgorithm.W8A8_MXFP8: FakeOnlineQuantLinear},
):
quantized_model = quantize.__wrapped__(model, online_config=config, dtype=torch.float16)
self.assertIs(quantized_model, model)
self.assertIsInstance(model[0], FakeOnlineQuantLinear)
self.assertIsInstance(model[1], nn.ReLU)
self.assertIsInstance(model[2], FakeOnlineQuantLinear)
self.assertEqual(
[record["dtype"] for record in FakeOnlineQuantLinear.init_records], [torch.float16, torch.float16]
)
mock_empty_cache.assert_called_once()
@mock.patch.object(torch.npu, "empty_cache")
def test_online_quantize_with_fallback_layers_and_timesteps(self, mock_empty_cache):
model = nn.ModuleDict(
{
"main": nn.Linear(8, 8),
"skip": nn.Linear(8, 8),
"fallback": nn.Linear(8, 8),
}
)
config = OnlineQuantConfig(
quant_type=QuantAlgorithm.W4A4_MXFP4_DYNAMIC,
fallback_layers={"skip": QuantAlgorithm.W16A16, "fallback": QuantAlgorithm.W8A8},
fallback_timesteps=[3, 7],
)
FakeOnlineQuantLinear.init_records = []
with mock.patch.dict(
quantize_module._ONLINE_QUANT_LAYER_MAP,
{
QuantAlgorithm.W4A4_MXFP4_DYNAMIC: FakeOnlineQuantLinear,
QuantAlgorithm.W8A8: FakeOnlineQuantLinear,
},
):
quantized_model = _online_quantize_impl(model, config, dtype=torch.bfloat16)
self.assertIs(quantized_model, model)
self.assertIsInstance(model["main"], FakeOnlineQuantLinear)
self.assertIsInstance(model["skip"], nn.Linear)
self.assertIsInstance(model["fallback"], FakeOnlineQuantLinear)
self.assertEqual(FakeOnlineQuantLinear.init_records[0]["fallback_timesteps"], [3, 7])
self.assertIsNone(FakeOnlineQuantLinear.init_records[1]["fallback_timesteps"])
mock_empty_cache.assert_called_once()
def test_quantize_rejects_mixed_offline_and_online_config(self):
model = nn.Sequential(nn.Linear(10, 10))
config = OnlineQuantConfig(quant_type=QuantAlgorithm.W8A8_DYNAMIC)
with self.assertRaises(ParametersInvalid):
quantize(model, "path/to/quant_des.json", online_config=config)
def test_quantize_rejects_missing_quant_source(self):
model = nn.Sequential(nn.Linear(10, 10))
with self.assertRaises(ConfigError):
quantize(model)
def test_quantize_rejects_invalid_dtype(self):
model = nn.Sequential(nn.Linear(10, 10))
config = OnlineQuantConfig(quant_type=QuantAlgorithm.W8A8_DYNAMIC)
with self.assertRaises(ParametersInvalid):
quantize(model, online_config=config, dtype=torch.float32)
def test_quantize_rejects_invalid_online_config_type(self):
model = nn.Sequential(nn.Linear(10, 10))
with self.assertRaises(ParametersInvalid):
quantize(model, online_config={})
def test_quantize_rejects_invalid_quant_des_path(self):
model = nn.Sequential(nn.Linear(10, 10))
with self.assertRaises(ConfigError):
quantize(model, "")
@mock.patch("mindiesd.utils.file_utils.check_file_safety")
@mock.patch("mindiesd.utils.file_utils.standardize_path", side_effect=lambda path: path)
def test_quantize_rejects_invalid_timestep_config(self, _mock_standardize, _mock_check_safety):
model = nn.Sequential(nn.Linear(10, 10))
with self.assertRaises(ParametersInvalid):
quantize(model, "path/to/quant_des.json", timestep_config="invalid")
@mock.patch.object(torch.npu, "empty_cache")
def test_online_quantize_rejects_unsupported_mutated_fallback(self, _mock_empty_cache):
model = nn.Sequential(nn.Linear(10, 10))
config = OnlineQuantConfig(
quant_type=QuantAlgorithm.W8A8_DYNAMIC,
fallback_layers={"0": QuantAlgorithm.W8A8},
)
config.fallback_layers["0"] = QuantAlgorithm.NO_QUANT
with self.assertRaises(ParametersInvalid):
_online_quantize_impl(model, config)
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."
)
class TestWeightQuantize(unittest.TestCase):
def setUp(self):
in_features = 8
out_features = 8
self.weights = {
"0.weight": torch.ones(out_features, in_features, dtype=torch.int8),
"0.bias": torch.ones(out_features, dtype=torch.float32),
"0.weight_scale": torch.ones(out_features, dtype=torch.float16),
"0.weight_offset": torch.ones(out_features, dtype=torch.float16),
}
def test_weight_quantize_with_w8a16(self):
layer = nn.Linear(8, 8)
cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A16)
quant_layer, is_modified = weight_quantize("0", layer, cfg, create_mock_handler(self.weights))
self.assertIsInstance(quant_layer, WeightQuantLinear)
self.assertTrue(is_modified)
def test_weight_quantize_with_w4a16(self):
layer = nn.Linear(8, 8)
cfg = QuantConfig(quant_algo=QuantAlgorithm.W4A16)
quant_layer, is_modified = weight_quantize("0", layer, cfg, create_mock_handler(self.weights))
self.assertIsInstance(quant_layer, WeightQuantLinear)
self.assertTrue(is_modified)
def test_weight_quantize_with_unsupported_algo(self):
layer = nn.Linear(8, 8)
cfg = QuantConfig(quant_algo=QuantAlgorithm.NO_QUANT)
quant_layer, is_modified = weight_quantize("0", layer, cfg, create_mock_handler(self.weights))
self.assertEqual(quant_layer, layer)
self.assertFalse(is_modified)
def test_w8a16_quantize_with_linear(self):
layer = nn.Linear(8, 8)
cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A16)
quant_layer, is_modified = w8a16_quantize("0", layer, cfg, create_mock_handler(self.weights))
self.assertIsInstance(quant_layer, WeightQuantLinear)
self.assertTrue(is_modified)
def test_w8a16_quantize_with_unsupported_layer(self):
layer = nn.ReLU()
cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A16)
quant_layer, is_modified = w8a16_quantize("0", layer, cfg, create_mock_handler(self.weights))
self.assertEqual(quant_layer, layer)
self.assertFalse(is_modified)
def test_w8a16_quantize_with_custom_map(self):
layer = nn.Linear(8, 8)
cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A16)
custom_map = {nn.Linear: WeightQuantLinear}
quant_layer, is_modified = w8a16_quantize("0", layer, cfg, create_mock_handler(self.weights), map=custom_map)
self.assertIsInstance(quant_layer, WeightQuantLinear)
self.assertTrue(is_modified)
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."
)
class TestAddFAQuant(unittest.TestCase):
def setUp(self):
self.weights = {
"test_layer.q_rot": torch.randn(128, 128, dtype=torch.float16),
"test_layer.k_rot": torch.randn(128, 128, dtype=torch.float16),
}
def test_add_fa_quant_with_valid_layer(self):
class MockLayer(nn.Module):
pass
layer = MockLayer()
cfg = QuantConfig(quant_algo=QuantAlgorithm.FP8_DYNAMIC)
add_fa_quant(layer, cfg, "test_layer", create_mock_handler(self.weights))
self.assertTrue(hasattr(layer, 'fa_quant'))
self.assertIsInstance(layer.fa_quant, FP8RotateQuantFA)
def test_add_fa_quant_with_invalid_layer(self):
layer = nn.Linear(10, 10)
cfg = QuantConfig(quant_algo=QuantAlgorithm.NO_QUANT)
add_fa_quant(layer, cfg, "test_layer", self.weights)
self.assertFalse(hasattr(layer, 'fa_quant'))
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."
)
class TestGetCfgAndWeights(unittest.TestCase):
def setUp(self):
self.quant_des_path = "path/to/quant_des.json"
self.quant_weight_path = "path/to/quant_model_weight_w8a8.safetensors"
self.quant_des_dict = {"model_quant_type": "W8A8", "layer1": "W8A8", "layer2": "FLOAT"}
self.quant_weights = {"weight": torch.ones(1)}
@mock.patch("mindiesd.utils.file_utils.safe_open")
@mock.patch("mindiesd.utils.file_utils.check_file_safety")
@mock.patch("safetensors.safe_open")
def test_get_cfg_and_weights_normal(self, mock_safe_open0, mock_check_safety, mock_safe_open1):
mock_file = mock.MagicMock()
mock_file.read.return_value = json.dumps(self.quant_des_dict)
mock_safe_open1.return_value.__enter__.return_value = mock_file
mock_safe_open0.return_value.__enter__.return_value = create_mock_handler(self.quant_weights)
cfg, weights = get_cfg_and_weights(self.quant_des_path)
self.assertEqual(cfg.quant_algo, QuantAlgorithm.W8A8)
self.assertEqual(cfg.exclude_layers, tuple(["layer2"]))
mock_safe_open1.assert_called_once()
mock_check_safety.assert_called()
mock_safe_open0.assert_called_once()
if __name__ == '__main__':
unittest.main()