import copy
import logging
import unittest
import sys
from unittest.mock import MagicMock
from unittest.mock import patch
import torch
import torch.nn as nn
from utils import TestModelLongcatFlashMLA
from mock_torch_npu import mock_npu_quantize, mock_npu_anti_quant, mock_npu
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
from transformers.cache_utils import DynamicCache
from amct_pytorch import quantize, convert
torch.manual_seed(0)
class TestLongcatFlashMLA(unittest.TestCase):
'''
ST FOR KVCACHE
'''
@classmethod
def setUpClass(cls):
config = DeepseekV3Config()
cls.test_model = TestModelLongcatFlashMLA(config).to(torch.bfloat16)
cls.hidden_states = torch.randn(1, 16, config.hidden_size).to(torch.bfloat16)
cls.kvcache_ori = DynamicCache()
cls.kvcache_quant = DynamicCache()
cls.kvcache = DynamicCache()
for _ in range(5):
cls.ori_out = cls.test_model(cls.hidden_states, past_key_values=cls.kvcache_ori)
logging.info('TestLongcatFlashMLA START!')
@classmethod
def tearDownClass(cls):
logging.info('TestLongcatFlashMLA END!')
def setUp(self):
mock_torch_npu = MagicMock()
sys.modules['torch_npu'] = mock_torch_npu
def tearDown(self):
del sys.modules['torch_npu']
@patch('torch_npu.npu_quantize', wraps=mock_npu_quantize)
@patch('torch_npu.npu_anti_quant', wraps=mock_npu_anti_quant)
def test_quantize_longcat_success(self, mock_1, mock_2):
cfg = {
'batch_num': 1,
'quant_cfg': {
'kvcache': {
'type': 'hifloat8',
'symmetric': True,
'strategy': 'tensor',
},
},
'algorithm': {'quantile'},
}
model = copy.deepcopy(self.test_model)
quantize(model, cfg)
model(self.hidden_states, past_key_values=self.kvcache_quant)
self.assertEqual(type(model.attn).__name__, 'LongcatFlashMLAQuant')
self.assertIsNotNone(model.attn.scale_k)
self.assertIsNotNone(model.attn.scale_v)
convert(model)
self.assertEqual(type(model.attn).__name__, 'NpuLongcatFlashMLA')
torch.Tensor.npu = mock_npu
for _ in range(5):
quant_out = model(self.hidden_states.npu(), past_key_values=self.kvcache)