#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/examples/offline_inference/save_sharded_state.py
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Sparse-Compress-Quantization state dict could also be saved via this script.

Example usage:

python save_sharded_state_310.py \
    --model /path/to/load \
    --tensor-parallel-size 8 \
    --output /path/to/save \
    --enable-compress \
    --compress-process-num 8 \
    --enforce-eager \
    --dtype float16 \
    --quantization ascend

Then, the model can be loaded with

llm = LLM(
    model="/path/to/save",
    load_format="sharded_state",
    tensor_parallel_size=8,
    quantization="ascend",
)
"""

import dataclasses
import json
import multiprocessing as mp
import os
import shutil
from pathlib import Path

import torch
from vllm import LLM, EngineArgs
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
from vllm.utils.argparse_utils import FlexibleArgumentParser

SUPPORTED_COMPRESS_QUANT_TYPE = ["W8A8S", "W16A16S"]
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
QUANTIZATION_UPDATE_MAP = {"W8A8S": "W8A8SC", "W16A16S": "W16A16SC"}


class FileHandler:
    @staticmethod
    def validate_path(path: str, must_exist: bool = True, check_writable: bool = False) -> Path:
        """
        Comprehensive path validation.
        - Checks existence
        - Checks write permissions for the target or its parent
        """
        p = Path(path)
        if must_exist and not p.exists():
            raise FileNotFoundError(f"Error: Path '{path}' does not exist.")

        if check_writable:
            # Check the directory itself if it exists, otherwise check the parent
            target = p if p.exists() else p.parent
            if not os.access(target, os.W_OK):
                raise PermissionError(f"Permission Denied: No write access to '{target}'.")
        return p

    @staticmethod
    def safe_copy(src: Path, dst: Path):
        """Copies files or directories with permission handling."""
        try:
            if src.is_dir():
                # dirs_exist_ok=True prevents errors if the destination directory exists
                shutil.copytree(src, dst, dirs_exist_ok=True)
            else:
                # copy2 preserves metadata (timestamps, permissions)
                shutil.copy2(src, dst)
        except (PermissionError, OSError) as e:
            print(f"Warning: Failed to copy {src} due to: {e}")


def clean_up():
    """Clean up VLLM resources"""
    destroy_model_parallel()
    destroy_distributed_environment()
    torch.npu.empty_cache()


def parse_args():
    parser = FlexibleArgumentParser()
    EngineArgs.add_cli_args(parser)
    parser.add_argument("--output", "-o", required=True, type=str, help="path to output checkpoint")
    parser.add_argument(
        "--enable-compress",
        action="store_true",
    )
    parser.add_argument(
        "--compress-process-num",
        type=int,
        default=1,
    )
    return parser.parse_args()


def get_quant_description(json_file: str) -> dict:
    """
    Extract quantization description from JSON configuration file.

    Args:
        json_file: Path to the JSON configuration file

    Returns:
        dict: Quantization descriptor dictionary

    Raises:
        FileNotFoundError: If the JSON file does not exist
        RuntimeError: If JSON parsing fails or required keys are missing
    """
    config_path = Path(json_file)
    if not config_path.exists():
        raise FileNotFoundError(f"Model configuration file not found: {json_file}")
    try:
        with config_path.open("r", encoding="utf-8") as file:
            quant_desc = json.load(file)
    except json.JSONDecodeError as e:
        raise RuntimeError(f"Invalid JSON format in {json_file}: {e}")

    return quant_desc


def update_quant_description(ori_json_file: str, target_json_file: str) -> None:
    """
    Update quantization types in JSON configuration file based on update mapping.

    Args:
        ori_json_file: Path to the JSON configuration file
        target_json_file: Path to the JSON configuration file to be saved

    Raises:
        FileNotFoundError: If the JSON file does not exist
        RuntimeError: If JSON parsing fails or required keys are missing
    """
    config_path = Path(ori_json_file)
    try:
        with config_path.open("r", encoding="utf-8") as file:
            json_data = json.load(file)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        raise RuntimeError(f"Failed to read configuration file {ori_json_file}: {e}")

    original_quant_type = json_data.get("model_quant_type")
    if not original_quant_type or original_quant_type not in QUANTIZATION_UPDATE_MAP:
        raise RuntimeError(
            f"Cannot update quantization type. "
            f"Original type '{original_quant_type}' not found or not supported for update in {ori_json_file}."
        )
    updated_quant_type = QUANTIZATION_UPDATE_MAP[original_quant_type]

    updated_config = {"model_quant_type": updated_quant_type, "version": "1.0.0"}

    for key, value in json_data.items():
        if key.endswith(".weight") and value == original_quant_type:
            updated_config[key] = updated_quant_type
        elif key not in ("model_quant_type", "version"):
            updated_config[key] = value

    try:
        new_file_path = Path(target_json_file)
        with new_file_path.open("w", encoding="utf-8") as file:
            json.dump(updated_config, file, indent=2, ensure_ascii=False)
        os.remove(ori_json_file)
    except OSError as e:
        raise RuntimeError(f"Failed to write updated configuration to {target_json_file}: {e}")


def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -> bool:
    """
    Worker logic for multiprocessing.
    Note: Imports are inside the worker to save memory in the main process.

    Returns:
        bool: True if processing succeeded, False otherwise.
    """
    import safetensors
    import safetensors.torch
    from msmodelslim.pytorch.weight_compression import CompressConfig, Compressor

    p = Path(file_path)
    if not p.exists():
        print(f"Error: File not found, failed to compress: {file_path}")
        return False

    try:
        state_dict = safetensors.torch.load_file(str(p))

        compress_config = CompressConfig(
            do_pseudo_sparse=False,
            sparse_ratio=1,
            is_debug=True,
            record_detail_root=str(p.parent),
            multiprocess_num=process_num,
        )
        compressor = Compressor(compress_config, weight=state_dict, quant_model_description=quant_desc)
        compressor.run()
        if p.exists():
            os.remove(p)
        compressor.export_safetensors(str(p.parent), safetensors_name=p.name)
        return True
    except Exception as e:
        print(f"Error processing Rank file {file_path}: {e}")
        return False


def main(args):
    # 1. Initial Validation
    # Validate early so the script doesn't fail after hours of inference
    output_dir = FileHandler.validate_path(args.output, must_exist=False, check_writable=True)
    model_dir = FileHandler.validate_path(args.model, must_exist=True)

    # 2. Run VLLM Engine and save sharded states
    engine_args = EngineArgs.from_cli_args(args)
    llm = LLM(**dataclasses.asdict(engine_args))

    output_dir.mkdir(parents=True, exist_ok=True)
    llm.llm_engine.engine_core.save_sharded_state(path=str(output_dir))

    del llm
    clean_up()

    # 3. Migrate Metadata (Excluding large weights)
    for item in model_dir.iterdir():
        if item.suffix not in (".bin", ".pt", ".safetensors"):
            FileHandler.safe_copy(item, output_dir / item.name)

    # 4. Compression Logic
    parameters_map_fpath = output_dir / "parameters_type_map.json"
    if args.enable_compress:
        quant_desc_file = output_dir / "quant_model_description.json"
        backup_quant_desc_file = output_dir / "ori_quant_model_description.json"
        if quant_desc_file.exists():
            os.rename(str(quant_desc_file), str(backup_quant_desc_file))
        quant_desc = get_quant_description(str(parameters_map_fpath))
        quant_type = quant_desc["model_quant_type"]
        if quant_type in SUPPORTED_COMPRESS_QUANT_TYPE:
            # TODO: Implement w16a16sc
            if quant_type == "W16A16S":
                raise NotImplementedError("W16A16SC is not supported yet.")

            tasks = []
            for i in range(args.tensor_parallel_size):
                file_name = DEFAULT_PATTERN.format(rank=i, part="0")
                full_path = output_dir / file_name

                p = mp.Process(
                    target=weight_compress_worker, args=(str(full_path), quant_desc, args.compress_process_num)
                )
                tasks.append(p)
                p.start()

            for p in tasks:
                p.join()

            update_quant_description(str(backup_quant_desc_file), str(quant_desc_file))
            print("Compression completed successfully.")
        else:
            print(f"Skipping compression: Unsupported type {quant_type}")
    if parameters_map_fpath.exists():
        os.remove(parameters_map_fpath)


if __name__ == "__main__":
    args = parse_args()
    main(args)