#!/usr/bin/env python
# -*- coding: UTF-8 -*-

"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.

MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:

         http://license.coscl.org.cn/MulanPSL2

THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os
import json
import sys

import torch

current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, '..', ".."))
sys.path.append(parent_directory)

from example.common.security.path import get_valid_write_path, get_valid_read_path
from example.common.utils import SafeGenerator, ArgumentParser, StringArgumentValidator, MAX_KEY_LENGTH, \
    MAX_JSON_LENGTH, cmd_bool, parse_tokenizer_args
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlier, AntiOutlierConfig
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from msmodelslim import logger

CPU = "cpu"
NPU = "npu"


def cmd_bool(cmd_arg):
    if cmd_arg == "True":
        return True
    elif cmd_arg == "False":
        return False
    raise ValueError(f"{cmd_arg} should be True or False")


def parse_arguments():
    parser = ArgumentParser()
    parser.add_argument('--model_path', type=str, help="model and tokenizer path")
    parser.add_argument('--save_directory', type=str)
    parser.add_argument('--part_file_size', type=int, default=None)
    parser.add_argument(
        '--calib_texts',
        type=str,
        nargs='+',
        default=None)
    parser.add_argument(
        '--calib_file',
        type=str,
        help='A jsonl file contains calibration data.',
        default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'common', 'boolq.jsonl'))
    parser.add_argument('--w_bit', type=int, default=8)
    parser.add_argument('--a_bit', type=int, default=8)
    parser.add_argument('--disable_names', type=str, nargs='+', default=None)
    parser.add_argument('--device_type', type=str, choices=[CPU, NPU], default=NPU)
    parser.add_argument('--fraction', type=float, default=0.01)
    parser.add_argument("--act_method", type=int, choices=[1, 2, 3], default=3,
                        help=" 1: MinMax, 2: Histogram, 3: Auto")
    parser.add_argument('--co_sparse', type=cmd_bool, default=False)
    parser.add_argument('--anti_method', type=str, default='')
    parser.add_argument('--disable_level', type=str, default='L0')
    parser.add_argument('--do_smooth', type=cmd_bool, default=False)
    parser.add_argument('--use_sigma', type=cmd_bool, default=False)
    parser.add_argument('--use_reduce_quant', type=cmd_bool, default=False)
    parser.add_argument('--tp_size', type=int, default=1)
    parser.add_argument('--sigma_factor', type=float, default=3.0)
    parser.add_argument('--is_lowbit', type=cmd_bool, default=False)
    parser.add_argument('--w_sym', type=cmd_bool, default=True)
    parser.add_argument('--use_kvcache_quant', type=cmd_bool, default=False)
    parser.add_argument('--use_fa_quant', type=cmd_bool, default=False)
    parser.add_argument('--fa_amp', type=int, default=0)
    parser.add_argument('--open_outlier', type=cmd_bool, default=True)
    parser.add_argument('--group_size', type=int, default=64)
    parser.add_argument('--is_dynamic', type=cmd_bool, default=False)
    parser.add_argument('--input_ids_name', type=str, default='input_ids',
                        validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
    parser.add_argument('--attention_mask_name', type=str, default='attention_mask',
                        validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
    parser.add_argument('--tokenizer_args', type=str, default='{"padding_side":"left","pad_token":"<unk>"}',
                        validator=StringArgumentValidator(min_length=2, max_length=MAX_JSON_LENGTH))
    parser.add_argument('--disable_last_linear', type=cmd_bool, default=True)
    parser.add_argument('--model_name', type=str, default=None,
                        validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH, allow_none=True))
    parser.add_argument('--trust_remote_code', type=cmd_bool, default=False)
    parser.add_argument('--mindie_format', action="store_true", help="Compatible with quantization formats \
                        supported by before 2.1.RC1 version of MindIE")
    return parser.parse_args()


class Quantifier:
    def __init__(self, model_path_or_name, quant_config=None,
                 anti_outlier_config=None, device_type='cpu', trust_remote_code=False, **kwargs):
        safe_generator = SafeGenerator()
        self.device_type = device_type
        device_map = CPU if self.device_type == CPU else "auto"
        self.trust_remote_code = trust_remote_code

        self.quant_config = quant_config
        self.anti_outlier_config = anti_outlier_config
        self.model_path_or_name = model_path_or_name
        self.config = safe_generator.get_config_from_pretrained(
            self.model_path_or_name, 
            trust_remote_code=self.trust_remote_code
        )
        self.dtype = self.config.torch_dtype if self.device_type == NPU else torch.float32
        self.model = safe_generator.get_model_from_pretrained(
            self.model_path_or_name,
            low_cpu_mem_usage=True, 
            torch_dtype=self.dtype,
            trust_remote_code=self.trust_remote_code,
            device_map=device_map,
        ).eval()

        tokenizer_args = kwargs.get("tokenizer_args", {})
        self.tokenizer = safe_generator.get_tokenizer_from_pretrained(
            self.model_path_or_name, 
            use_fast=False, 
            trust_remote_code=self.trust_remote_code,
            device_map=device_map, 
            legacy=False, 
            **tokenizer_args
        )
        
        # 为了确保 tokenizer.pad_token_id 被正确设置,可采取以下步骤:
        self.tokenizer.pad_token = self.tokenizer.eos_token
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id   
        self.model.config.pad_token_id = self.tokenizer.pad_token_id

        self.model_name = kwargs.get("model_name", None)

    def get_tokenized_data(self, input_texts,
                           input_ids_name='input_ids',
                           attention_mask_name='attention_mask'):
        tokenized_data = []
        for input_text in input_texts:
            inputs = self.tokenizer(input_text, return_tensors='pt', padding=True).to(self.device_type)
            tokenized_data.append(
                [inputs.data[input_ids_name], inputs.data[attention_mask_name]])
        return tokenized_data

    def convert(self, tokenized_data, save_path, disable_level, part_file_size=None):
        if self.device_type == NPU:
            # 如果使用npu进行量化需开启二进制编译,避免在线编译算子
            torch.npu.set_compile_mode(jit_compile=False)
            option = {}
            option["NPU_FUZZY_COMPILE_BLACKLIST"] = "ReduceProd"
            torch.npu.set_option(option)

        if self.anti_outlier_config is not None:
            if self.model_name == "baichuan":
                anti_outlier = AntiOutlier(self.model, calib_data=tokenized_data,
                                           cfg=self.anti_outlier_config, norm_class_name="RMSNorm")
            else:
                anti_outlier = AntiOutlier(self.model, calib_data=tokenized_data, cfg=self.anti_outlier_config)
            anti_outlier.process()

        calibrator = Calibrator(self.model, self.quant_config, calib_data=tokenized_data, disable_level=disable_level)
        calibrator.run()
        save_type = "safe_tensor" if args.mindie_format else "ascendV1"
        calibrator.save(save_path, save_type=[save_type], part_file_size=part_file_size)


if __name__ == '__main__':
    args = parse_arguments()
    checker = SafeGenerator()

    try:
        rank: int = int(os.getenv("RANK", "0"))
    except ValueError as e:
        logger.warning(f"Error converting 'RANK' environment variable to integer: {e}")
        logger.info("Defaulting to 0.")
        rank: int = 0

    model_path = args.model_path
    save_directory = args.save_directory

    # Check if disable_names is provided, if not and a_bit is 8, generate disable_names
    disable_names = args.disable_names

    quant_conf = QuantConfig(
        w_bit=args.w_bit,
        a_bit=args.a_bit,
        disable_names=disable_names,
        dev_type=args.device_type,
        dev_id=rank,
        act_method=args.act_method,
        w_sym=args.w_sym,
        mm_tensor=False,
        co_sparse=args.co_sparse,
        fraction=args.fraction,
        sigma_factor=args.sigma_factor,
        use_sigma=args.use_sigma,
        is_lowbit=args.is_lowbit,
        do_smooth=args.do_smooth,
        open_outlier=args.open_outlier,
        group_size=args.group_size,
        use_kvcache_quant=args.use_kvcache_quant,
        is_dynamic=args.is_dynamic,
        disable_last_linear=args.disable_last_linear,
    )

    if args.use_fa_quant:
        quant_conf = quant_conf.fa_quant(fa_amp=args.fa_amp)

    anti_outlier_config_val = None
    if args.anti_method == 'm3':
        anti_outlier_config_val = AntiOutlierConfig(a_bit=args.a_bit, w_bit=args.w_bit,
                                                    anti_method=args.anti_method, w_sym=args.w_sym,
                                                    dev_type=args.device_type, dev_id=rank)
    elif args.anti_method:
        anti_outlier_config_val = AntiOutlierConfig(anti_method=args.anti_method, 
                                                    dev_type=args.device_type, dev_id=rank)
    tokenizer_args = parse_tokenizer_args(
        args.tokenizer_args, 
        default={"padding_side": "left", "pad_token": "<unk>"}
    )
    quantifier = Quantifier(
        model_path, quant_conf, anti_outlier_config_val,
        device_type=args.device_type, tokenizer_args=tokenizer_args,
        model_name=args.model_name, trust_remote_code=args.trust_remote_code
    )
    tokenized_calib_data = None
    calib_file = args.calib_file
    calib_texts = checker.load_jsonl(calib_file) if calib_file else args.calib_texts
    if calib_texts is not None:
        tokenized_calib_data = quantifier.get_tokenized_data(
            calib_texts,
            input_ids_name=args.input_ids_name,
            attention_mask_name=args.attention_mask_name
        )

    if not os.path.exists(save_directory):
        os.makedirs(save_directory, exist_ok=True)

    # check dst dir
    save_directory = get_valid_write_path(save_directory, is_dir=True)
    #为适配工具稀疏量化传入w_bit=4,a_bit=8暂时修改quant_type
    quantifier.convert(tokenized_calib_data, save_directory, args.disable_level, part_file_size=args.part_file_size)
    quant_type = f"w{args.w_bit}a{args.a_bit}"
    is_sparseCompress = args.w_bit == 4 and args.a_bit == 8 and (args.co_sparse or args.is_lowbit)
    if is_sparseCompress:
        quant_type = "w8a8s"
    is_w8a8_dynamic = args.w_bit == 8 and args.a_bit == 8 and args.is_dynamic
    if is_w8a8_dynamic:
        quant_type = "w8a8_dynamic"
    auto_config = checker.get_config_from_pretrained(model_path, trust_remote_code=args.trust_remote_code)
    checker.modify_config(model_path, save_directory, auto_config.torch_dtype,
                quant_type, args)
    checker.copy_tokenizer_files(model_path, save_directory)