"""test_quant_config
Whenever needed, you can execute the following code before importing transformers to configure the HuggingFace proxy.
```
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
```
"""
import os
import unittest
from enum import Enum
from pathlib import Path
from parameterized import parameterized
from tensor_cast.transformers.utils import AutoModelConfigLoader, init_on_device_without_buffers
from tensor_cast.utils import get_modules_to_not_convert, pattern_match
from transformers.initialization import no_init_weights
class ConfigMode(Enum):
"""location of config file"""
local = 0
remote = 1
class QuantConfigTestCase(unittest.TestCase):
def setUp(self):
self.model_config_dir = str(Path(__file__).resolve().parents[2] / "assets" / "model_config")
@parameterized.expand(
[
["deepseekv3.1_remote", ConfigMode.local, [False, False, False]],
["moonshotai/Kimi-K2-Thinking", ConfigMode.remote, [True, True, True]],
["minimax_m2", ConfigMode.local, [True, True, False]],
]
)
def test_pattern_match(self, model_name_or_path, config_mode, match_result):
test_case = [
"lm_head",
"model.layers.0.mlp.gate_proj",
"model.layers.60.mlp.shared_experts.down_proj",
]
if config_mode == ConfigMode.local:
model_name_or_path = os.path.join(self.model_config_dir, model_name_or_path)
with init_on_device_without_buffers("meta"), no_init_weights():
auto_loader = AutoModelConfigLoader()
hf_config = auto_loader.load_config(model_name_or_path)
quant_config = auto_loader.load_quant_config(hf_config)
modules_to_not_convert = get_modules_to_not_convert(quant_config)
test_result = [pattern_match(case, modules_to_not_convert) for case in test_case]
self.assertListEqual(test_result, match_result)