# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
transform mindspore ckpt to huggingface model.
"""
import json
import os
from collections import defaultdict
from glob import glob
import warnings
from safetensors.torch import save_file
import numpy as np
import torch

import mindspore as ms
from mindspore.ops.operations import Cast
from mindformers.tools.utils import set_safe_mode_for_file_or_dir

ms.set_device(device_target='CPU')
cpu_cast = Cast().set_device('CPU')

dtype_map = {
    'fp32': torch.float32,
    'bf16': torch.bfloat16,
    'fp16': torch.float16
}

default_config = {
    'num_routed_experts': 256,
    'n_head': 128,
    'qk_nope_head_dim': 128,
    'qk_rope_head_dim': 64,
    'v_head_dim': 128,
    'num_layers': 61,
    'num_nextn_predict_layers': 1,
    'first_k_dense_replace': 3,
    'dtype': torch.bfloat16,
    'use_grouped_gemm': True,
    'load_format': "safetensors"
}


def str2bool(b: str):
    """String convert to Bool."""
    if b.lower() in ["false"]:
        output = False
    elif b.lower() in ["true"]:
        output = True
    else:
        raise Exception("Invalid Bool Value")
    return output


def plain_name_replace(weight_name: str):
    """Weight name replacing for pre/post-process module"""
    weight_name = weight_name.replace('tok_embeddings.embedding_weight', 'embed_tokens.weight')
    weight_name = weight_name.replace('model.norm_out.weight', 'model.norm.weight')
    return weight_name


def mla_name_replace(weight_name: str):
    """Weight name replacing for MLA module weights"""
    weight_name = weight_name.replace('.attention.q_proj.', '.self_attn.q_proj.')
    weight_name = weight_name.replace('.attention.q2l_proj.', '.self_attn.q_a_proj.')
    weight_name = weight_name.replace('.attention.lq_norm.', '.self_attn.q_a_layernorm.')
    weight_name = weight_name.replace('.attention.l2q_proj.', '.self_attn.q_b_proj.')
    weight_name = weight_name.replace('.attention.kv2l.', '.self_attn.kv_a_proj_with_mqa.')
    weight_name = weight_name.replace('.attention.lkv_norm.', '.self_attn.kv_a_layernorm.')
    weight_name = weight_name.replace('.attention.lkv2kv.', '.self_attn.kv_b_proj.')
    weight_name = weight_name.replace('.attention.wo.', '.self_attn.o_proj.')
    weight_name = weight_name.replace('.attention_norm.', '.input_layernorm.')
    weight_name = weight_name.replace('.ffn_norm.', '.post_attention_layernorm.')
    return weight_name


def mlp_name_replace(weight_name: str, use_grouped_gemm: bool = True):
    """Weight name replacing for MLP module, including MoE"""
    weight_name = weight_name.replace('feed_forward.w1.', 'mlp.gate_proj.')
    weight_name = weight_name.replace('feed_forward.w2.', 'mlp.down_proj.')
    weight_name = weight_name.replace('feed_forward.w3.', 'mlp.up_proj.')
    weight_name = weight_name.replace('feed_forward.shared_experts.w1.', 'mlp.shared_experts.gate_proj.')
    weight_name = weight_name.replace('feed_forward.shared_experts.w2.', 'mlp.shared_experts.down_proj.')
    weight_name = weight_name.replace('feed_forward.shared_experts.w3.', 'mlp.shared_experts.up_proj.')

    bmm_key = 'feed_forward.routed_experts.router.dense.weight'
    gmm_key = 'feed_forward.routed_experts.router_dense.weight'
    weight_name = weight_name.replace(gmm_key if use_grouped_gemm else bmm_key, 'mlp.gate.weight')

    bmm_key = 'feed_forward.routed_experts.router.router.topk_bias'
    gmm_key = 'feed_forward.routed_experts.topk_bias'
    weight_name = weight_name.replace(gmm_key if use_grouped_gemm else bmm_key,
                                      'mlp.gate.e_score_correction_bias')
    return weight_name


def mtp_name_replace(weight_name: str, current_layer_id: int, mtp_layer_id: int):
    """replace weight name for MultiPredictionToken module"""
    weight_name = weight_name.replace(f"model.mtp_hidden_fusers.{mtp_layer_id}.norm_emb",
                                      f"model.layers.{current_layer_id}.enorm")
    weight_name = weight_name.replace(f"model.mtp_hidden_fusers.{mtp_layer_id}.norm",
                                      f"model.layers.{current_layer_id}.hnorm")
    weight_name = weight_name.replace(f"model.mtp_hidden_fusers.{mtp_layer_id}.dense",
                                      f"model.layers.{current_layer_id}.eh_proj")
    weight_name = weight_name.replace(f"model.mtp_norms.{mtp_layer_id}",
                                      f"model.layers.{current_layer_id}.shared_head.norm")
    return weight_name


def load_data_ms(file_name):
    return ms.load_checkpoint(file_name, format="safetensors")


def layers_model_file_map(file_path, config):
    """Get weight-file map"""
    num_layers = config["num_layers"]
    layer_st_map = defaultdict(set)
    weight_map_file = os.path.join(file_path, "param_name_map.json")
    if not os.path.exists(weight_map_file):
        weight_map_file = os.path.join(file_path, "ms-model.safetensors.index.json")

    if os.path.exists(weight_map_file):
        with open(weight_map_file) as f:
            weights_map = json.load(f)
        try:
            weights_map = weights_map["weight_map"]
        except KeyError:
            pass
    else:
        warnings.warn(f"Cannot find weight map file eighther param_name_map.json or " \
                    f"ms-model.safetensors.index.json in path {file_path}, " \
                    f"Trying to load one safetensor file ...")
        files = sorted(glob(os.path.join(file_path, "*.safetensors")))
        if not files:
            raise ValueError(f"No safetensors files found in path {file_path}")

        weight_file = files[0].split("/")[-1]
        keys = load_data_ms(os.path.join(file_path, weight_file)).keys()
        weights_map = {}
        for k in keys:
            weights_map[k] = weight_file

    for weight_key, value in weights_map.items():
        if weight_key.startswith("model.layers."):
            layer_name = int(weight_key.split('model.layers.')[1].split('.')[0])
            layer_st_map[layer_name].add(os.path.join(file_path, value))
        elif weight_key.startswith("model.mtp_hidden_fusers."):
            mtp_layer_name = int(weight_key.split('model.mtp_hidden_fusers.')[1].split('.')[0])
            layer_name = num_layers + mtp_layer_name
            layer_st_map[layer_name].add(os.path.join(file_path, value))
        else:
            layer_st_map[weight_key].add(os.path.join(file_path, value))
    return layer_st_map


def read_matched_file(layer_st_map, layer_list, is_first, is_last):
    """Load weights into dict for specified layers"""
    st_file_list = []
    for layer in layer_list:
        st_file_list.extend(list(layer_st_map[layer]))
    if is_first:
        st_file_list.extend(list(layer_st_map["model.tok_embeddings.embedding_weight"]))
    if is_last:
        st_file_list.extend(list(layer_st_map["model.norm_out.weight"]))
        st_file_list.extend(list(layer_st_map["lm_head.weight"]))
    st_file_list = list(set(st_file_list))
    weights = {}
    for st_file in st_file_list:
        current_weight = load_data_ms(st_file)
        weights.update(current_weight)
    return weights


def _mla_ms_to_pt(layer_id, ms_layer_weights, config):
    """Processing weights in MLA module"""
    n_head = config['n_head']
    qk_nope_head_dim = config['qk_nope_head_dim']
    qk_rope_head_dim = config['qk_rope_head_dim']
    v_head_dim = config['v_head_dim']
    dtype = config['dtype']

    qk_nope_key = f"model.layers.{layer_id}.attention.l2q_nope_proj.weight"
    qk_rope_key = f"model.layers.{layer_id}.attention.l2q_pe_proj.weight"
    latent_kv_key = f"model.layers.{layer_id}.attention.kv2l_latent_kv.weight"
    k_rope_key = f"model.layers.{layer_id}.attention.kv2l_k_pe.weight"
    k_nope_key = f"model.layers.{layer_id}.attention.lkv2kv_k_nope.weight"
    v_key = f"model.layers.{layer_id}.attention.lkv2kv_v.weight"

    q_a_proj_key = f"model.layers.{layer_id}.attention.q2l_proj.weight"
    kv_a_proj_key = f"model.layers.{layer_id}.attention.kv2l.weight"
    o_proj_key = f"model.layers.{layer_id}.attention.wo.weight"
    q_a_layernorm_key = f"model.layers.{layer_id}.attention.lq_norm.weight"
    kv_a_layernorm_key = f"model.layers.{layer_id}.attention.lkv_norm.weight"
    q_b_proj_key = f"model.layers.{layer_id}.attention.l2q_proj.weight"
    kv_b_proj_key = f"model.layers.{layer_id}.attention.lkv2kv.weight"
    input_norm_key = f"model.layers.{layer_id}.attention_norm.weight"
    post_attn_norm_key = f"model.layers.{layer_id}.ffn_norm.weight"

    qk_nope = cpu_cast(ms_layer_weights.pop(qk_nope_key), ms.float32).numpy()
    qk_rope = cpu_cast(ms_layer_weights.pop(qk_rope_key), ms.float32).numpy()
    latent_kv = cpu_cast(ms_layer_weights.pop(latent_kv_key), ms.float32).numpy()
    k_rope = cpu_cast(ms_layer_weights.pop(k_rope_key), ms.float32).numpy()
    k_nope = cpu_cast(ms_layer_weights.pop(k_nope_key), ms.float32).numpy()
    v = cpu_cast(ms_layer_weights.pop(v_key), ms.float32).numpy()

    q_a_proj = cpu_cast(ms_layer_weights.pop(q_a_proj_key), ms.float32).numpy()
    o_proj = cpu_cast(ms_layer_weights.pop(o_proj_key), ms.float32).numpy()
    q_a_layernorm = cpu_cast(ms_layer_weights.pop(q_a_layernorm_key), ms.float32).numpy()
    kv_a_layernorm = cpu_cast(ms_layer_weights.pop(kv_a_layernorm_key), ms.float32).numpy()
    input_norm = cpu_cast(ms_layer_weights.pop(input_norm_key), ms.float32).numpy()
    post_attn_norm = cpu_cast(ms_layer_weights.pop(post_attn_norm_key), ms.float32).numpy()


    mla_weight_dict = defaultdict()
    # merge qk_nope, qk_rope into q_b_proj
    qk_rope = torch.from_numpy(qk_rope).to(dtype).reshape(n_head, 2, qk_rope_head_dim // 2, -1)
    qk_rope = qk_rope.permute(0, 2, 1, 3).reshape(n_head, qk_rope_head_dim, -1)
    qk_nope = torch.from_numpy(qk_nope).to(dtype).reshape(n_head, qk_nope_head_dim, -1)
    q_b_proj = torch.cat([qk_nope, qk_rope], dim=1).reshape(-1, qk_nope.shape[-1])
    q_b_proj_key = mla_name_replace(q_b_proj_key)
    mla_weight_dict[q_b_proj_key] = q_b_proj.clone()

    # merge latent_kv, k_rope into kv_a_proj
    k_rope = torch.from_numpy(k_rope).to(dtype).reshape(2, k_rope.shape[0] // 2, -1).permute(1, 0, 2)
    k_rope = k_rope.reshape(-1, k_rope.shape[-1])
    latent_kv = torch.from_numpy(latent_kv).to(dtype)
    kv_a_proj = torch.cat([latent_kv, k_rope], dim=0)
    kv_a_proj_key = mla_name_replace(kv_a_proj_key)
    mla_weight_dict[kv_a_proj_key] = kv_a_proj.clone()

    # merge k_nope, v into kv_b_proj
    k_nope = torch.from_numpy(k_nope).to(dtype).reshape(n_head, qk_nope_head_dim, -1)
    v = torch.from_numpy(v).to(dtype).reshape(n_head, v_head_dim, -1)
    kv_b_proj = torch.cat([k_nope, v], dim=1).reshape(-1, k_nope.shape[-1])
    kv_b_proj_key = mla_name_replace(kv_b_proj_key)
    mla_weight_dict[kv_b_proj_key] = kv_b_proj.clone()

    # process q_a_proj, o_proj, and layernorms
    q_a_proj_key = mla_name_replace(q_a_proj_key)
    mla_weight_dict[q_a_proj_key] = torch.from_numpy(q_a_proj).to(dtype).clone()
    o_proj_key = mla_name_replace(o_proj_key)
    mla_weight_dict[o_proj_key] = torch.from_numpy(o_proj).to(dtype).clone()
    q_a_layernorm_key = mla_name_replace(q_a_layernorm_key)
    mla_weight_dict[q_a_layernorm_key] = torch.from_numpy(q_a_layernorm).to(dtype).clone()
    kv_a_layernorm_key = mla_name_replace(kv_a_layernorm_key)
    mla_weight_dict[kv_a_layernorm_key] = torch.from_numpy(kv_a_layernorm).to(dtype).clone()
    input_norm_key = mla_name_replace(input_norm_key)
    mla_weight_dict[input_norm_key] = torch.from_numpy(input_norm).to(dtype).clone()
    post_attn_norm_key = mla_name_replace(post_attn_norm_key)
    mla_weight_dict[post_attn_norm_key] = torch.from_numpy(post_attn_norm).to(dtype).clone()

    return mla_weight_dict


def _mlp_ms_to_pt(layer_id, ms_layer_weights, config):
    """Processing weights in MLP/MoE module"""
    num_routed_experts = config['num_routed_experts']
    first_k_dense_replace = config['first_k_dense_replace']
    dtype = config['dtype']
    use_grouped_gemm = config['use_grouped_gemm']

    mlp_weight_dict = defaultdict()
    if layer_id < first_k_dense_replace:
        gate_proj_key = f"model.layers.{layer_id}.feed_forward.w1.weight"
        up_proj_key = f"model.layers.{layer_id}.feed_forward.w3.weight"
        down_proj_key = f"model.layers.{layer_id}.feed_forward.w2.weight"
        gate_proj = cpu_cast(ms_layer_weights.pop(gate_proj_key), ms.float32).numpy()
        up_proj = cpu_cast(ms_layer_weights.pop(up_proj_key), ms.float32).numpy()
        down_proj = cpu_cast(ms_layer_weights.pop(down_proj_key), ms.float32).numpy()

        gate_proj_key = mlp_name_replace(gate_proj_key)
        up_proj_key = mlp_name_replace(up_proj_key)
        down_proj_key = mlp_name_replace(down_proj_key)
        mlp_weight_dict[gate_proj_key] = torch.from_numpy(gate_proj).to(dtype).clone()
        mlp_weight_dict[up_proj_key] = torch.from_numpy(up_proj).to(dtype).clone()
        mlp_weight_dict[down_proj_key] = torch.from_numpy(down_proj).to(dtype).clone()
    else:
        if use_grouped_gemm:
            router_weight_key = f"model.layers.{layer_id}.feed_forward.routed_experts.router_dense.weight"
            router_correct_bias_key = f"model.layers.{layer_id}.feed_forward.routed_experts.topk_bias"
        else:
            router_weight_key = f"model.layers.{layer_id}.feed_forward.routed_experts.router.dense.weight"
            router_correct_bias_key = f"model.layers.{layer_id}.feed_forward.routed_experts.router.router.topk_bias"
        shared_experts_gate_proj_key = f"model.layers.{layer_id}.feed_forward.shared_experts.w1.weight"
        shared_experts_up_proj_key = f"model.layers.{layer_id}.feed_forward.shared_experts.w3.weight"
        shared_experts_down_proj_key = f"model.layers.{layer_id}.feed_forward.shared_experts.w2.weight"
        router_weight = cpu_cast(ms_layer_weights.pop(router_weight_key), ms.float32).numpy()
        router_weight = router_weight[:num_routed_experts, :]
        router_correct_bias = cpu_cast(ms_layer_weights.pop(router_correct_bias_key), ms.float32).numpy()
        router_correct_bias = router_correct_bias[:num_routed_experts]
        shared_experts_gate_proj = cpu_cast(ms_layer_weights.pop(shared_experts_gate_proj_key), ms.float32).numpy()
        shared_experts_up_proj = cpu_cast(ms_layer_weights.pop(shared_experts_up_proj_key), ms.float32).numpy()
        shared_experts_down_proj = cpu_cast(ms_layer_weights.pop(shared_experts_down_proj_key), ms.float32).numpy()

        # replace name and store
        router_weight_key = mlp_name_replace(router_weight_key, use_grouped_gemm)
        router_correct_bias_key = mlp_name_replace(router_correct_bias_key, use_grouped_gemm)
        shared_experts_gate_proj_key = mlp_name_replace(shared_experts_gate_proj_key)
        shared_experts_up_proj_key = mlp_name_replace(shared_experts_up_proj_key)
        shared_experts_down_proj_key = mlp_name_replace(shared_experts_down_proj_key)
        mlp_weight_dict[router_weight_key] = torch.from_numpy(router_weight).to(dtype).clone()
        mlp_weight_dict[router_correct_bias_key] = torch.from_numpy(router_correct_bias).to(dtype).clone()
        mlp_weight_dict[shared_experts_gate_proj_key] = torch.from_numpy(shared_experts_gate_proj).to(dtype).clone()
        mlp_weight_dict[shared_experts_up_proj_key] = torch.from_numpy(shared_experts_up_proj).to(dtype).clone()
        mlp_weight_dict[shared_experts_down_proj_key] = torch.from_numpy(shared_experts_down_proj).to(dtype).clone()

        # routed experts
        if use_grouped_gemm:
            weight1_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w1"
            weight2_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w2"
            weight1 = cpu_cast(ms_layer_weights.pop(weight1_key), ms.float32).numpy()
            weight2 = cpu_cast(ms_layer_weights.pop(weight2_key), ms.float32).numpy()
            # split then transpose back
            expert_gate_proj, expert_up_proj = np.split(
                weight1.reshape(num_routed_experts, -1, weight1.shape[-1]), 2, axis=-1)
            expert_gate_proj = np.swapaxes(expert_gate_proj, 1, 2)
            expert_up_proj = np.swapaxes(expert_up_proj, 1, 2)
            # transpose back
            expert_down_proj = np.swapaxes(
                weight2.reshape(num_routed_experts, -1, weight2.shape[-1]), 1, 2)
        else:
            expert_gate_proj_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w1.weight"
            expert_up_proj_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w3.weight"
            expert_down_proj_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w2.weight"
            expert_gate_proj = cpu_cast(ms_layer_weights.pop(expert_gate_proj_key), ms.float32).numpy()
            expert_up_proj = cpu_cast(ms_layer_weights.pop(expert_up_proj_key), ms.float32).numpy()
            expert_down_proj = cpu_cast(ms_layer_weights.pop(expert_down_proj_key), ms.float32).numpy()
        expert_gate_proj = torch.from_numpy(expert_gate_proj).to(dtype).reshape(num_routed_experts,
                                                                                -1, expert_gate_proj.shape[-1])
        expert_up_proj = torch.from_numpy(expert_up_proj).to(dtype).reshape(num_routed_experts,
                                                                            -1, expert_up_proj.shape[-1])
        expert_down_proj = torch.from_numpy(expert_down_proj).to(dtype).reshape(num_routed_experts,
                                                                                -1, expert_down_proj.shape[-1])

        for expert_id in range(num_routed_experts):
            gate_proj_key = f"model.layers.{layer_id}.mlp.experts.{expert_id}.gate_proj.weight"
            up_proj_key = f"model.layers.{layer_id}.mlp.experts.{expert_id}.up_proj.weight"
            down_proj_key = f"model.layers.{layer_id}.mlp.experts.{expert_id}.down_proj.weight"
            mlp_weight_dict[gate_proj_key] = expert_gate_proj[expert_id, ...].clone().contiguous()
            mlp_weight_dict[up_proj_key] = expert_up_proj[expert_id, ...].clone().contiguous()
            mlp_weight_dict[down_proj_key] = expert_down_proj[expert_id, ...].clone().contiguous()

    return mlp_weight_dict


def _mtp_ms_to_pt(layer_id, ms_layer_weights, config):
    """Processing weights in MTP module"""
    num_layers = config["num_layers"]
    dtype = config['dtype']

    mtp_layer_id = layer_id - num_layers
    # ignore the shared emb_weights and lm head in mtp layers
    enorm_key = f"model.mtp_hidden_fusers.{mtp_layer_id}.norm_emb.weight"
    hnorm_key = f"model.mtp_hidden_fusers.{mtp_layer_id}.norm.weight"
    e_proj_key = f"model.mtp_hidden_fusers.{mtp_layer_id}.dense.weight"
    norm_out_key = f"model.mtp_norms.{mtp_layer_id}.weight"

    enorm = cpu_cast(ms_layer_weights.pop(enorm_key), ms.float32).numpy()
    hnorm = cpu_cast(ms_layer_weights.pop(hnorm_key), ms.float32).numpy()
    e_proj = cpu_cast(ms_layer_weights.pop(e_proj_key), ms.float32).numpy()
    shard_head_norm = cpu_cast(ms_layer_weights.pop(norm_out_key), ms.float32).numpy()

    mtp_weight_dict = defaultdict()
    enorm_key = mtp_name_replace(enorm_key, layer_id, mtp_layer_id)
    hnorm_key = mtp_name_replace(hnorm_key, layer_id, mtp_layer_id)
    e_proj_key = mtp_name_replace(e_proj_key, layer_id, mtp_layer_id)
    norm_out_key = mtp_name_replace(norm_out_key, layer_id, mtp_layer_id)
    mtp_weight_dict[enorm_key] = torch.from_numpy(enorm).to(dtype).clone()
    mtp_weight_dict[hnorm_key] = torch.from_numpy(hnorm).to(dtype).clone()
    mtp_weight_dict[e_proj_key] = torch.from_numpy(e_proj).to(dtype).clone()
    mtp_weight_dict[norm_out_key] = torch.from_numpy(shard_head_norm).to(dtype).clone()

    emb_weight_key = "model.tok_embeddings.embedding_weight"
    lm_head_key = "lm_head.weight"
    emb_weight = cpu_cast(ms_layer_weights.get(emb_weight_key), ms.float32).numpy()
    lm_head = cpu_cast(ms_layer_weights.get(lm_head_key), ms.float32).numpy()

    shared_embed_key = f"model.layers.{layer_id}.embed_tokens.weight"
    shared_head_key = f"model.layers.{layer_id}.shared_head.head.weight"
    mtp_weight_dict[shared_embed_key] = torch.from_numpy(emb_weight).to(dtype).clone()
    mtp_weight_dict[shared_head_key] = torch.from_numpy(lm_head).to(dtype).clone()

    return mtp_weight_dict


def _model_preprocess_ms_to_pt(ms_layer_weights, config):
    """Processing weights in prepross module"""
    dtype = config['dtype']
    emb_weight_key = "model.tok_embeddings.embedding_weight"
    emb_weight = cpu_cast(ms_layer_weights.get(emb_weight_key), ms.float32).numpy()
    emb_weight_key = plain_name_replace(emb_weight_key)

    plain_weight_dict = defaultdict()
    plain_weight_dict[emb_weight_key] = torch.from_numpy(emb_weight).to(dtype).clone()

    return plain_weight_dict


def _model_postprocess_ms_to_pt(ms_layer_weights, config):
    """Processing weights in postpross module"""
    dtype = config['dtype']
    final_norm_key = "model.norm_out.weight"
    lm_head_key = "lm_head.weight"
    final_norm = cpu_cast(ms_layer_weights.get(final_norm_key), ms.float32).numpy()
    lm_head = cpu_cast(ms_layer_weights.get(lm_head_key), ms.float32).numpy()

    final_norm_key = plain_name_replace(final_norm_key)
    lm_head_key = plain_name_replace(lm_head_key)

    plain_weight_dict = defaultdict()
    plain_weight_dict[final_norm_key] = torch.from_numpy(final_norm).to(dtype).clone()
    plain_weight_dict[lm_head_key] = torch.from_numpy(lm_head).to(dtype).clone()

    return plain_weight_dict


def get_torch_storage_size(tensor):
    """Get tensor's storage size, requires torch >= 2.1"""
    return tensor.untyped_storage().nbytes()


def ms_ckpt_convertor(input_path, output_path, config):
    """Convert ckpt format checkpoint"""
    # for .ckpt format checkpoints, only single file is valid
    if os.path.isdir(input_path):
        raise ValueError("File in `.ckpt` format is valid to convert checkpoints, but get a directory!")
    ms_weights = ms.load_checkpoint(input_path, format='ckpt')

    num_layers = config["num_layers"]
    num_nextn_predict_layers = config["num_nextn_predict_layers"]
    total_num_layers = num_layers + num_nextn_predict_layers

    converted_st_map = defaultdict()
    converted_st_map["weight_map"] = defaultdict()
    converted_st_map["metadata"] = defaultdict()

    total_size = 0
    for layer_id in range(total_num_layers):
        pt_layer_weights = defaultdict()
        if layer_id == 0:
            pt_layer_weights.update(_model_preprocess_ms_to_pt(ms_weights, config))
        pt_layer_weights.update(_mla_ms_to_pt(layer_id, ms_weights, config))
        pt_layer_weights.update(_mlp_ms_to_pt(layer_id, ms_weights, config))
        if layer_id > num_layers - 1:
            pt_layer_weights.update(_mtp_ms_to_pt(layer_id, ms_weights, config))
        if layer_id == total_num_layers - 1:
            pt_layer_weights.update(_model_postprocess_ms_to_pt(ms_weights, config))

        saving_file_name = f"model-{layer_id+1:05d}-of-{total_num_layers:05d}.safetensors"
        for name in list(pt_layer_weights.keys()):
            converted_st_map["weight_map"][name] = saving_file_name
            total_size += get_torch_storage_size(pt_layer_weights.get(name))
        save_file(pt_layer_weights, saving_file_name)

    converted_st_map["metadata"]["total_size"] = total_size
    converted_model_index_file = os.path.join(output_path, "model.safetensors.index.json")
    with open(converted_model_index_file, "w") as f:
        json_string = json.dumps(converted_st_map, default=lambda x: x.__dict__, sort_keys=False, indent=2)
        f.write(json_string)
    set_safe_mode_for_file_or_dir(converted_model_index_file)


def ms_safetensors_convertor(input_path, output_path, config):
    """Convert safetensors format checkpoint"""
    # try to get weight-file map
    layer_st_map = layers_model_file_map(input_path, config)

    num_layers = config["num_layers"]
    num_nextn_predict_layers = config["num_nextn_predict_layers"]
    total_num_layers = num_layers + num_nextn_predict_layers

    converted_st_map = defaultdict()
    converted_st_map["weight_map"] = defaultdict()
    converted_st_map["metadata"] = defaultdict()

    total_size = 0
    for layer_id in range(total_num_layers):
        if layer_id == 0:
            ms_layer_weights = read_matched_file(layer_st_map, [layer_id], is_first=True, is_last=False)
        elif 0 < layer_id < num_layers:
            ms_layer_weights = read_matched_file(layer_st_map, [layer_id], is_first=False, is_last=False)
        else:
            # for mtp layers, embed weight and lm_head weight are needed for shared weights
            ms_layer_weights = read_matched_file(layer_st_map, [layer_id], is_first=True, is_last=True)
        pt_layer_weights = defaultdict()
        if layer_id == 0:
            pt_layer_weights.update(_model_preprocess_ms_to_pt(ms_layer_weights, config))
        pt_layer_weights.update(_mla_ms_to_pt(layer_id, ms_layer_weights, config))
        pt_layer_weights.update(_mlp_ms_to_pt(layer_id, ms_layer_weights, config))
        if layer_id > num_layers - 1:
            pt_layer_weights.update(_mtp_ms_to_pt(layer_id, ms_layer_weights, config))
        if layer_id == total_num_layers - 1:
            pt_layer_weights.update(_model_postprocess_ms_to_pt(ms_layer_weights, config))

        saving_file_name = f"model-{layer_id+1:05d}-of-{total_num_layers:05d}.safetensors"
        for name in list(pt_layer_weights.keys()):
            converted_st_map["weight_map"][name] = saving_file_name
            total_size += get_torch_storage_size(pt_layer_weights.get(name))
        save_file(pt_layer_weights, os.path.join(output_path, saving_file_name))
        print(f"saving weights in layer-{layer_id} to file {saving_file_name}")

    converted_st_map["metadata"]["total_size"] = total_size
    converted_model_index_file = os.path.join(output_path, "model.safetensors.index.json")
    with open(converted_model_index_file, "w") as f:
        json_string = json.dumps(converted_st_map, default=lambda x: x.__dict__, sort_keys=False, indent=2)
        f.write(json_string)
    set_safe_mode_for_file_or_dir(converted_model_index_file)


def convert_ms_to_pt(input_path, output_path, config=None):
    """convert ms weight to huggingface."""
    if config is None:
        config = default_config
    os.makedirs(output_path, exist_ok=True)

    load_format = config['load_format']
    print(f"Loading mindspore checkpoint in '{input_path}' ...")

    if load_format == "ckpt":
        ms_ckpt_convertor(input_path, output_path, config)

    if load_format == "safetensors":
        ms_safetensors_convertor(input_path, output_path, config)

    print("Finish converting mindspore checkpoints into Huggingface checkpoints!")