# Copyright Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.

import os
import sys
import argparse
import json
from typing import List

import numpy as np
import torch


from safetensors.torch import load_file, save_file

current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, '..'))
sys.path.append(parent_directory)

from example.common.security.path import get_valid_write_path, SafeWriteUmask, get_valid_read_path

TOOL_AWQ = 'awq'
TOOL_GPTQ = 'gptq'
QUANT_TYPE = ['W4A16', 'W8A16']
STORAGE_BITS = 32
ORDINAL_PACK_ORDER = [0, 1, 2, 3, 4, 5, 6, 7]
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]


def awq_pack(iweight: torch.Tensor, w_bit: int, direction: str = "column"):
    pack_num = STORAGE_BITS // w_bit
    shifts = torch.arange(0, STORAGE_BITS, w_bit, device=iweight.device)

    iweight = iweight.to(torch.int8)
    iweight = torch.bitwise_and(iweight, 0x0F)  # eventually correct overflow

    if direction == "column":
        iweight = iweight.view(-1, iweight.shape[1] // pack_num, pack_num)
        qmatrix = torch.bitwise_left_shift(iweight, shifts[None, None, :]).sum(dim=-1)

    elif direction == "row":
        iweight = iweight.view(iweight.shape[0] // pack_num, pack_num, -1)
        qmatrix = torch.bitwise_left_shift(iweight, shifts[None, :, None]).sum(dim=1)

    qmatrix = qmatrix.to(torch.int32)
    return qmatrix


def apply_order(
    iweight: torch.Tensor,
    w_bit: int,
    direction: str = "column",
    order: List[int] = None,
):
    pack_num = STORAGE_BITS // w_bit
    try:
        if direction == "column":
            iweight = iweight.view(-1, pack_num)[:, order].view(iweight.shape)
        elif direction == "row":
            iweight = iweight.view(pack_num, -1)[order, :].view(iweight.shape)
    except IndexError as ide:
        raise IndexError(f"Order index {order} out of range for pack_num {pack_num}. "
                           f"Order indices must be < {pack_num} for w_bit={w_bit}") from ide
    return iweight


def gptq_qweight_pack(iweight: torch.Tensor, w_bit: int):
    i = 0
    row = 0
    iweight = iweight.numpy().astype(np.uint32)
    if len(iweight.shape) < 2:
        raise ValueError("Expected qweight to have at least 2 dimensions, but got shape: {}".format(iweight.shape))
    qweight = np.zeros((iweight.shape[0] // STORAGE_BITS * w_bit, iweight.shape[1]), dtype=np.uint32)
    while row < qweight.shape[0]:
        if w_bit in (4, 8):
            for j in range(i, i + (32 // w_bit)):
                qweight[row] |= iweight[j] << (w_bit * (j - i))
            i += 32 // w_bit
            row += 1
    qweight = qweight.astype(np.int32)
    qweight = torch.from_numpy(qweight)
    return qweight


def gptq_qzeros_pack(zeros: torch.Tensor, w_bit: int):
    i = 0
    col = 0
    if zeros.dtype == torch.bfloat16:
        zeros = zeros.to(torch.float16)
    
    #AutoGPTQ zeros -= 1, or it may breaks exllama kernels 
    zeros -= 1    
    zeros = zeros.numpy().astype(np.uint32)
    if len(zeros.shape) < 2:
        raise ValueError("Expected zeros to have at least 2 dimensions, but got shape: {}".format(zeros.shape))
    qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // STORAGE_BITS * w_bit), dtype=np.uint32)
    while col < qzeros.shape[1]:
        if w_bit in (4, 8):
            for j in range(i, i + (32 // w_bit)):
                qzeros[:, col] |= zeros[:, j] << (w_bit * (j - i))
            i += 32 // w_bit
            col += 1
            
    qzeros = qzeros.astype(np.int32)
    qzeros = torch.from_numpy(qzeros)
    return qzeros  


def convert_ms_to_vllm(target_tool, w_bit, weight_dict, json_dict):

    vllm_weight_dict = {}
    for name, quant_type in json_dict.items():
        if name in weight_dict.keys():
            tensor = weight_dict[name]
            if quant_type in QUANT_TYPE:
                order = AWQ_PACK_ORDER
                direction = 'column'
                if name.endswith('.weight') and '.module.weight' not in name:
                    base_name = name.rsplit('.', 1)[0]
                    tmp_key = base_name + '.module.weight'
                    if tmp_key in weight_dict.keys():
                        continue
                    vllm_name = base_name + '.qweight'
                    tensor = tensor.t().contiguous()
                    tensor = torch.clamp(tensor, -2**(w_bit - 1), 2**(w_bit - 1) - 1)
                    if w_bit == 8:
                        tensor = tensor.to(torch.int32)
                    tensor.add_(2 ** (w_bit - 1))

                    if target_tool == TOOL_AWQ:
                        iweights = apply_order(tensor, w_bit, direction, order)
                        qweight = awq_pack(iweights, w_bit)
                    elif target_tool == TOOL_GPTQ:
                        qweight = gptq_qweight_pack(tensor, w_bit)
                    vllm_weight_dict[vllm_name] = qweight

                elif name.endswith('.weight_scale'):
                    vllm_name = name.rsplit('.', 1)[0] + '.scales'
                    tensor = tensor.t().contiguous()
                    vllm_weight_dict[vllm_name] = tensor

                elif name.endswith('.weight_offset'):
                    vllm_name = name.rsplit('.', 1)[0] + '.qzeros'
                    tensor = tensor.t().contiguous()
                    tensor = torch.clamp(tensor, -2**(w_bit - 1), 2**(w_bit - 1) - 1)
                    if w_bit == 8:
                        tensor = tensor.to(torch.int32)
                    tensor.add_(2 ** (w_bit - 1))

                    if target_tool == TOOL_AWQ:
                        izeros = apply_order(tensor, w_bit, direction, order)
                        qzeros = awq_pack(izeros, w_bit, direction)
                    elif target_tool == TOOL_GPTQ:
                        qzeros = gptq_qzeros_pack(tensor, w_bit)
                    vllm_weight_dict[vllm_name] = qzeros

                elif 'module.weight' in name and 'model.norm' not in name:
                    vllm_name = name.replace('module.weight', 'weight')
                    vllm_weight_dict[vllm_name] = tensor
                elif 'model.norm.module.bias' in name or 'model.norm.module.weight' in name:
                    pass
                else:
                    vllm_weight_dict[name] = tensor
            else:
                vllm_weight_dict[name] = tensor
    return vllm_weight_dict
    

def load_json_info(json_file_path):
    with open(json_file_path, 'r', encoding='utf-8') as f:
        quant_info = json.load(f)

    return quant_info


def check_w_bit(value):
    ivalue = int(value)
    if ivalue not in (4, 8):
        raise argparse.ArgumentTypeError(f"Invalid w_bit value: {value}. Supported values are 4 and 8.")
    return ivalue


def check_target_tool(value):
    if value not in ("awq", "gptq"):
        raise argparse.ArgumentTypeError(f"Invalid target_tool value: {value}. Supported values are 'awq' and 'gptq'.")
    return value


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default=None, help="Quantied safetensors file path")
    parser.add_argument("--json", type=str, default=None, help="Quantied description file path")
    parser.add_argument("--save_path", type=str, 
                        default='res.safetensors', 
                        help="The path to save converted quant weights")
    parser.add_argument("--w_bit", type=check_w_bit, default=4, help="Quantied weight bits")
    parser.add_argument(
        "--target_tool", type=check_target_tool, default="awq", help="target tool, value include awq and gptq"
        )
    args = parser.parse_args()

    save_path = args.save_path
    w_bit = args.w_bit
    quant_tool = args.target_tool
    
    model_path = get_valid_read_path(args.model, size_max=0)
    tensor_info = load_file(model_path)

    json_path = get_valid_read_path(args.json)
    json_info = load_json_info(json_path)

    vllm_weight = convert_ms_to_vllm(quant_tool, w_bit, weight_dict=tensor_info, json_dict=json_info)

    save_path = get_valid_write_path(save_path)
    with SafeWriteUmask(umask=0o377):
        save_file(vllm_weight, save_path)