#!/usr/bin/env python

# coding=utf-8

# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.

# MindIE is licensed under Mulan PSL v2.

# You can use this software according to the terms and conditions of the Mulan PSL v2.

# You may obtain a copy of Mulan PSL v2 at:

#          http://license.coscl.org.cn/MulanPSL2

# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,

# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,

# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.

# See the Mulan PSL v2 for more details.

# pylint: disable=abstract-method,duplicate-code

import os

import importlib

import unittest

from unittest import mock

import json

import torch

from torch import nn

from mindiesd.quantization.config import QuantConfig, LayerQuantConfig, OnlineQuantConfig

from mindiesd.quantization.layer import (

    W8A8QuantBaseLinear,

    WeightQuantLinear,

    FP8RotateQuantFA,

    W8A8MXFP8QuantLinear,

    W4A4QuantLinear,

    W4A4MXFP4QuantLinear,

)

from mindiesd.quantization.mode import QuantAlgorithm

from mindiesd.quantization.quantize import smooth_quantize_w8a8, smooth_quantize, quantize

from mindiesd.quantization.quantize import weight_quantize, w8a16_quantize, add_fa_quant

from mindiesd.quantization.quantize import get_cfg_and_weights

from mindiesd.quantization.quantize import _online_quantize_impl

from mindiesd.utils import ParametersInvalid, ConfigError

from mindiesd.utils.get_platform import NPUDevice, get_npu_device



quantize_module = importlib.import_module("mindiesd.quantization.quantize")





class CustomLinear(nn.Linear):

    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):

        super().__init__(in_features, out_features, bias, device=device, dtype=dtype)





class MockSafeTensorHandler:

    def __init__(self, data):

        self.data = data



    def get_tensor(self, key):

        return self.data.get(key, None)



    def keys(self):

        return self.data.keys()





def create_mock_handler(mock_data):

    return MockSafeTensorHandler(mock_data)





class FakeOnlineQuantLinear(nn.Module):

    init_records = []



    def __init__(self, original_linear, dtype=torch.bfloat16, fallback_timesteps=None):

        super().__init__()

        self.input_feature = original_linear.in_features

        self.output_feature = original_linear.out_features

        self.dtype = dtype

        self.fallback_timesteps = fallback_timesteps

        self.register_buffer(

            "weight", torch.empty(original_linear.out_features, original_linear.in_features), persistent=False

        )

        FakeOnlineQuantLinear.init_records.append(

            {

                "dtype": dtype,

                "fallback_timesteps": fallback_timesteps,

                "in_features": original_linear.in_features,

                "out_features": original_linear.out_features,

            }

        )



    def forward(self, x):

        return torch.empty(*x.shape[:-1], self.output_feature, dtype=self.dtype)





@unittest.skipIf(

    os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."

)

class TestSmoothQuantize(unittest.TestCase):

    def setUp(self):

        in_features = 10

        out_features = 10

        self.weights = {

            "0.quant_bias": torch.ones(out_features, dtype=torch.int32),

            "0.deq_scale": torch.ones(out_features, dtype=torch.int64),

            "0.input_scale": torch.ones(1, dtype=torch.float16),

            "0.input_offset": torch.ones(1, dtype=torch.float16),

            "0.weight": torch.ones(out_features, in_features, dtype=torch.int8),

            "0.bias": torch.ones(out_features, dtype=torch.float32),

        }

        self.weights2 = {

            "0.linear.quant_bias": torch.ones(out_features, dtype=torch.int32),

            "0.linear.deq_scale": torch.ones(out_features, dtype=torch.int64),

            "0.linear.input_scale": torch.ones(1, dtype=torch.float16),

            "0.linear.input_offset": torch.ones(1, dtype=torch.float16),

            "0.linear.weight": torch.ones(out_features, in_features, dtype=torch.int8),

            "0.linear.bias": torch.ones(out_features, dtype=torch.float32),

            "0.div.mul_scale": torch.ones(out_features, dtype=torch.float32),

        }

        self.weights3 = {

            "0.weight": torch.ones(out_features, in_features, dtype=torch.int8),

            "0.weight_scale": torch.ones(out_features, out_features, dtype=torch.float32),

            "0.bias": torch.ones(out_features, dtype=torch.float32),

        }

        self.weights4 = {

            "0.linear.weight": torch.ones(out_features, in_features, dtype=torch.int8),

            "0.linear.weight_scale": torch.ones(out_features, out_features, dtype=torch.float32),

            "0.linear.bias": torch.ones(out_features, dtype=torch.float32),

            "0.div.mul_scale": torch.ones(out_features, dtype=torch.float32),

        }

        in_features_w4a4 = 8

        out_features_w4a4 = 8

        self.weights5 = {

            "0.linear.weight": torch.ones(out_features_w4a4, in_features_w4a4, dtype=torch.int8),

            "0.linear.weight_scale": torch.ones(out_features_w4a4, out_features_w4a4, dtype=torch.float32),

            "0.linear.bias": torch.ones(out_features_w4a4, dtype=torch.float32),

            "0.div.mul_scale": torch.ones(out_features_w4a4, dtype=torch.float32),

        }

        self.weights6 = {

            "0.weight": torch.ones(out_features_w4a4, in_features_w4a4, dtype=torch.float8_e4m3fn),

            "0.weight_scale": torch.ones(out_features_w4a4, out_features_w4a4, dtype=torch.uint8),

            "0.bias": torch.ones(out_features_w4a4, dtype=torch.float32),

        }



    def test_smooth_quantize_w8a8_with_linear(self):

        layer = nn.Linear(10, 10)

        cfg = QuantConfig()

        quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights))

        self.assertIsInstance(quant_layer, W8A8QuantBaseLinear)

        self.assertTrue(is_modified)



    def test_smooth_quantize_w4a4_with_linear(self):

        layer = nn.Linear(8, 8)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W4A4_DYNAMIC)

        quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights5))

        self.assertIsInstance(quant_layer, W4A4QuantLinear)

        self.assertTrue(is_modified)



    def test_smooth_quantize_w8a8_with_anti_linear(self):

        layer = nn.Linear(10, 10)

        cfg = QuantConfig()

        quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights2))

        self.assertIsInstance(quant_layer, W8A8QuantBaseLinear)

        self.assertTrue(is_modified)



    def test_smooth_quantize_w8a8_with_fuse_linear(self):

        layer = nn.Linear(10, 10)

        layer.fuse_algo = QuantAlgorithm.W8A8

        cfg = QuantConfig()

        quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights))

        self.assertIsInstance(quant_layer, W8A8QuantBaseLinear)

        self.assertTrue(is_modified)



    def test_smooth_quantize_w8a8_with_unsupported_layer(self):

        layer = nn.ReLU()

        cfg = QuantConfig()

        quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights))

        self.assertEqual(quant_layer, layer)

        self.assertFalse(is_modified)



    def test_smooth_quantize_w8a8_mxfp8_with_linear(self):

        layer = nn.Linear(10, 10)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A8_MXFP8)

        quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights3))

        self.assertIsInstance(quant_layer, W8A8MXFP8QuantLinear)

        self.assertTrue(is_modified)



    def test_smooth_quantize_w8a8_mxfp8_with_anti_linear(self):

        layer = nn.Linear(10, 10)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A8_MXFP8)

        quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights4))

        self.assertIsInstance(quant_layer, W8A8MXFP8QuantLinear)

        self.assertTrue(is_modified)



    @unittest.skipIf(get_npu_device() != NPUDevice.A5, "Skip unsupported tests when device is not available.")

    def test_smooth_quantize_w4a4_mxfp4_with_linear(self):

        layer = nn.Linear(8, 8)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W4A4_MXFP4_DYNAMIC)

        quant_layer, is_modified = smooth_quantize_w8a8("0", layer, cfg, create_mock_handler(self.weights6))

        self.assertIsInstance(quant_layer, W4A4MXFP4QuantLinear)

        self.assertTrue(is_modified)



    def test_smooth_quantize_with_supported_algo(self):

        layer = nn.Linear(10, 10)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A8)

        quant_layer, is_modified = smooth_quantize("0", layer, cfg, create_mock_handler(self.weights))

        self.assertIsInstance(quant_layer, W8A8QuantBaseLinear)

        self.assertTrue(is_modified)



    def test_smooth_quantize_with_unsupported_algo(self):

        layer = nn.Linear(10, 10)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.NO_QUANT)

        quant_layer, is_modified = smooth_quantize("0", layer, cfg, create_mock_handler(self.weights))

        self.assertEqual(quant_layer, layer)

        self.assertFalse(is_modified)





@unittest.skipIf(

    os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."

)

class TestQuantize(unittest.TestCase):

    def setUp(self):

        in_features = 10

        out_features = 10

        self.weights = {

            "0.quant_bias": torch.ones(out_features, dtype=torch.int32),

            "0.weight_scale": torch.ones(1, dtype=torch.bfloat16),

            "0.deq_scale": torch.ones(out_features, dtype=torch.int64),

            "0.input_scale": torch.ones(1, dtype=torch.float16),

            "0.input_offset": torch.ones(1, dtype=torch.float16),

            "0.weight": torch.ones(out_features, in_features, dtype=torch.int8),

            "0.bias": torch.ones(out_features, dtype=torch.float32),

        }

        self.weights2 = {

            "0.weight_scale": torch.ones(1, dtype=torch.bfloat16),

            "0.weight": torch.ones(out_features, in_features, dtype=torch.int8),

            "0.bias": torch.ones(out_features, dtype=torch.float32),

        }



    @mock.patch.object(quantize_module, "get_cfg_and_weights")  # 装饰器指定被mock的函数

    def test_quantize_with_non_quant_config(self, mock_func):

        model = nn.Sequential(nn.Linear(10, 10))

        cfg = LayerQuantConfig()

        mock_func.return_value = (cfg, create_mock_handler(self.weights))

        quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)

        self.assertEqual(quantized_model, model)



    @mock.patch.object(quantize_module, "get_cfg_and_weights")  # 装饰器指定被mock的函数

    def test_quantize_with_empty_model(self, mock_func):

        model = nn.Sequential()

        cfg = LayerQuantConfig()

        mock_func.return_value = (cfg, create_mock_handler(self.weights))

        quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)

        self.assertEqual(quantized_model, model)



    @mock.patch.object(quantize_module, "get_cfg_and_weights")  # 装饰器指定被mock的函数

    def test_quantize_with_excluded_layer(self, mock_func):

        model = nn.Sequential(nn.Linear(10, 10))

        cfg = LayerQuantConfig(

            quantized_layers={"1": QuantConfig(quant_algo=QuantAlgorithm.W8A8, exclude_layers=tuple(["0"]))}

        )

        mock_func.return_value = (cfg, create_mock_handler(self.weights))

        quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)

        self.assertEqual(quantized_model, model)



    @mock.patch.object(quantize_module, "get_cfg_and_weights")  # 装饰器指定被mock的函数

    def test_quantize_with_w8a8_dynamic_layer(self, mock_func):

        model = nn.Sequential(nn.Linear(10, 10))

        cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8)})

        mock_func.return_value = cfg, create_mock_handler(self.weights)

        quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)

        self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)



    @mock.patch.object(quantize_module, "get_cfg_and_weights")  # 装饰器指定被mock的函数

    def test_quantize_with_w8a8_timestep_layer(self, mock_func):

        model = nn.Sequential(nn.Linear(10, 10))

        cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8_DYNAMIC)})

        mock_func.return_value = cfg, create_mock_handler(self.weights)

        quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)

        self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)



    @mock.patch.object(quantize_module, "get_cfg_and_weights")  # 装饰器指定被mock的函数

    def test_quantize_with_w8a8_layer(self, mock_func):

        model = nn.Sequential(nn.Linear(10, 10))

        cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8_TIMESTEP)})

        mock_func.return_value = cfg, create_mock_handler(self.weights)

        quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg, t_idx=5)

        self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)



    @mock.patch.object(quantize_module, "get_cfg_and_weights")  # 装饰器指定被mock的函数

    def test_quantize_with_custom_w8a8_layer(self, mock_func):

        model = nn.Sequential(nn.Linear(10, 10))

        cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8)})

        mock_func.return_value = cfg, create_mock_handler(self.weights)

        quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg, map={CustomLinear: W8A8QuantBaseLinear})

        self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)



    @mock.patch.object(quantize_module, "get_cfg_and_weights")  # 装饰器指定被mock的函数

    def test_quantize_with_w8a8_fuse_layer(self, mock_func):

        model = nn.Sequential(nn.Linear(10, 10))

        model[0].fuse_algo = QuantAlgorithm.W8A8

        cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A8)})

        mock_func.return_value = cfg, create_mock_handler(self.weights)

        quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)

        self.assertIsInstance(quantized_model[0], W8A8QuantBaseLinear)



    @mock.patch.object(quantize_module, "get_cfg_and_weights")  # 装饰器指定被mock的函数

    def test_quantize_with_w8a16_layer(self, mock_func):

        model = nn.Sequential(nn.Linear(10, 10))

        cfg = LayerQuantConfig(quantized_layers={"0": QuantConfig(quant_algo=QuantAlgorithm.W8A16)})

        mock_func.return_value = cfg, create_mock_handler(self.weights2)

        quantized_model = quantize.__wrapped__(model, "path", custom_cfg=cfg)

        self.assertIsInstance(quantized_model[0], WeightQuantLinear)



    @mock.patch("mindiesd.utils.file_utils.safe_open")

    @mock.patch("mindiesd.utils.file_utils.check_file_safety")

    def test_quantize_decorator_invalid_config(self, mock_check_safety, mock_safe_open):

        # Mock file with invalid config

        mock_file = mock.MagicMock()

        mock_file.read.return_value = json.dumps({"layer1": "W8A8"})

        mock_safe_open.return_value.__enter__.return_value = mock_file



        # Test invalid config case

        model = nn.Sequential(nn.Linear(10, 10))

        with self.assertRaises(ParametersInvalid):

            quantize(model, "path/to/quant_des.json")



    @mock.patch("mindiesd.utils.file_utils.safe_open")

    @mock.patch("mindiesd.utils.file_utils.check_file_safety")

    def test_quantize_decorator_file_error(self, mock_check_safety, mock_safe_open):

        # Mock file operation error

        mock_safe_open.side_effect = FileNotFoundError()



        # Test file error case

        model = nn.Sequential(nn.Linear(10, 10))

        with self.assertRaises(FileNotFoundError):

            quantize(model, "path/to/quant_des.json")



    @mock.patch.object(torch.npu, "empty_cache")

    def test_quantize_with_online_config(self, mock_empty_cache):

        model = nn.Sequential(nn.Linear(10, 8), nn.ReLU(), nn.Linear(8, 4))

        config = OnlineQuantConfig(quant_type=QuantAlgorithm.W8A8_MXFP8)

        FakeOnlineQuantLinear.init_records = []



        with mock.patch.dict(

            quantize_module._ONLINE_QUANT_LAYER_MAP,

            {QuantAlgorithm.W8A8_MXFP8: FakeOnlineQuantLinear},

        ):

            quantized_model = quantize.__wrapped__(model, online_config=config, dtype=torch.float16)



        self.assertIs(quantized_model, model)

        self.assertIsInstance(model[0], FakeOnlineQuantLinear)

        self.assertIsInstance(model[1], nn.ReLU)

        self.assertIsInstance(model[2], FakeOnlineQuantLinear)

        self.assertEqual(

            [record["dtype"] for record in FakeOnlineQuantLinear.init_records], [torch.float16, torch.float16]

        )

        mock_empty_cache.assert_called_once()



    @mock.patch.object(torch.npu, "empty_cache")

    def test_online_quantize_with_fallback_layers_and_timesteps(self, mock_empty_cache):

        model = nn.ModuleDict(

            {

                "main": nn.Linear(8, 8),

                "skip": nn.Linear(8, 8),

                "fallback": nn.Linear(8, 8),

            }

        )

        config = OnlineQuantConfig(

            quant_type=QuantAlgorithm.W4A4_MXFP4_DYNAMIC,

            fallback_layers={"skip": QuantAlgorithm.W16A16, "fallback": QuantAlgorithm.W8A8},

            fallback_timesteps=[3, 7],

        )

        FakeOnlineQuantLinear.init_records = []



        with mock.patch.dict(

            quantize_module._ONLINE_QUANT_LAYER_MAP,

            {

                QuantAlgorithm.W4A4_MXFP4_DYNAMIC: FakeOnlineQuantLinear,

                QuantAlgorithm.W8A8: FakeOnlineQuantLinear,

            },

        ):

            quantized_model = _online_quantize_impl(model, config, dtype=torch.bfloat16)



        self.assertIs(quantized_model, model)

        self.assertIsInstance(model["main"], FakeOnlineQuantLinear)

        self.assertIsInstance(model["skip"], nn.Linear)

        self.assertIsInstance(model["fallback"], FakeOnlineQuantLinear)

        self.assertEqual(FakeOnlineQuantLinear.init_records[0]["fallback_timesteps"], [3, 7])

        self.assertIsNone(FakeOnlineQuantLinear.init_records[1]["fallback_timesteps"])

        mock_empty_cache.assert_called_once()



    def test_quantize_rejects_mixed_offline_and_online_config(self):

        model = nn.Sequential(nn.Linear(10, 10))

        config = OnlineQuantConfig(quant_type=QuantAlgorithm.W8A8_DYNAMIC)



        with self.assertRaises(ParametersInvalid):

            quantize(model, "path/to/quant_des.json", online_config=config)



    def test_quantize_rejects_missing_quant_source(self):

        model = nn.Sequential(nn.Linear(10, 10))



        with self.assertRaises(ConfigError):

            quantize(model)



    def test_quantize_rejects_invalid_dtype(self):

        model = nn.Sequential(nn.Linear(10, 10))

        config = OnlineQuantConfig(quant_type=QuantAlgorithm.W8A8_DYNAMIC)



        with self.assertRaises(ParametersInvalid):

            quantize(model, online_config=config, dtype=torch.float32)



    def test_quantize_rejects_invalid_online_config_type(self):

        model = nn.Sequential(nn.Linear(10, 10))



        with self.assertRaises(ParametersInvalid):

            quantize(model, online_config={})



    def test_quantize_rejects_invalid_quant_des_path(self):

        model = nn.Sequential(nn.Linear(10, 10))



        with self.assertRaises(ConfigError):

            quantize(model, "")



    @mock.patch("mindiesd.utils.file_utils.check_file_safety")

    @mock.patch("mindiesd.utils.file_utils.standardize_path", side_effect=lambda path: path)

    def test_quantize_rejects_invalid_timestep_config(self, _mock_standardize, _mock_check_safety):

        model = nn.Sequential(nn.Linear(10, 10))



        with self.assertRaises(ParametersInvalid):

            quantize(model, "path/to/quant_des.json", timestep_config="invalid")



    @mock.patch.object(torch.npu, "empty_cache")

    def test_online_quantize_rejects_unsupported_mutated_fallback(self, _mock_empty_cache):

        model = nn.Sequential(nn.Linear(10, 10))

        config = OnlineQuantConfig(

            quant_type=QuantAlgorithm.W8A8_DYNAMIC,

            fallback_layers={"0": QuantAlgorithm.W8A8},

        )

        config.fallback_layers["0"] = QuantAlgorithm.NO_QUANT



        with self.assertRaises(ParametersInvalid):

            _online_quantize_impl(model, config)





@unittest.skipIf(

    os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."

)

class TestWeightQuantize(unittest.TestCase):

    def setUp(self):

        in_features = 8

        out_features = 8

        self.weights = {

            "0.weight": torch.ones(out_features, in_features, dtype=torch.int8),

            "0.bias": torch.ones(out_features, dtype=torch.float32),

            "0.weight_scale": torch.ones(out_features, dtype=torch.float16),

            "0.weight_offset": torch.ones(out_features, dtype=torch.float16),

        }



    def test_weight_quantize_with_w8a16(self):

        layer = nn.Linear(8, 8)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A16)

        quant_layer, is_modified = weight_quantize("0", layer, cfg, create_mock_handler(self.weights))

        self.assertIsInstance(quant_layer, WeightQuantLinear)

        self.assertTrue(is_modified)



    def test_weight_quantize_with_w4a16(self):

        layer = nn.Linear(8, 8)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W4A16)

        quant_layer, is_modified = weight_quantize("0", layer, cfg, create_mock_handler(self.weights))

        self.assertIsInstance(quant_layer, WeightQuantLinear)

        self.assertTrue(is_modified)



    def test_weight_quantize_with_unsupported_algo(self):

        layer = nn.Linear(8, 8)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.NO_QUANT)

        quant_layer, is_modified = weight_quantize("0", layer, cfg, create_mock_handler(self.weights))

        self.assertEqual(quant_layer, layer)

        self.assertFalse(is_modified)



    def test_w8a16_quantize_with_linear(self):

        layer = nn.Linear(8, 8)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A16)

        quant_layer, is_modified = w8a16_quantize("0", layer, cfg, create_mock_handler(self.weights))

        self.assertIsInstance(quant_layer, WeightQuantLinear)

        self.assertTrue(is_modified)



    def test_w8a16_quantize_with_unsupported_layer(self):

        layer = nn.ReLU()

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A16)

        quant_layer, is_modified = w8a16_quantize("0", layer, cfg, create_mock_handler(self.weights))

        self.assertEqual(quant_layer, layer)

        self.assertFalse(is_modified)



    def test_w8a16_quantize_with_custom_map(self):

        layer = nn.Linear(8, 8)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.W8A16)

        custom_map = {nn.Linear: WeightQuantLinear}

        quant_layer, is_modified = w8a16_quantize("0", layer, cfg, create_mock_handler(self.weights), map=custom_map)

        self.assertIsInstance(quant_layer, WeightQuantLinear)

        self.assertTrue(is_modified)





@unittest.skipIf(

    os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."

)

class TestAddFAQuant(unittest.TestCase):

    def setUp(self):

        self.weights = {

            "test_layer.q_rot": torch.randn(128, 128, dtype=torch.float16),

            "test_layer.k_rot": torch.randn(128, 128, dtype=torch.float16),

        }



    def test_add_fa_quant_with_valid_layer(self):

        # 创建一个具有必要属性的模拟层

        class MockLayer(nn.Module):

            pass



        layer = MockLayer()

        cfg = QuantConfig(quant_algo=QuantAlgorithm.FP8_DYNAMIC)

        add_fa_quant(layer, cfg, "test_layer", create_mock_handler(self.weights))

        self.assertTrue(hasattr(layer, 'fa_quant'))

        self.assertIsInstance(layer.fa_quant, FP8RotateQuantFA)



    def test_add_fa_quant_with_invalid_layer(self):

        # 创建一个没有必要属性的层

        layer = nn.Linear(10, 10)

        cfg = QuantConfig(quant_algo=QuantAlgorithm.NO_QUANT)

        add_fa_quant(layer, cfg, "test_layer", self.weights)

        self.assertFalse(hasattr(layer, 'fa_quant'))





@unittest.skipIf(

    os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."

)

class TestGetCfgAndWeights(unittest.TestCase):

    def setUp(self):

        self.quant_des_path = "path/to/quant_des.json"

        self.quant_weight_path = "path/to/quant_model_weight_w8a8.safetensors"

        self.quant_des_dict = {"model_quant_type": "W8A8", "layer1": "W8A8", "layer2": "FLOAT"}

        self.quant_weights = {"weight": torch.ones(1)}



    @mock.patch("mindiesd.utils.file_utils.safe_open")

    @mock.patch("mindiesd.utils.file_utils.check_file_safety")

    @mock.patch("safetensors.safe_open")

    def test_get_cfg_and_weights_normal(self, mock_safe_open0, mock_check_safety, mock_safe_open1):

        # Mock file operations

        mock_file = mock.MagicMock()

        mock_file.read.return_value = json.dumps(self.quant_des_dict)

        mock_safe_open1.return_value.__enter__.return_value = mock_file

        mock_safe_open0.return_value.__enter__.return_value = create_mock_handler(self.quant_weights)



        # Test normal case

        cfg, weights = get_cfg_and_weights(self.quant_des_path)



        # Verify results

        self.assertEqual(cfg.quant_algo, QuantAlgorithm.W8A8)

        self.assertEqual(cfg.exclude_layers, tuple(["layer2"]))



        # Verify calls

        mock_safe_open1.assert_called_once()

        mock_check_safety.assert_called()

        mock_safe_open0.assert_called_once()





if __name__ == '__main__':

    unittest.main()