"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os
import types
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast
from add_safetensors import find_file_with_pattern, get_weight_map
from example.common.security.path import json_safe_load, json_safe_dump
from example.common.security.path import get_valid_read_path
READ_ONLY_PERMISSION = 0o400
READ_WRITE_PERMISSION = 0o600
MAX_READ_FILE_SIZE_16G = 17179869184
def remove_zero_and_shift(matrix):
n, m = matrix.shape
zero_pos = (matrix == 0).int().argmax(dim=1)
col_indices = torch.arange(m, device=matrix.device).expand(n, -1)
mask = (col_indices != zero_pos.unsqueeze(1))
filtered = matrix[mask].view(n, m - 1)
result = torch.cat([filtered, torch.zeros(n, 1, device=matrix.device)], dim=1)
return result.to(matrix)
def custom_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class SharedHead(nn.Module):
def __init__(self, config):
super().__init__()
self.norm = DeepseekV3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, hidden_states):
normalized_states = self.norm(hidden_states)
logits = self.head(normalized_states)
return logits
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class MTPLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.enorm = DeepseekV3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.hnorm = DeepseekV3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.shared_head = SharedHead(config)
self.eh_proj = nn.Linear(
config.hidden_size * 2,
config.hidden_size,
bias=False
)
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
)
class MTPModel(PreTrainedModel):
def __init__(self, config, model, mtp_layer):
super().__init__(config)
self.model = model.model
self.vocab_size = config.vocab_size
self.lm_head = model.lm_head
self.post_init()
self.mtp_decoder = model.model.layers[-1]
self.mtp_layer = mtp_layer.to(model.device)
self.model.layers = self.model.layers[:-1]
self.model.config.num_hidden_layers = self.model.config.num_hidden_layers - 1
self.config.num_hidden_layers = self.config.num_hidden_layers - 1
self.model.forward = types.MethodType(custom_model_forward, self.model)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pre_hidden_states = outputs[0]
hidden_states = self.model.norm(pre_hidden_states)
logits = self.lm_head(hidden_states)
logits = logits.float()
input_ids_mtp = remove_zero_and_shift(input_ids)
position_ids = torch.arange(
0,
input_ids_mtp.shape[-1],
dtype=torch.long,
device=input_ids.device,
) + 1
position_ids = position_ids.unsqueeze(0)
logits[:, -1, :].argmax(dim=1)
input_ids_mtp[:, -1] = logits[:, -1, :].argmax(dim=1)
input_embeds_mtp = self.mtp_layer.embed_tokens(input_ids_mtp)
input_embeds_mtp = self.mtp_layer.enorm(input_embeds_mtp)
hidden_states_mtp = self.mtp_layer.hnorm(pre_hidden_states)
hidden_states_mtp = torch.cat([input_embeds_mtp, hidden_states_mtp], dim=-1)
hidden_states_mtp = self.mtp_layer.eh_proj(hidden_states_mtp)
attention_mask_mtp = _prepare_4d_causal_attention_mask(
attention_mask,
(input_ids.shape[:2]),
input_embeds_mtp,
0,
)
layer_outputs_mtp = self.mtp_decoder(
hidden_states_mtp,
attention_mask=attention_mask_mtp,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
use_cache=False,
)
logits_mtp = self.mtp_layer.shared_head(layer_outputs_mtp[0])
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def warp_mtp_model(config, base_model, model_path):
"""
从预训练模型加载MTP层权重,并封装为MTP_Model
Args:
config (Config): 模型配置对象
base_model (PreTrainedModel): 原始基础模型
model_path (str): 预训练模型路径
mtp_layer_class (nn.Module): MTP层类(默认: MTP_Layer)
Returns:
MTP_Model: 封装后的MTP模型
"""
mtp_layer = MTPLayer(config)
mtp_safetensor = os.path.join(model_path, "model-00163-of-000163.safetensors")
mtp_safetensor = get_valid_read_path(
mtp_safetensor,
size_max=MAX_READ_FILE_SIZE_16G,
is_dir=False,
check_user_stat=True
)
mtp_weight = load_file(mtp_safetensor)
new_state_dict = {}
for key, value in mtp_weight.items():
new_key = key.replace('model.layers.61.', '')
if new_key in mtp_layer.state_dict().keys():
new_state_dict[new_key] = value
mtp_layer.load_state_dict(new_state_dict)
warpped_model = MTPModel(config, base_model, mtp_layer)
return warpped_model
def post_process_mtp_quant(model_path):
safetensor_to_check = set()
desc_file = find_file_with_pattern(model_path, "quant_model_description*.json")
description_data = json_safe_load(desc_file)
new_state_dict = {}
for key, value in description_data.items():
new_key = key.replace('mtp_decoder', 'model.layers.61').replace('mtp_layer', 'model.layers.61')
new_state_dict[new_key] = value
json_safe_dump(new_state_dict, desc_file, indent=4)
index_file = find_file_with_pattern(model_path, "quant_model_weight_*.index.json")
index_data = json_safe_load(index_file)
map_data = get_weight_map(index_file)
new_state_dict = {}
for key, value in map_data.items():
new_key = key.replace('mtp_decoder', 'model.layers.61').replace('mtp_layer', 'model.layers.61')
if 'model.layers.61' in new_key:
safetensor_to_check.add(value)
new_state_dict[new_key] = value
index_data["weight_map"] = new_state_dict
json_safe_dump(index_data, index_file, indent=4)
for tensor_file in safetensor_to_check:
tensor_path = os.path.join(model_path, tensor_file)
with safe_open(tensor_path, framework="pt") as f:
tensors = {}
for name in f.keys():
new_name = name.replace('mtp_decoder', 'model.layers.61').replace('mtp_layer', 'model.layers.61')
tensors[new_name] = f.get_tensor(name)
os.chmod(tensor_path, READ_WRITE_PERMISSION, follow_symlinks=False)
save_file(tensors, tensor_path)
os.chmod(tensor_path, READ_ONLY_PERMISSION, follow_symlinks=False)