import logging
import os
import unittest
from amct_pytorch import convert, quantize
LLAMA2_7B_MODEL_PATH = "meta-llama/Llama-2-7b-hf"
RUN_SKIPPED = os.getenv('RUN_SKIPPED_TESTS', 'False').lower() == 'true'
logger = logging.getLogger(__name__)
class TestFlatQuant(unittest.TestCase):
'''
ST FOR FLATQUANT ALGORITHM
'''
@classmethod
def setUpClass(cls):
logger.info('TestFlatQuant START!')
@classmethod
def tearDownClass(cls):
logger.info('TestFlatQuant END!')
def setUp(self):
pass
def tearDown(self):
pass
@unittest.skipIf(not RUN_SKIPPED, "Skip by default due to requiring the actual Llama model")
def test_int4_tensor_sym_flatquant_success(self):
cfg = {
'batch_num': 4,
'quant_cfg': {
'inputs': {
'enable_quant': True,
'type': 'int4',
'symmetric': True,
'strategy': 'token'
},
'weights': {
'type': 'int4',
'symmetric': True,
'strategy': 'channel',
},
},
'algorithm': {
'flatquant': {
'use_kcache_quant': False,
'k_bits': 16,
'use_vcache_quant': False,
'v_bits': 16,
'use_o_quant': False
},
},
'skip_layers': {'lm_head'}
}
import transformers
config = transformers.LlamaConfig.from_pretrained(LLAMA2_7B_MODEL_PATH, attn_implementation='eager')
model = transformers.LlamaForCausalLM.from_pretrained(
LLAMA2_7B_MODEL_PATH, torch_dtype='auto', config=config,
use_auth_token=None, low_cpu_mem_usage=True)
model.seqlen = 2048
logger.info(f'---> Loading {LLAMA2_7B_MODEL_PATH} Model with seq_len: {model.seqlen}')
quantize(model, cfg)
self.assertEqual(type(model.model.layers[0].self_attn).__name__, 'FlatQuantAttention')
self.assertEqual(type(model.model.layers[0].mlp).__name__, 'FlatQuantMLP')
convert(model)
self.assertEqual(type(model.model.layers[0].self_attn).__name__, 'NpuFlatQuantAttention')
self.assertEqual(type(model.model.layers[0].mlp).__name__, 'NpuFlatQuantMLP')