import os
import sys
import types
import logging as logger
import torch
import transformers
from .models import get_megatron_model
from .models import get_huggingface_model
logger.basicConfig(format="")
logger.getLogger().setLevel(logger.INFO)
def add_arguments(parser):
group = parser.add_argument_group(title='Llama-2 HF loader.')
group.add_argument(
'--true-vocab-size',
type=int,
default=None,
help='original size of vocab, if specified will trim padding from embedding table.',
)
group.add_argument(
'--vocab-file',
type=str,
default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.',
)
group.add_argument('--tokenizer-model', required=True, help='Sentencepiece tokenizer model.')
group.add_argument('--megatron-path', type=str, default=None, help='Base directory of deepspeed repository')
group.add_argument("--w-pack", type=bool, help='True is w_pack weight for llm', default=False)
parser.add_argument('--add-qkv-bias', action='store_true', help='Add bias for attention qkv', default=False)
parser.add_argument('--add-dense-bias', action='store_true', help='Add bias for attention dense', default=False)
parser.add_argument('--params-dtype', type=str, help='Set weight dtype', default='fp16')
group.add_argument(
'--make-vocab-size-divisible-by',
type=int,
default=1,
help='Pad the vocab size to be divisible by this value.This is added for computational efficiency reasons.',
)
group.add_argument('--use-mcore-models', action='store_true', help='Use the implementation from megatron core')
group.add_argument('--post-norm', action='store_true', help='post norm after attention or mlp.', default=False)
group.add_argument('--moe-grouped-gemm', action='store_true', help='Usr moe grouped gemm.')
group.add_argument(
"--moe-tp-extend-ep",
action='store_true',
help="use tp group to extend experts parallelism instead of sharding weight tensor of experts in tp group",
)
group.add_argument(
'--spec',
type=str,
default=None,
nargs='*',
help='Specify the <module_location function_name> pair '
'that returns a spec to customize transformer layer, depending on the use case.',
)
group.add_argument("--noop-layers", type=str, default=None, help='Specity the noop layers.')
def verify_transformers_version():
version_idents = transformers.__version__.split('.')
major_version = int(version_idents[0])
minor_version = int(version_idents[1])
if major_version < 4 or minor_version < 31:
raise ValueError("the version transformers should greater or equal 4.31")
def build_metadata(args, margs):
md = types.SimpleNamespace()
md.model_type = args.model_type
md.spec = args.spec
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
md.num_attention_heads = margs.num_attention_heads
md.max_position_embeddings = margs.max_position_embeddings
md.tokenizer_type = margs.tokenizer_type
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.output_layer = margs.untie_embeddings_and_output_weights
md.position_embedding_type = margs.position_embedding_type
md.linear_bias = margs.add_bias_linear
md.norm_has_bias = False
md.add_output_layer_bias = getattr(margs, "add_output_layer_bias", False)
md.swiglu = margs.swiglu
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = margs.vocab_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
md.checkpoint_args = margs
md.consumed_train_samples = 0
md.consumed_valid_samples = 0
md.embed_layernorm = margs.embed_layernorm
md.disable_bias_linear = margs.disable_bias_linear
md.moe_grouped_gemm = margs.moe_grouped_gemm
md.moe_tp_extend_ep = margs.moe_tp_extend_ep
md.spec = margs.spec
md.num_experts = getattr(margs, "num_experts", None)
md.n_shared_experts = getattr(margs, "n_shared_experts", None)
md.qk_layernorm = getattr(margs, "qk_layernorm", False)
md.moe_intermediate_size = getattr(margs, "moe_intermediate_size", None)
md.first_k_dense_replace = getattr(margs, "first_k_dense_replace", None)
md.moe_layer_freq = getattr(margs, "moe_layer_freq", None)
md.multi_latent_attention = getattr(margs, "multi_latent_attention", False)
md.cla_share_factor = getattr(margs, "cla_share_factor", 1)
md.q_lora_rank = getattr(margs, "q_lora_rank", None)
if md.multi_latent_attention:
md.qk_pos_emb_head_dim = getattr(margs, "qk_pos_emb_head_dim", None)
md.qk_head_dim = getattr(margs, "qk_head_dim", None)
md.q_lora_rank = getattr(margs, "q_lora_rank", None)
md.kv_lora_rank = getattr(margs, "kv_lora_rank", None)
md.v_head_dim = getattr(margs, "v_head_dim", None)
if hasattr(margs, 'normalization'):
md.norm_has_bias = margs.normalization == "LayerNorm"
return md
def get_message_preprocess(model, md):
message = {"word embeddings": model.get_embedding_word_embeddings_weight()}
if model.has_embedding_word_embeddings_norm_module():
message["word embeddings norm_w"] = model.get_embedding_word_embeddings_norm_weight()
message["word embeddings norm_b"] = model.get_embedding_word_embeddings_norm_bias()
if md.position_embedding_type == 'learned_absolute':
message["position embeddings"] = model.get_embedding_position_embeddings_weight()
else:
if model.has_embedding_position_embeddings_module():
raise ValueError("model should have position_embeddings")
return message
def get_message_layer_norm(message, model, layer_idx, md, args=None):
message["input norm weight"] = model.get_layers_input_layernorm_weight(layer_idx=layer_idx)
if args.post_norm:
message["post norm weight"] = model.get_layers_self_attention_post_attention_layernorm_weight(
layer_idx=layer_idx
)
message["pre mlp norm weight"] = model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx)
message["post mlp norm weight"] = model.get_layers_self_attention_post_mlp_layernorm_weight(layer_idx=layer_idx)
else:
message["post norm weight"] = model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx)
if md.norm_has_bias:
message["input norm bias"] = model.get_layers_input_layernorm_bias(layer_idx=layer_idx)
message["post norm bias"] = model.get_layers_self_attention_pre_mlp_layernorm_bias(layer_idx=layer_idx)
return message
def get_message_layer_attn(message, model, layer_idx, md=None, args=None):
qkv_weight = []
dense_weight = []
qkv_weight.append(model.get_layers_self_attention_linear_qkv_weight(layer_idx=layer_idx))
dense_weight.append(model.get_layers_self_attention_linear_proj_weight(layer_idx=layer_idx))
if getattr(model.get_args(), "qk_layernorm", False):
if getattr(model.get_args(), "multi_latent_attention", False):
if getattr(model.get_args(), "q_lora_rank", None):
message["q layernorm"] = model.get_layers_self_attention_q_layernorm_weight(layer_idx=layer_idx)
message["kv layernorm"] = model.get_layers_self_attention_kv_layernorm_weight(layer_idx=layer_idx)
else:
message["q layernorm"] = model.get_layers_self_attention_q_layernorm_weight(layer_idx=layer_idx)
message["k layernorm"] = model.get_layers_self_attention_k_layernorm_weight(layer_idx=layer_idx)
if getattr(model.get_args(), "multi_latent_attention", False):
if getattr(model.get_args(), "q_lora_rank", None):
message["linear qb weight"] = model.get_layers_self_attention_linear_q_up_proj_weight(layer_idx=layer_idx)
message["linear kvb weight"] = model.get_layers_self_attention_linear_kv_up_proj_weight(layer_idx=layer_idx)
if args.add_qkv_bias:
message["qkv bias"] = model.get_layers_self_attention_linear_qkv_bias(layer_idx=layer_idx)
if args.add_dense_bias:
message["dense bias"] = model.get_layers_self_attention_linear_proj_bias(layer_idx=layer_idx)
if md.linear_bias:
message["qkv bias"] = model.get_layers_self_attention_linear_qkv_bias(layer_idx=layer_idx)
message["dense bias"] = model.get_layers_self_attention_linear_proj_bias(layer_idx=layer_idx)
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
return message
def _get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1, is_moe_mlp=False, **kwargs):
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
if is_moe_mlp:
mlp_l0_weight.append(model.get_layers_mlp_experts_linear_fc1_weight(layer_idx=layer_idx, **kwargs))
mlp_l1_weight.append(model.get_layers_mlp_experts_linear_fc2_weight(layer_idx=layer_idx, **kwargs))
else:
mlp_l0_weight.append(model.get_layers_mlp_linear_fc1_weight(layer_idx=layer_idx, **kwargs))
mlp_l1_weight.append(model.get_layers_mlp_linear_fc2_weight(layer_idx=layer_idx, **kwargs))
if md.linear_bias:
if is_moe_mlp:
mlp_l0_bias.append(model.get_layers_mlp_experts_linear_fc1_bias(layer_idx=layer_idx, **kwargs))
else:
mlp_l0_bias.append(model.get_layers_mlp_linear_fc1_bias(layer_idx=layer_idx, **kwargs))
if md.swiglu:
for tp_rank in range(tp_size):
mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
else:
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
if md.linear_bias:
if is_moe_mlp:
message["mlp l1 bias"] = model.get_layers_mlp_experts_linear_fc2_bias(layer_idx=layer_idx, **kwargs)
else:
message["mlp l1 bias"] = model.get_layers_mlp_linear_fc2_bias(layer_idx=layer_idx, **kwargs)
if md.swiglu:
for tp_rank in range(tp_size):
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias], dim=0)
message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias], dim=0)
else:
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
return message
def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
margs = model.get_args()
first_k_dense_replace = model.get_first_k_dense_replace()
moe_layer_freq = model.get_moe_layer_freq()
shared_expert_gate = getattr(margs, 'shared_expert_gate', None)
if layer_idx >= first_k_dense_replace and layer_idx % moe_layer_freq == 0:
message["mlp_moe"] = {}
mlp_router_weight = model.get_layers_mlp_router_weight(layer_idx=layer_idx)
message["mlp_moe"]["mlp router weight"] = mlp_router_weight
if shared_expert_gate:
shared_expert_gate = model.get_layers_mlp_shared_experts_gate_weight_module(layer_idx=layer_idx)
message["mlp_moe"]["mlp shared_expert_gate weight"] = shared_expert_gate
if getattr(margs, "n_shared_experts", None) is not None:
shared_experts_fc1_weight = []
shared_experts_fc2_weight = []
shared_experts_fc1_weight.append(model.get_layers_mlp_shared_experts_linear_fc1_weight(layer_idx=layer_idx))
shared_experts_fc2_weight.append(model.get_layers_mlp_shared_experts_linear_fc2_weight(layer_idx=layer_idx))
if md.swiglu:
for tp_rank in range(tp_size):
shared_experts_fc1_weight[tp_rank] = torch.chunk(shared_experts_fc1_weight[tp_rank], 2, dim=0)
message["mlp_moe"]["mlp shared experts linear fc1 weight W"] = torch.cat(
[w[0] for w in shared_experts_fc1_weight], dim=0
)
message["mlp_moe"]["mlp shared experts linear fc1 weight V"] = torch.cat(
[w[1] for w in shared_experts_fc1_weight], dim=0
)
else:
message["mlp_moe"]["mlp shared experts linear fc1 weight"] = torch.cat(shared_experts_fc1_weight, dim=0)
message["mlp_moe"]["mlp shared experts linear fc2 weight"] = torch.cat(shared_experts_fc2_weight, dim=1)
if margs.moe_grouped_gemm:
weight1 = model.get_layers_mlp_experts_weight1_module(layer_idx=layer_idx)
weight2 = model.get_layers_mlp_experts_weight2_module(layer_idx=layer_idx)
message["mlp_moe"]["mlp experts weight1 module"] = weight1
message["mlp_moe"]["mlp experts weight2 module"] = weight2
else:
for expert_idx in range(margs.num_experts):
kwargs = {'expert_idx': expert_idx}
expert = _get_message_layer_mlp({}, model, layer_idx, md=md, tp_size=tp_size, is_moe_mlp=True, **kwargs)
message["mlp_moe"][f"expert {expert_idx}"] = expert
return message
else:
return _get_message_layer_mlp(message, model, layer_idx, md=md, tp_size=tp_size)
def get_message_postprocess(model, md):
message = {
"weight": model.get_final_layernorm_weight(),
}
if md.norm_has_bias:
message["bias"] = model.get_final_layernorm_bias()
return message
def get_message_output_layer(model, md):
message = None
if md.output_layer:
message = {"weight": model.get_output_layer_weight()}
if md.add_output_layer_bias:
message["bias"] = model.get_output_layer_bias()
return message
def to_detach(message):
for key, value in message.items():
if isinstance(message[key], dict):
to_detach(value)
else:
message[key] = value.detach()
def _load_checkpoint(model_provider, queue, args):
verify_transformers_version()
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
model_hf = get_huggingface_model(args)
args_hf = model_hf.get_args()
args_hf.moe_grouped_gemm = args.moe_grouped_gemm
args_hf.moe_tp_extend_ep = args.moe_tp_extend_ep
args_hf.spec = args.spec
model_mg = get_megatron_model(model_provider, args_cmd=args)
model_mg.initialize_megatron_args(args_hf, queue)
model_mg.set_tensor_model_parallel_world_size(model_mg.args.tensor_model_parallel_size)
model_mg.set_expert_model_parallel_world_size(model_mg.args.expert_model_parallel_size)
model_mg.set_pipeline_model_parallel_world_size(model_mg.args.pipeline_model_parallel_size)
model_mg.set_virtual_pipeline_model_parallel_world_size(model_mg.args.virtual_pipeline_model_parallel_size)
model_mg.set_tensor_model_parallel_rank(0)
model_mg.set_pipeline_model_parallel_rank(0)
margs = model_mg.get_args()
md = build_metadata(args, margs)
queue.put(md)
model_hf.get_modules_from_pretrained()
model_mg.get_modules_from_config()
model_mg.update_module(model_hf)
def queue_put(name, msg):
logger.info(f"sending {name}")
msg["name"] = name
queue.put(msg)
message = get_message_preprocess(model_mg, md)
queue_put("embeddings", message)
for layer_idx in range(margs.num_layers):
message = {}
message = get_message_layer_norm(message, model_mg, layer_idx, md, args)
message = get_message_layer_attn(message, model_mg, layer_idx, md, args)
message = get_message_layer_mlp(message, model_mg, layer_idx, md)
to_detach(message)
queue_put(f"transformer layer {layer_idx}", message)
message = get_message_postprocess(model_mg, md)
queue_put("final norm", message)
message = get_message_output_layer(model_mg, md)
if message is not None:
queue_put("output layer", message)
queue.put("done")
def load_checkpoint(model_provider, queue, args):
try:
_load_checkpoint(model_provider, queue, args)
except Exception:
queue.put("exit")
raise