# ----------------------------------------------------------------------------
# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
import copy
import logging
import sys
import unittest
from unittest.mock import MagicMock, patch

import torch
import torch.nn as nn
from mock_torch_npu import (
    mock_npu,
    mock_npu_convert_weight_to_int4pack,
    mock_npu_dtype_cast,
    mock_npu_dynamic_mx_quant,
    mock_npu_format_cast,
    mock_npu_quant_matmul,
    mock_npu_quantize,
    mock_npu_trans_quant_param,
    mock_npu_weight_quant_batchmatmul,
)
from utils import TestModel, TestModelBias

from amct_pytorch import convert, quantize

logger = logging.getLogger(__name__)

MINMAX_QUANT = 'MinMaxQuant'
LINEAR = 'Linear'
NPU_QUANTIZATION_LINEAR = 'NpuQuantizationLinear'
NPU_WEIGHT_QUANTIZED_LINEAR = 'NpuWeightQuantizedLinear'

torch.manual_seed(0)


class TestMinMax(unittest.TestCase):
    '''
    ST FOR MINMAX ALGORITHM
    '''
    @classmethod
    def setUpClass(cls):
        cls.test_model = TestModel().to(torch.bfloat16)
        cls.inputs = torch.randn(64, 64).to(torch.bfloat16)
        cls.ori_out = cls.test_model(cls.inputs)
        logger.info('TestMinMax START!')

    @classmethod
    def tearDownClass(cls):
        logger.info('TestMinMax END!')

    def setUp(self):
        mock_torch_npu = MagicMock()
        sys.modules['torch_npu'] = mock_torch_npu
 
    def tearDown(self):
        del sys.modules['torch_npu']

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int8_tensor_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': True,
                    'strategy': 'tensor',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        self.assertEqual(model.linear1.scale_w.shape[0], 1)
        self.assertIsNotNone(model.linear1.scale_w)
        self.assertIsNone(model.linear2.offset_w)
        self.assertEqual(type(model.linear3).__name__, MINMAX_QUANT)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
    
    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int8_tensor_asym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': False,
                    'strategy': 'tensor',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        self.assertEqual(model.linear1.scale_w.shape[0], 1)
        self.assertEqual(type(model.linear3).__name__, MINMAX_QUANT)
        self.assertIsNotNone(model.linear1.scale_w)
        self.assertIsNotNone(model.linear2.offset_w)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
    
    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int8_channel_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': True,
                    'strategy': 'channel',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        self.assertEqual(model.linear1.scale_w.shape[0], 64)
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        model(self.inputs)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int8_channel_asym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': False,
                    'strategy': 'channel',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        self.assertEqual(model.linear1.scale_w.shape[0], 64)
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
    
    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int8_group_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': True,
                    'strategy': 'group',
                    'group_size': 32
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        self.assertEqual(model.linear1.scale_w.shape[0], 64)
        self.assertEqual(model.linear1.scale_w.shape[1], 2)
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int8_group_asym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': False,
                    'strategy': 'group',
                    'group_size': 32
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())

    # Not Quant - int4
    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int4_tensor_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int4',
                    'symmetric': True,
                    'strategy': 'tensor',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear2).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear3).__name__, LINEAR)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)
        self.assertEqual(type(model.linear2).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)
        

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int4_tensor_asym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int4',
                    'symmetric': False,
                    'strategy': 'tensor',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear2).__name__, MINMAX_QUANT)
        self.assertEqual(model.linear1.scale_w.shape[0], 1)
        self.assertEqual(model.linear2.scale_w.shape[0], 1)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)
        self.assertEqual(type(model.linear2).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)
    
    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int4_channel_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int4',
                    'symmetric': True,
                    'strategy': 'channel',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        torch.Tensor.npu = mock_npu
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear2).__name__, MINMAX_QUANT)
        self.assertEqual(model.linear1.scale_w.shape[0], 64)
        self.assertEqual(model.linear2.scale_w.shape[0], 32)
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)
        self.assertEqual(type(model.linear2).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)


    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int4_channel_asym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int4',
                    'symmetric': False,
                    'strategy': 'channel',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        torch.Tensor.npu = mock_npu
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear2).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear3).__name__, LINEAR)
        self.assertEqual(model.linear1.scale_w.shape[0], 64)
        self.assertEqual(model.linear2.scale_w.shape[0], 32)
        self.assertIsNotNone(model.linear1.offset_w[0])
        self.assertIsNotNone(model.linear2.offset_w[0])
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)
        self.assertEqual(type(model.linear2).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)
        self.assertEqual(type(model.linear3).__name__, LINEAR)
    
    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int4_group_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int4',
                    'symmetric': True,
                    'strategy': 'group',
                    'group_size': 32
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear2).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear3).__name__, LINEAR)
        self.assertEqual(model.linear1.scale_w.shape[0], 64)
        self.assertEqual(model.linear1.scale_w.shape[1], 2)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int4_group_asym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int4',
                    'symmetric': False,
                    'strategy': 'group',
                    'group_size': 32
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)
    # Not Quant - int4

    # int8 - int8
    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.npu_quantization_linear.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int8_int8_tensor_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': True,
                    'strategy': 'tensor',
                },
                'inputs': {
                    'type': 'int8',
                    'symmetric': True,
                    'strategy': 'tensor',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        torch.Tensor.npu = mock_npu
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear2).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear3).__name__, MINMAX_QUANT)
        self.assertEqual(model.linear1.scale_w.shape[0], 1)
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_QUANTIZATION_LINEAR)
        self.assertEqual(type(model.linear2).__name__, NPU_QUANTIZATION_LINEAR)
        self.assertEqual(type(model.linear3).__name__, NPU_QUANTIZATION_LINEAR)

    @patch('amct_pytorch.classic.deploy_op.npu_quantization_linear.check_parameters_in_schema', 
           MagicMock(return_value=True))
    def test_int8_int8_tensor_asym_minmax_invalid(self):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': False,
                    'strategy': 'channel',
                },
                'inputs': {
                    'type': 'int8',
                    'symmetric': False,
                    'strategy': 'tensor',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        with self.assertRaisesRegex(ValueError, 'int8 int8 only support symmetric weight quantization'):
            quantize(model, cfg)

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.npu_quantization_linear.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int8_int8_tensor_asym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': True,
                    'strategy': 'channel',
                },
                'inputs': {
                    'type': 'int8',
                    'symmetric': False,
                    'strategy': 'tensor',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear2).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear3).__name__, MINMAX_QUANT)
        self.assertEqual(model.linear1.scale_w.shape[0], 64)
        self.assertEqual(model.linear1.scale_d.shape[0], 1)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_QUANTIZATION_LINEAR)
        self.assertEqual(type(model.linear2).__name__, NPU_QUANTIZATION_LINEAR)
        self.assertEqual(type(model.linear3).__name__, NPU_QUANTIZATION_LINEAR)

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch(
        'amct_pytorch.classic.deploy_op.npu_quantization_linear.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_int8_int8_channel_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int8',
                    'symmetric': True,
                    'strategy': 'channel',
                },
                'inputs': {
                    'type': 'int8',
                    'symmetric': True,
                    'strategy': 'tensor',
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_QUANTIZATION_LINEAR)
        self.assertEqual(type(model.linear2).__name__, NPU_QUANTIZATION_LINEAR)
        self.assertEqual(type(model.linear3).__name__, NPU_QUANTIZATION_LINEAR)

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch('torch_npu.npu_trans_quant_param', wraps=mock_npu_trans_quant_param)
    @patch('amct_pytorch.classic.deploy_op.npu_quantization_linear.check_parameters_in_schema', 
           MagicMock(return_value=True))
    def test_int8_int4_tensor_tensor_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4, mock_5):
        self._run_a8w4_minmax_case('tensor', True)
        self.assertTrue(mock_1.called)

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch('torch_npu.npu_trans_quant_param', wraps=mock_npu_trans_quant_param)
    @patch('amct_pytorch.classic.deploy_op.npu_quantization_linear.check_parameters_in_schema', 
           MagicMock(return_value=True))
    def test_int8_int4_tensor_channel_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4, mock_5):
        self._run_a8w4_minmax_case('channel', True)
        self.assertTrue(mock_1.called)

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch('torch_npu.npu_trans_quant_param', wraps=mock_npu_trans_quant_param)
    @patch('amct_pytorch.classic.deploy_op.npu_quantization_linear.check_parameters_in_schema', 
           MagicMock(return_value=True))
    def test_int8_int4_asym_act_tensor_minmax_success(self, mock_1, mock_2, mock_3, mock_4, mock_5):
        self._run_a8w4_minmax_case('tensor', False)
        self.assertTrue(mock_1.called)

    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch('torch_npu.npu_trans_quant_param', wraps=mock_npu_trans_quant_param)
    @patch('amct_pytorch.classic.deploy_op.npu_quantization_linear.check_parameters_in_schema', 
           MagicMock(return_value=True))
    def test_int8_int4_asym_act_channel_minmax_success(self, mock_1, mock_2, mock_3, mock_4, mock_5):
        self._run_a8w4_minmax_case('channel', False)
        self.assertTrue(mock_1.called)
        
    # int8 - int8
# minmax

 # Not Quant - float4e2m1
    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul)
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul)
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack)
    @patch('torch_npu.npu_format_cast', wraps=mock_npu_format_cast)
    @patch('torch_npu.npu_dtype_cast', wraps=mock_npu_dtype_cast)
    @patch('torch_npu.npu_dynamic_mx_quant', wraps=mock_npu_dynamic_mx_quant)
    @patch(
        'amct_pytorch.classic.deploy_op.weight_npu_quant_module.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_fp4_group_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4, mock_5, mock_6, mock_7):
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'float4_e2m1',
                    'symmetric': True,
                    'strategy': 'group',
                    'group_size': 32
                },
            },
            'algorithm': {'minmax'},
        }
        model = copy.deepcopy(self.test_model).to(torch.bfloat16)
        quantize(model, cfg)
        model(self.inputs)
        self.assertEqual(type(model.linear1).__name__, MINMAX_QUANT)
        self.assertEqual(type(model.linear2).__name__, LINEAR)
        self.assertEqual(type(model.linear3).__name__, LINEAR)
        self.assertIsNotNone(model.linear1.scale_w)
        torch.Tensor.npu = mock_npu
        convert(model)
        quant_out = model(self.inputs.npu())
        self.assertEqual(type(model.linear1).__name__, NPU_WEIGHT_QUANTIZED_LINEAR)
        self.assertEqual(type(model.linear2).__name__, LINEAR)
        self.assertEqual(type(model.linear3).__name__, LINEAR)

    # float8e4m3 - float4e2m1 
    @patch('torch_npu.npu_quantize', wraps=mock_npu_quantize) 
    @patch('torch_npu.npu_quant_matmul', wraps=mock_npu_quant_matmul) 
    @patch('torch_npu.npu_weight_quant_batchmatmul', wraps=mock_npu_weight_quant_batchmatmul) 
    @patch('torch_npu.npu_convert_weight_to_int4pack', wraps=mock_npu_convert_weight_to_int4pack) 
    @patch('torch_npu.npu_format_cast', wraps=mock_npu_format_cast) 
    @patch('torch_npu.npu_dtype_cast', wraps=mock_npu_dtype_cast) 
    @patch('torch_npu.npu_dynamic_mx_quant', wraps=mock_npu_dynamic_mx_quant) 
    @patch('torch_npu.npu_trans_quant_param', wraps=mock_npu_trans_quant_param) 
    @patch(
        'amct_pytorch.classic.deploy_op.npu_quantization_linear.check_parameters_in_schema',
        MagicMock(return_value=True),
    )
    def test_fp8_fp4_group_sym_minmax_success(self, mock_1, mock_2, mock_3, mock_4, mock_5, mock_6, mock_7, mock_8): 
        cfg = {
            'batch_num': 1, 
            'quant_cfg': {
                'weights': {
                    'type': 'float4_e2m1', 
                    'symmetric': True, 
                    'strategy': 'group', 
                    'group_size': 32 
                }, 
                'inputs': {
                    'type': 'float8_e4m3fn', 
                    'symmetric': True, 
                    'strategy': 'tensor', 
                }, 
            }, 
            'algorithm': {'minmax'}, 
        } 
        model = copy.deepcopy(TestModelBias()).to(torch.bfloat16) 


        quantize(model, cfg) 
        model(self.inputs) 
        self.assertEqual(type(model.linear1).__name__, LINEAR) 
        self.assertEqual(type(model.linear2).__name__, MINMAX_QUANT) 
        self.assertEqual(type(model.linear3).__name__, LINEAR) 
        self.assertIsNotNone(model.linear2.scale_w1) 
        self.assertIsNotNone(model.linear2.scale_d) 
        torch.Tensor.npu = mock_npu 
        convert(model) 
        quant_out = model(self.inputs.npu()) 
        self.assertEqual(type(model.linear1).__name__, LINEAR) 
        self.assertEqual(type(model.linear2).__name__, 'NpuQuantizationLinear') 
        self.assertEqual(type(model.linear3).__name__, 'Linear')

    def _run_a8w4_minmax_case(self, weight_strategy, act_symmetric):
        model = copy.deepcopy(TestModel()).to(torch.float16)
        inputs = self.inputs.to(torch.float16)
        cfg = {
            'batch_num': 1,
            'quant_cfg': {
                'weights': {
                    'type': 'int4',
                    'symmetric': True,
                    'strategy': weight_strategy,
                },
                'inputs': {
                    'type': 'int8',
                    'symmetric': act_symmetric,
                    'strategy': 'tensor',
                },
            },
            'algorithm': {'minmax'},
        }
        quantize(model, cfg)
        model(inputs)
        self.assertEqual(type(model.linear1).__name__, 'MinMaxQuant')
        self.assertEqual(type(model.linear2).__name__, 'MinMaxQuant')
        self.assertEqual(type(model.linear3).__name__, 'Linear')
        self.assertIsNone(model.linear1.offset_w)
        self.assertIsNotNone(model.linear1.scale_d)
        if act_symmetric:
            self.assertIsNone(model.linear1.offset_d)
        else:
            self.assertIsNotNone(model.linear1.offset_d)
        torch.Tensor.npu = mock_npu
        convert(model)
        model(inputs.npu())
        self.assertEqual(type(model.linear1).__name__, 'NpuQuantizationLinear')
        self.assertEqual(type(model.linear2).__name__, NPU_QUANTIZATION_LINEAR)
        self.assertEqual(type(model.linear3).__name__, LINEAR)