# Copyright Huawei Technologies Co., Ltd. 2025. All rights reserved.
import os
import json
import sys
import torch

current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, '..', ".."))
sys.path.append(parent_directory)
import torch.nn.functional as F
from example.common.security.path import get_valid_write_path, get_valid_read_path
from example.common.utils import SafeGenerator, ArgumentParser, StringArgumentValidator, MAX_KEY_LENGTH, \
    MAX_JSON_LENGTH, cmd_bool, parse_tokenizer_args
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlier, AntiOutlierConfig
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from msmodelslim import logger


CPU = "cpu"
NPU = "npu"


def cmd_bool(cmd_arg):
    if cmd_arg == "True":
        return True
    elif cmd_arg == "False":
        return False
    raise ValueError(f"{cmd_arg} should be True or False")


def get_anti_dataset(tokenizer, anti_dict, device="npu"):
    anti_dataset = []
    max_len = 0
    anti_list = [anti_dict['prompt']]
    for anti_data in anti_list:
        inputs = tokenizer(anti_data, return_tensors='pt')
        anti_dataset.append(inputs.data['input_ids'].to(device))
        max_len = max(max_len, inputs.data['input_ids'].shape[1])
    for i, _ in enumerate(anti_dataset):
        anti_dataset[i] = F.pad(anti_dataset[i], (0, max_len - anti_dataset[i].shape[1]), value=0)
    return torch.cat(anti_dataset)


def get_calib_dataset(tokenizer, calib_dict, device="npu"):
    calib_dataset = []
    calib_list = [calib_dict['prompt']]
    for calib_data in calib_list:
        inputs = tokenizer(calib_data, return_tensors='pt').to(device)
        calib_dataset.append([inputs.data['input_ids']])
    return calib_dataset



def parse_arguments():
    parser = ArgumentParser()
    parser.add_argument('--model_path', type=str, help="model and tokenizer path")
    parser.add_argument('--save_directory', type=str)
    parser.add_argument('--part_file_size', type=int, default=None)
    parser.add_argument(
        '--calib_texts',
        type=str,
        nargs='+',
        default=None)
    parser.add_argument(
        '--calib_file',
        type=str,
        help='A jsonl file contains calibration data.',
        default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'common', 'teacher_qualification.jsonl'))
    parser.add_argument(
        '--anti_file', 
        type=str, 
        default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'common', 'mix_dataset_glm.json'))
    parser.add_argument('--w_bit', type=int, default=8)
    parser.add_argument('--a_bit', type=int, default=8)
    parser.add_argument('--disable_names', type=str, nargs='+', default=None)
    parser.add_argument('--device_type', type=str, choices=[CPU, NPU], default=CPU)
    parser.add_argument('--fraction', type=float, default=0.01)
    parser.add_argument("--act_method", type=int, choices=[1, 2, 3], default=1,
                        help=" 1: MinMax, 2: Histogram, 3: Auto")
    parser.add_argument('--co_sparse', type=cmd_bool, default=False)
    parser.add_argument('--anti_method', type=str, default='')
    parser.add_argument('--disable_level', type=str, default='L0')
    parser.add_argument('--do_smooth', type=cmd_bool, default=False)
    parser.add_argument('--use_sigma', type=cmd_bool, default=False)
    parser.add_argument('--use_reduce_quant', type=cmd_bool, default=False)
    parser.add_argument('--tp_size', type=int, default=1)
    parser.add_argument('--sigma_factor', type=float, default=3.0)
    parser.add_argument('--is_lowbit', type=cmd_bool, default=False)
    parser.add_argument('--mm_tensor', type=cmd_bool, default=True)
    parser.add_argument('--w_sym', type=cmd_bool, default=True)
    parser.add_argument('--use_kvcache_quant', type=cmd_bool, default=False)
    parser.add_argument('--use_fa_quant', type=cmd_bool, default=False)
    parser.add_argument('--fa_amp', type=int, default=0)
    parser.add_argument('--open_outlier', type=cmd_bool, default=True)
    parser.add_argument('--group_size', type=int, default=64)
    parser.add_argument('--is_dynamic', type=cmd_bool, default=False)
    parser.add_argument('--input_ids_name', type=str, default='input_ids',
                        validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
    parser.add_argument('--attention_mask_name', type=str, default='attention_mask',
                        validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
    parser.add_argument('--tokenizer_args', type=str, default='{"padding_side":"left","pad_token":"<unk>"}',
                        validator=StringArgumentValidator(min_length=2, max_length=MAX_JSON_LENGTH))
    parser.add_argument('--disable_last_linear', type=cmd_bool, default=True)
    parser.add_argument('--model_name', type=str, default=None,
                        validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH, allow_none=True))
    parser.add_argument('--trust_remote_code', type=cmd_bool, default=False)
    parser.add_argument('--mindie_format', action="store_true", help="Compatible with quantization formats \
                        supported by before 2.1.RC1 version of MindIE")
    return parser.parse_args()


class Quantifier:
    def __init__(self, model_path_or_name, quant_config=None,
                 anti_outlier_config=None, device_type='cpu', trust_remote_code=False, **kwargs):
        safe_generator = SafeGenerator()
        self.device_type = device_type
        device_map = CPU if self.device_type == CPU else "auto"
        self.trust_remote_code = trust_remote_code

        self.quant_config = quant_config
        self.anti_outlier_config = anti_outlier_config
        self.model_path_or_name = model_path_or_name
        self.config = safe_generator.get_config_from_pretrained(
            self.model_path_or_name, 
            trust_remote_code=self.trust_remote_code
        )
        self.dtype = self.config.torch_dtype if self.device_type == NPU else torch.float32
        self.model = safe_generator.get_model_from_pretrained(
            self.model_path_or_name,
            low_cpu_mem_usage=True, 
            torch_dtype=self.dtype,
            device_map=device_map,
            trust_remote_code=self.trust_remote_code
        )
        tokenizer_args = kwargs.get("tokenizer_args", {})
        self.tokenizer = safe_generator.get_tokenizer_from_pretrained(
            self.model_path_or_name, 
            use_fast=False, 
            trust_remote_code=self.trust_remote_code, 
            legacy=False, 
            **tokenizer_args
        )
        self.model_name = kwargs.get("model_name", None)

    def get_tokenized_data(self, input_texts,
                           input_ids_name='input_ids',
                           attention_mask_name='attention_mask'):
        tokenized_data = []
        for input_text in input_texts:
            inputs = self.tokenizer(input_text, return_tensors='pt', padding=True).to(self.device_type)
            tokenized_data.append(
                [inputs.data[input_ids_name], inputs.data[attention_mask_name]])
        return tokenized_data

    def convert(self, tokenized_calib_data, tokenized_anti_data, save_path, disable_level, part_file_size=None):
        if self.device_type == NPU:
            # 避免在线编译算子,使用二进制编译的算子
            torch.npu.set_compile_mode(jit_compile=False)

        if self.anti_outlier_config is not None:
            anti_outlier = AntiOutlier(self.model, calib_data=tokenized_anti_data, cfg=self.anti_outlier_config)
            anti_outlier.process()

        calibrator = Calibrator(self.model,
                                self.quant_config,
                                calib_data=tokenized_calib_data,
                                disable_level=disable_level)
        calibrator.run()
        save_type = "safe_tensor" if args.mindie_format else "ascendV1"
        calibrator.save(save_path, save_type=[save_type], part_file_size=part_file_size)


if __name__ == '__main__':
    args = parse_arguments()
    checker = SafeGenerator()
    try:
        rank: int = int(os.getenv("RANK", "0"))
    except ValueError as e:
        logger.warning(f"Error converting 'RANK' environment variable to integer: {e}")
        logger.info("Defaulting to 0.")
        rank: int = 0

    model_path = args.model_path
    save_directory = args.save_directory
    num_layers = checker.get_config_from_pretrained(
        model_path, 
        trust_remote_code=args.trust_remote_code
    ).num_layers

    # Check if disable_names is provided, if not and a_bit is 8, generate disable_names
    disable_names = args.disable_names
    if not disable_names and args.a_bit == 8:
        disable_names = [f"transformer.encoder.layers.{i}.mlp.dense_4h_to_h"
                        for i in range(num_layers)]

    quant_conf = QuantConfig(
        w_bit=args.w_bit,
        a_bit=args.a_bit,
        disable_names=disable_names,
        dev_type=args.device_type,
        dev_id=rank,
        act_method=args.act_method,
        w_sym=args.w_sym,
        mm_tensor=False,
        co_sparse=args.co_sparse,
        fraction=args.fraction,
        sigma_factor=args.sigma_factor,
        use_sigma=args.use_sigma,
        is_lowbit=args.is_lowbit,
        do_smooth=args.do_smooth,
        open_outlier=args.open_outlier,
        group_size=args.group_size,
        use_kvcache_quant=args.use_kvcache_quant,
        is_dynamic=args.is_dynamic,
        disable_last_linear=args.disable_last_linear,
    )

    if args.use_fa_quant:
        quant_conf = quant_conf.fa_quant(fa_amp=args.fa_amp)

    anti_outlier_config_val = None
    if args.anti_method == 'm3':
        anti_outlier_config_val = AntiOutlierConfig(a_bit=args.a_bit, w_bit=args.w_bit,
                                                    anti_method=args.anti_method, w_sym=args.w_sym,
                                                    dev_type=args.device_type, dev_id=rank)
    elif args.anti_method:
        anti_outlier_config_val = AntiOutlierConfig(anti_method=args.anti_method,
                                                    dev_type=args.device_type)
    tokenizer_args = parse_tokenizer_args(
        args.tokenizer_args, 
        default={"padding_side": "left", "pad_token": "<unk>"}
    )
    quantifier = Quantifier(
        model_path, quant_conf, anti_outlier_config_val,
        device_type=args.device_type, tokenizer_args=tokenizer_args,
        model_name=args.model_name, trust_remote_code=args.trust_remote_code
    )
    anti_file_path = get_valid_read_path(args.anti_file, "json", is_dir=False)
    calib_file_path = get_valid_read_path(args.calib_file, "json", is_dir=False)
    
    with open(anti_file_path, 'r') as f:
        anti_prompts = json.load(f)
    with open(calib_file_path, 'r') as f:
        calib_prompts = json.load(f)

    anti_dataset = []
    for anti_prompt in anti_prompts:
        tokenizer_anti_prompt = get_anti_dataset(quantifier.tokenizer, anti_prompt, args.device_type)
        anti_dataset.append([tokenizer_anti_prompt])

    calib_dataset = []
    for clib_prompt in calib_prompts:
        tokenizer_calib_prompt = get_calib_dataset(quantifier.tokenizer, clib_prompt, args.device_type)
        calib_dataset += (tokenizer_calib_prompt)

    if not os.path.exists(save_directory):
        os.makedirs(save_directory, mode=0o750, exist_ok=True)

    save_directory = get_valid_write_path(save_directory, is_dir=True)
    quantifier.convert(calib_dataset,
                       anti_dataset, save_directory,
                       args.disable_level,
                       part_file_size=args.part_file_size)

    quant_type = f"w{args.w_bit}a{args.a_bit}"
    is_sparseCompress = args.w_bit == 4 and args.a_bit == 8 and (args.co_sparse or args.is_lowbit)
    if is_sparseCompress:
        quant_type = "w8a8s"
    is_w8a8_dynamic = args.w_bit == 8 and args.a_bit == 8 and args.is_dynamic
    if is_w8a8_dynamic:
        quant_type = "w8a8_dynamic"
    auto_config = checker.get_config_from_pretrained(model_path, trust_remote_code=args.trust_remote_code)
    checker.modify_config(model_path, save_directory, auto_config.torch_dtype,
                quant_type, args)
    checker.copy_tokenizer_files(model_path, save_directory)