Xxinlianglalalainternlm fit llama
3aeb4a2b创建于 2023年9月12日历史提交
# Copyright 2023 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

# http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

# ============================================================================

"""

Convert internlm weight.

Support huggingface format format.

"""



import os

import json

import argparse



import mindspore as ms





def read_json(path):

    with open(path, "r") as f:

        return json.load(f)





def name_replace(name: str):

    """replace hf param name to ms."""

    name = name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight')

    name = name.replace('.self_attn.q_proj.', '.attention.wq.')

    name = name.replace('.self_attn.k_proj.', '.attention.wk.')

    name = name.replace('.self_attn.v_proj.', '.attention.wv.')

    name = name.replace('.self_attn.o_proj.', '.attention.wo.')

    name = name.replace('.mlp.gate_proj.', '.feed_forward.w1.')

    name = name.replace('.mlp.down_proj.', '.feed_forward.w2.')

    name = name.replace('.mlp.up_proj.', '.feed_forward.w3.')

    name = name.replace('.input_layernorm.', '.attention_norm.')

    name = name.replace('.post_attention_layernorm.', '.ffn_norm.')

    name = name.replace('.norm.', '.norm_out.')

    return name





def convert_hf_ckpt(ckpt_dir, output_name, dtype=ms.float16):

    """convert hf weight to ms."""

    print(f"Trying to convert huggingface checkpoint in '{ckpt_dir}'.", flush=True)

    try:

        from transformers import AutoModelForCausalLM

        model_hf = AutoModelForCausalLM.from_pretrained(ckpt_dir, trust_remote_code=True)

    # pylint: disable=W0703

    except Exception as e:

        print(f"Do not find huggingface checkpoint in '{ckpt_dir}', Error {e}.", flush=True)

        return False



    ckpt_list = []

    for name, value in model_hf.named_parameters():

        name = name_replace(name)

        if name == 'norm.weight':

            name = 'norm_out.weight'

        if name[:7] == 'layers.':

            name = name[7:]

        value = value.detach().numpy()

        print(f'\rprocessing parameter: {name} {value.shape}     ', end='', flush=True)

        ckpt_list.append({'name': name, 'data': ms.Tensor(value, dtype=dtype)})



    ckpt_file = os.path.join(ckpt_dir, output_name)

    ms.save_checkpoint(ckpt_list, os.path.join(ckpt_file))

    print(f"\rConvert huggingface checkpoint finished, the mindspore checkpoint is saved in '{ckpt_file}'.", flush=True)

    return True





if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--torch_ckpt_dir', default='./internlm-chat/')

    parser.add_argument('--mindspore_ckpt_path', default='./internlm-chat.ckpt')

    args = parser.parse_args()

    convert_hf_ckpt(ckpt_dir=args.torch_ckpt_dir, output_name=args.mindspore_ckpt_path)