#!/usr/bin/env python
"""
Merge LoRA safetensors weights with base HF model and save as HF format.
"""
import os
import sys
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import argparse
import json
import shutil
import torch
import torch_npu
from safetensors import safe_open
from safetensors.torch import save_file


def get_args():
    parser = argparse.ArgumentParser(description="Merge LoRA safetensors weights with base HF model")
    parser.add_argument("--base_hf_dir", type=str, required=True, help="Path to the base HF model directory")
    parser.add_argument("--lora_safetensors", type=str, required=True, help="Path to the LoRA safetensors file")
    parser.add_argument("--save_merged_hf_dir", type=str, required=True, help="Path to save the merged HF model")
    parser.add_argument("--lora_target_modules", type=str, nargs='+', default=None, help="LoRA target modules (auto-detect if not specified)")
    parser.add_argument("--lora_alpha", type=int, default=16, help="The lora_alpha config value")
    parser.add_argument("--lora_r", type=int, default=8, help="The lora_r config value")
    parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "npu"], help="Device to use for LoRA merging computation (default: cpu)")
    return parser.parse_args()


def merge_lora_to_base(base_state_dict, lora_state_dict, target_layers, scaling, device="npu"):
    """Merge LoRA weights into base model weights."""
    for target_layer in target_layers:
        base_key = f"{target_layer}.weight"
        lora_a_key = f"{target_layer}.lora_A.default.weight"
        lora_b_key = f"{target_layer}.lora_B.default.weight"

        if base_key in base_state_dict:
            lora_a = lora_state_dict.get(lora_a_key)
            lora_b = lora_state_dict.get(lora_b_key)

            if lora_a is not None and lora_b is not None:
                base_weight = base_state_dict[base_key].to(device=device, dtype=torch.float32)
                lora_a = lora_a.to(device=device, dtype=torch.float32)
                lora_b = lora_b.to(device=device, dtype=torch.float32)
                merged_weight = base_weight + scaling * (lora_b @ lora_a)
                base_state_dict[base_key] = merged_weight.to(device="cpu", dtype=torch.bfloat16)
    return base_state_dict


def main():
    args = get_args()

    base_hf_dir = Path(args.base_hf_dir)
    lora_safetensors = Path(args.lora_safetensors)
    save_merged_hf_dir = Path(args.save_merged_hf_dir)
    lora_target_modules = args.lora_target_modules
    lora_alpha = args.lora_alpha
    lora_r = args.lora_r
    device = args.device
    scaling = lora_alpha / lora_r

    print("=" * 60)
    print("Loading base HF model weights...")
    print("=" * 60)

    index_file = base_hf_dir / "model.safetensors.index.json"
    with open(index_file, "r") as f:
        weight_map = json.load(f)["weight_map"]

    shard_keys = {}
    for key, shard_file in weight_map.items():
        shard_keys.setdefault(shard_file, []).append(key)

    print(f"Total weights: {len(weight_map)}")
    print(f"Shard files: {len(shard_keys)}")

    print(f"\nLoading LoRA weights from: {lora_safetensors}")
    lora_state_dict = {}
    with safe_open(lora_safetensors, framework="pt", device="cpu") as f:
        for key in f.keys():
            lora_state_dict[key] = f.get_tensor(key)
    print(f"LoRA keys: {len(lora_state_dict)}")

    target_layers = set()
    for name in lora_state_dict.keys():
        if ".lora_A.default.weight" in name:
            layer_name = name.split(".lora_")[0]
            if lora_target_modules is None or any(mod in layer_name for mod in lora_target_modules):
                target_layers.add(layer_name)
    print(f"LoRA target layers: {len(target_layers)}")

    save_merged_hf_dir.mkdir(parents=True, exist_ok=True)

    print(f"\nCopying config and tokenizer...")
    for item in base_hf_dir.iterdir():
        if item.name.startswith("model.safetensors"):
            continue
        if item.is_file():
            shutil.copy2(item, save_merged_hf_dir / item.name)
        elif item.is_dir():
            shutil.copytree(item, save_merged_hf_dir / item.name, dirs_exist_ok=True)

    shutil.copy2(index_file, save_merged_hf_dir / "model.safetensors.index.json")

    print(f"\nMerging and saving shards...")
    for shard_file, keys in shard_keys.items():
        print(f"  Processing {shard_file}...")

        shard_path = base_hf_dir / shard_file
        shard_state_dict = {}
        with safe_open(shard_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                shard_state_dict[key] = f.get_tensor(key)

        shard_state_dict = merge_lora_to_base(shard_state_dict, lora_state_dict, target_layers, scaling, device)
        merged_count = sum(1 for tl in target_layers if f"{tl}.weight" in shard_state_dict)

        if merged_count > 0:
            print(f"    Merged {merged_count} layers")

        save_file(shard_state_dict, save_merged_hf_dir / shard_file, metadata={"format": "pt"})
        del shard_state_dict
        if device == "npu":
            torch.npu.empty_cache()

    print(f"\nMerge complete! Saved to {save_merged_hf_dir}")


if __name__ == "__main__":
    main()