import os
import types
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast
from add_safetensors import find_file_with_pattern, get_weight_map
from example.common.security.path import json_safe_load, json_safe_dump
from example.common.security.path import get_valid_read_path
READ_ONLY_PERMISSION = 0o400
READ_WRITE_PERMISSION = 0o600
MAX_READ_FILE_SIZE_16G = 17179869184
def remove_zero_and_shift(matrix):
n, m = matrix.shape
zero_pos = (matrix == 0).int().argmax(dim=1)
col_indices = torch.arange(m, device=matrix.device).expand(n, -1)
mask = (col_indices != zero_pos.unsqueeze(1))
filtered = matrix[mask].view(n, m - 1)
result = torch.cat([filtered, torch.zeros(n, 1, device=matrix.device)], dim=1)
return result.to(matrix)
def custom_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class SharedHead(nn.Module):
def __init__(self, config):
super().__init__()
self.norm = DeepseekV3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, hidden_states):
normalized_states = self.norm(hidden_states)
logits = self.head(normalized_states)
return logits
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class MTPLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.enorm = DeepseekV3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.hnorm = DeepseekV3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.shared_head = SharedHead(config)
self.eh_proj = nn.Linear(
config.hidden_size * 2,
config.hidden_size,
bias=False
)
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
)
class MTPModel(PreTrainedModel):
def __init__(self, config, model, mtp_layer):
super().__init__(config)
self.model = model.model
self.vocab_size = config.vocab_size
self.lm_head = model.lm_head
self.post_init()
self.mtp_decoder = model.model.layers[-1]
self.mtp_layer = mtp_layer.to(model.device)
self.model.layers = self.model.layers[:-1]
self.model.config.num_hidden_layers = self.model.config.num_hidden_layers - 1
self.config.num_hidden_layers = self.config.num_hidden_layers - 1
self.model.forward = types.MethodType(custom_model_forward, self.model)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pre_hidden_states = outputs[0]
hidden_states = self.model.norm(pre_hidden_states)
logits = self.lm_head(hidden_states)
logits = logits.float()
input_ids_mtp = remove_zero_and_shift(input_ids)
position_ids = torch.arange(
0,
input_ids_mtp.shape[-1],
dtype=torch.long,
device=input_ids.device,
) + 1
position_ids = position_ids.unsqueeze(0)
logits[:, -1, :].argmax(dim=1)
input_ids_mtp[:, -1] = logits[:, -1, :].argmax(dim=1)
input_embeds_mtp = self.mtp_layer.embed_tokens(input_ids_mtp)
input_embeds_mtp = self.mtp_layer.enorm(input_embeds_mtp)
hidden_states_mtp = self.mtp_layer.hnorm(pre_hidden_states)
hidden_states_mtp = torch.cat([input_embeds_mtp, hidden_states_mtp], dim=-1)
hidden_states_mtp = self.mtp_layer.eh_proj(hidden_states_mtp)
attention_mask_mtp = _prepare_4d_causal_attention_mask(
attention_mask,
(input_ids.shape[:2]),
input_embeds_mtp,
0,
)
layer_outputs_mtp = self.mtp_decoder(
hidden_states_mtp,
attention_mask=attention_mask_mtp,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
use_cache=False,
)
logits_mtp = self.mtp_layer.shared_head(layer_outputs_mtp[0])
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def warp_mtp_model(config, base_model, model_path):
"""
从预训练模型加载MTP层权重,并封装为MTP_Model
Args:
config (Config): 模型配置对象
base_model (PreTrainedModel): 原始基础模型
model_path (str): 预训练模型路径
mtp_layer_class (nn.Module): MTP层类(默认: MTP_Layer)
Returns:
MTP_Model: 封装后的MTP模型
"""
mtp_layer = MTPLayer(config)
mtp_safetensor = os.path.join(model_path, "model-00163-of-000163.safetensors")
mtp_safetensor = get_valid_read_path(
mtp_safetensor,
size_max=MAX_READ_FILE_SIZE_16G,
is_dir=False,
check_user_stat=True
)
mtp_weight = load_file(mtp_safetensor)
new_state_dict = {}
for key, value in mtp_weight.items():
new_key = key.replace('model.layers.61.', '')
if new_key in mtp_layer.state_dict().keys():
new_state_dict[new_key] = value
mtp_layer.load_state_dict(new_state_dict)
warpped_model = MTPModel(config, base_model, mtp_layer)
return warpped_model
def post_process_mtp_quant(model_path):
safetensor_to_check = set()
desc_file = find_file_with_pattern(model_path, "quant_model_description*.json")
description_data = json_safe_load(desc_file)
new_state_dict = {}
for key, value in description_data.items():
new_key = key.replace('mtp_decoder', 'model.layers.61').replace('mtp_layer', 'model.layers.61')
new_state_dict[new_key] = value
json_safe_dump(new_state_dict, desc_file, indent=4)
index_file = find_file_with_pattern(model_path, "quant_model_weight_*.index.json")
index_data = json_safe_load(index_file)
map_data = get_weight_map(index_file)
new_state_dict = {}
for key, value in map_data.items():
new_key = key.replace('mtp_decoder', 'model.layers.61').replace('mtp_layer', 'model.layers.61')
if 'model.layers.61' in new_key:
safetensor_to_check.add(value)
new_state_dict[new_key] = value
index_data["weight_map"] = new_state_dict
json_safe_dump(index_data, index_file, indent=4)
for tensor_file in safetensor_to_check:
tensor_path = os.path.join(model_path, tensor_file)
with safe_open(tensor_path, framework="pt") as f:
tensors = {}
for name in f.keys():
new_name = name.replace('mtp_decoder', 'model.layers.61').replace('mtp_layer', 'model.layers.61')
tensors[new_name] = f.get_tensor(name)
os.chmod(tensor_path, READ_WRITE_PERMISSION, follow_symlinks=False)
save_file(tensors, tensor_path)
os.chmod(tensor_path, READ_ONLY_PERMISSION, follow_symlinks=False)