#!/usr/bin/env python

# -*- coding: UTF-8 -*-



"""

-------------------------------------------------------------------------

This file is part of the MindStudio project.

Copyright (c) 2025 Huawei Technologies Co.,Ltd.



MindStudio is licensed under Mulan PSL v2.

You may use this software according to the terms and conditions of the Mulan PSL v2.

You may obtain a copy of Mulan PSL v2 at:



         http://license.coscl.org.cn/MulanPSL2



THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,

EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,

MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.

See the Mulan PSL v2 for more details.

-------------------------------------------------------------------------



Post-process AscendV1 quantized weights: convert deq_scale from float32/bf16 to int64

(bit-pattern of float32 stored as int64), so that checkpoints saved with bf16 default

can be used where int64 deq_scale is expected. Supports single-file and sharded layouts.

"""



import argparse

import json

import logging

import os

import shutil

import sys

import tempfile

from collections import defaultdict

from typing import Optional, Set



import numpy as np

import torch

from safetensors.torch import load_file, save_file



logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

logger = logging.getLogger(__name__)





def _deqscale2int64(scale: torch.Tensor) -> torch.Tensor:

    """

    Interpret float32 deq_scale as int32 bit pattern and store as int64.

    Same semantics as AscendV1 saver (non-bf16 path). No msmodelslim dependency.

    """

    scale = scale.cpu().numpy()

    scale = np.frombuffer(scale.tobytes(), dtype=np.int32).astype(np.int64)

    return torch.tensor(scale)



ASCENDV1_DESC_JSON_NAME = "quant_model_description.json"

ASCENDV1_SAFETENSORS_NAME = "quant_model_weights.safetensors"

ASCENDV1_SAFETENSORS_INDEX_NAME = "quant_model_weights.safetensors.index.json"

DEQ_SCALE_QUANT_TYPES = ("W8A8", "W8A8_MIX")

SUPPORTED_CONFIG_EXTENSIONS = (".json", ".py")

MAX_CONFIG_FILES = 1024





def parse_args():

    parser = argparse.ArgumentParser(

        description="Convert deq_scale in AscendV1 quant weights from float32/bf16 to int64."

    )

    parser.add_argument(

        "--model_path",

        type=str,

        required=True,

        help="Directory containing quant_model_description.json and safetensors weight(s).",

    )

    parser.add_argument(

        "--output_dir",

        type=str,

        default=None,

        help="Output directory. If not set, overwrites in-place (with temp + rename).",

    )

    parser.add_argument(

        "--dry_run",

        action="store_true",

        help="Only print which keys would be converted, do not write files.",

    )

    return parser.parse_args()





def get_deq_scale_keys_from_description(model_path: str) -> Optional[Set[str]]:

    """Read quant_model_description.json and return set of deq_scale keys (W8A8/W8A8_MIX)."""

    desc_path = os.path.join(model_path, ASCENDV1_DESC_JSON_NAME)

    if not os.path.isfile(desc_path):

        return None

    with open(desc_path, "r", encoding="utf-8") as f:

        desc = json.load(f)

    if not isinstance(desc, dict):

        return None

    keys = set()

    for key, value in desc.items():

        if isinstance(value, str) and value in DEQ_SCALE_QUANT_TYPES and key.endswith(".deq_scale"):

            keys.add(key)

    return keys if keys else None





def is_deq_scale_key_candidate(key: str) -> bool:

    """Fallback: key looks like a deq_scale (e.g. ends with .deq_scale)."""

    return key.endswith(".deq_scale") or ".deq_scale" in key





def convert_tensor_if_needed(tensor: torch.Tensor, key: str, dry_run: bool, converted: list, skipped: list):

    """Convert float32/bf16 deq_scale to int64; return converted tensor or original."""

    if tensor.dtype == torch.int64:

        skipped.append(key)

        return tensor

    if tensor.dtype == torch.float32:

        out = _deqscale2int64(tensor)

        converted.append(key)

        return out

    if tensor.dtype == torch.bfloat16:

        out = _deqscale2int64(tensor.to(torch.float32))

        converted.append(key)

        return out

    skipped.append(key)

    return tensor





def process_single_file(

    model_path: str,

    output_dir: str,

    deq_scale_keys: Optional[Set[str]],

    dry_run: bool,

    converted: list,

    skipped: list,

):

    """Process single quant_model_weights.safetensors file."""

    src_file = os.path.join(model_path, ASCENDV1_SAFETENSORS_NAME)

    if not os.path.isfile(src_file):

        return False

    tensors = load_file(src_file, device="cpu")

    modified = False

    for key in list(tensors.keys()):

        if deq_scale_keys is not None and key not in deq_scale_keys:

            continue

        if not is_deq_scale_key_candidate(key):

            continue

        t = tensors[key]

        if t.dtype not in (torch.float32, torch.bfloat16, torch.int64):

            continue

        new_t = convert_tensor_if_needed(t, key, dry_run, converted, skipped)

        if new_t is not t:

            tensors[key] = new_t

            modified = True

    if not modified or dry_run:

        if dry_run and converted:

            logger.info("Dry run: would convert keys in single file: %s", src_file)

        return True

    out_file = os.path.join(output_dir, ASCENDV1_SAFETENSORS_NAME)

    if output_dir == model_path:

        fd, tmp_path = tempfile.mkstemp(suffix=".safetensors", dir=output_dir)

        os.close(fd)

        try:

            save_file(tensors, tmp_path)

            os.replace(tmp_path, src_file)

        except Exception:

            if os.path.isfile(tmp_path):

                os.remove(tmp_path)

            raise

    else:

        save_file(tensors, out_file)

    logger.info("Processed single file: %s -> %s", src_file, out_file)

    return True





def process_sharded(

    model_path: str,

    output_dir: str,

    deq_scale_keys: Optional[Set[str]],

    dry_run: bool,

    converted: list,

    skipped: list,

):

    """Process sharded layout using quant_model_weights.safetensors.index.json."""

    index_path = os.path.join(model_path, ASCENDV1_SAFETENSORS_INDEX_NAME)

    if not os.path.isfile(index_path):

        return False

    with open(index_path, "r", encoding="utf-8") as f:

        index_data = json.load(f)

    weight_map = index_data.get("weight_map")

    if not weight_map:

        logger.warning("Index file has no weight_map: %s", index_path)

        return True

    # key -> filename (basename)

    file_to_keys = defaultdict(list)

    for key, filename in weight_map.items():

        file_to_keys[filename].append(key)

    for filename, keys_in_file in file_to_keys.items():

        src_file = os.path.join(model_path, filename)

        if not os.path.isfile(src_file):

            logger.warning("Shard file not found: %s", src_file)

            continue

        tensors = load_file(src_file, device="cpu")

        modified = False

        for key in keys_in_file:

            if deq_scale_keys is not None and key not in deq_scale_keys:

                continue

            if not is_deq_scale_key_candidate(key):

                continue

            if key not in tensors:

                continue

            t = tensors[key]

            if t.dtype not in (torch.float32, torch.bfloat16, torch.int64):

                continue

            new_t = convert_tensor_if_needed(t, key, dry_run, converted, skipped)

            if new_t is not t:

                tensors[key] = new_t

                modified = True

        if not modified or dry_run:

            continue

        out_file = os.path.join(output_dir, filename)

        if output_dir == model_path:

            fd, tmp_path = tempfile.mkstemp(suffix=".safetensors", dir=output_dir)

            os.close(fd)

            try:

                save_file(tensors, tmp_path)

                os.replace(tmp_path, src_file)

            except Exception:

                if os.path.isfile(tmp_path):

                    os.remove(tmp_path)

                raise

        else:

            save_file(tensors, out_file)

        logger.info("Processed shard: %s -> %s", src_file, out_file)

    if dry_run and converted:

        logger.info("Dry run: would convert keys in sharded files under: %s", model_path)

    return True





def copy_all_to_output(src_dir: str, dst_dir: str, weight_map: Optional[dict]):

    """Copy config, description, index, and all weight files to output dir."""

    if src_dir == dst_dir:

        return

    if not os.path.isdir(dst_dir):

        os.makedirs(dst_dir, 0o750)

    names = os.listdir(src_dir)

    if len(names) > MAX_CONFIG_FILES:

        raise ValueError(f"Too many files in directory ({len(names)}), limit {MAX_CONFIG_FILES}.")

    for name in names:

        src_path = os.path.join(src_dir, name)

        if not os.path.isfile(src_path):

            continue

        _, ext = os.path.splitext(name)

        is_safetensors = name.endswith(".safetensors")

        is_config = ext in SUPPORTED_CONFIG_EXTENSIONS

        if is_safetensors and weight_map is not None and name not in set(weight_map.values()):

            continue

        if not is_config and not is_safetensors:

            continue

        dst_path = os.path.join(dst_dir, name)

        shutil.copy2(src_path, dst_path)

        os.chmod(dst_path, 0o600)

    logger.info("Copied config and weight files to: %s", dst_dir)





def main():

    args = parse_args()

    model_path = os.path.abspath(os.path.expanduser(args.model_path))

    if not os.path.isdir(model_path):

        logger.error("model_path is not a directory: %s", model_path)

        sys.exit(1)

    output_dir = os.path.abspath(os.path.expanduser(args.output_dir or model_path))

    if output_dir != model_path and not args.dry_run:

        os.makedirs(output_dir, 0o750)



    # When writing to a different dir, copy full layout first then convert in place under output_dir

    work_dir = model_path

    if output_dir != model_path and not args.dry_run:

        weight_map = None

        index_path = os.path.join(model_path, ASCENDV1_SAFETENSORS_INDEX_NAME)

        if os.path.isfile(index_path):

            with open(index_path, "r", encoding="utf-8") as f:

                weight_map = json.load(f).get("weight_map") or {}

        copy_all_to_output(model_path, output_dir, weight_map)

        work_dir = output_dir

    elif output_dir != model_path and args.dry_run:

        work_dir = model_path



    deq_scale_keys = get_deq_scale_keys_from_description(work_dir)

    if deq_scale_keys is not None:

        logger.info("Using deq_scale keys from description: %d keys", len(deq_scale_keys))

    else:

        logger.info("No description or no W8A8/W8A8_MIX deq_scale entries; will infer from key name and dtype.")



    converted = []

    skipped = []



    index_path = os.path.join(work_dir, ASCENDV1_SAFETENSORS_INDEX_NAME)

    single_path = os.path.join(work_dir, ASCENDV1_SAFETENSORS_NAME)

    ret = False

    if os.path.isfile(index_path):

        ret = process_sharded(work_dir, work_dir, deq_scale_keys, args.dry_run, converted, skipped)

    elif os.path.isfile(single_path):

        ret = process_single_file(work_dir, work_dir, deq_scale_keys, args.dry_run, converted, skipped)

    else:

        logger.error(

            "Neither %s nor %s found in %s",

            ASCENDV1_SAFETENSORS_NAME,

            ASCENDV1_SAFETENSORS_INDEX_NAME,

            work_dir,

        )

        sys.exit(1)

    if not ret:

        logger.error("Processing failed (file missing or invalid).")

        sys.exit(1)



    if args.dry_run:

        logger.info("Dry run: would convert %d keys, skip %d.", len(converted), len(skipped))

        for k in converted[:20]:

            logger.info("  convert: %s", k)

        if len(converted) > 20:

            logger.info("  ... and %d more", len(converted) - 20)

        return



    logger.info("Converted %d deq_scale keys to int64, skipped %d.", len(converted), len(skipped))





if __name__ == "__main__":

    main()