import re
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List, Tuple, Any, Dict
import torch
from torch import Tensor
from checkpoint.common.constant import DIGIT_FMT
from checkpoint.common.mm_types import STATE_DICT_T, QV_NAME_T, QKV_NAME_T
def interleaved_qkv_to_concated(megatron_qkv: Tensor, num_key_value_heads: int, split_size: List[int]) -> Tensor:
"""mindspeed-mm中的qkv排布和vllm中qkv的排布不同,因此需要重新转换,qwen2vl的转换逻辑如下:
qkv = [nq1,k1,v1, nq2,k2,v2, nq3,k3,v3, nq4,k4,v4]
new_qkv = [nq1,nq2,nq3,nq4,k1,k2,k3,k4,v1,v2,v3,v4]
"""
groups_qkv = [torch.split(group, split_size) for group in torch.chunk(megatron_qkv, num_key_value_heads)]
return torch.cat([head for group_head in zip(*groups_qkv) for head in group_head])
def concated_qkv_to_interleaved(qkv: Tensor, num_key_value_heads: int, split_size: List[int]) -> Tensor:
"""
qkv : [nq1,nq2,nq3,nq4,k1,k2,k3,k4]
new_qkv : [nq1,k1,v1, nq2,k2,v2, nq3,k3,v3, nq4,k4,v4]
"""
qkv_split = torch.split(qkv, split_size)
qkv_chunk_group = [torch.chunk(i, num_key_value_heads) for i in qkv_split]
return torch.cat([i for j in zip(*qkv_chunk_group) for i in j])
def merge_to_interleaved_qkv(q: Tensor, k: Tensor, v: Tensor, group: int) -> Tensor:
"""原顺序: q=[nq1,nq2,nq3,...], k=[k1,k2,k3,...], v=[v1,v2,v3,...]
转换后顺序: qkv = [nq1,k1,v1,nq2,k2,v2,nq3,k3,v3,...]
这里nq1,group query attention时,多个q共享一个k/v, k1/v1即一个key head对应的head_dim维度
"""
qkv_chunks = [torch.chunk(x, chunks=group, dim=0) for x in [q, k, v]]
return torch.concat([head for qkv_group in zip(*qkv_chunks) for head in qkv_group], dim=0)
def split_to_interleaved_qkv(value: Tensor, num_heads: int, q_size: int, k_size: int, v_size: int) -> Tuple[
Tensor, Tensor, Tensor]:
qkv_chunks = torch.chunk(value, num_heads, dim=0)
q_parts, k_parts, v_parts = zip(*[torch.split(chunk, [q_size, k_size, v_size]) for chunk in qkv_chunks])
return torch.cat(q_parts), torch.cat(k_parts), torch.cat(v_parts)
def get_layer_num(pattern: str, names: Iterable[str]) -> int:
"""在names中找到pattern正则匹配到的数数字的最大值"""
layer_ids = [int(re.findall(pattern, k)[0]) for k in names if re.match(pattern, k)]
if len(layer_ids) == 0:
return 0
return max(layer_ids) + 1
def get_layer_and_expert_num(pattern: str, names: Iterable[str]) -> Tuple[int, int]:
"""仅在含有expert相关的参数时调用,匹配到的第一个数字为层数,第二个数字为专家数"""
layer_num = max(int(re.findall(pattern, k)[0][0]) for k in names if re.match(pattern, k)) + 1
expert_num = max(int(re.findall(pattern, k)[0][1]) for k in names if re.match(pattern, k)) + 1
return layer_num, expert_num
def resize_embedding(tensor: Tensor, new_vocab_size: int, original_vocab_size: int = None) -> Tensor:
"""
调整嵌入权重矩阵的大小
Args:
tensor: 原始嵌入权重,形状为 [vocab_size, hidden_size]
new_vocab_size: 目标词汇表大小
original_vocab_size: 原始词汇表大小(revert时使用)
"""
old_vocab_size, hidden_size = tensor.shape
target_size = original_vocab_size if original_vocab_size is not None else new_vocab_size
if target_size == old_vocab_size:
return tensor.clone()
new_tensor = torch.empty(target_size, hidden_size, device=tensor.device, dtype=tensor.dtype)
copy_size = min(old_vocab_size, target_size)
new_tensor[:copy_size] = tensor[:copy_size].clone()
if target_size > old_vocab_size:
mean = tensor.mean().item()
std = tensor.std().item()
torch.nn.init.normal_(new_tensor[old_vocab_size:], mean=mean, std=std)
return new_tensor
class Operator(ABC):
@abstractmethod
def apply(self, weights: STATE_DICT_T):
"""原地修改weights字典"""
pass
@abstractmethod
def revert(self, weights: STATE_DICT_T):
"""原地修改weights字典"""
pass
class ZeroWeightOp(Operator):
"""
权重置零操作,将权重值赋值为零或向权重字典添加全零权重
权重转换需要把q,k,v权重合并成一个完整的qkv大矩阵。
而由于在Qwen2.5-Omni模型的音频模型中,k没有bias,所以需要将k的bias以全零权重的形式添加到权重字典
"""
def __init__(self, name: str, shape: List[int], layers: int):
"""
初始化权重信息
参数:
- name (str): 权重的名称。
- shape (List[int]): 权重的的形状,以整数列表形式表示。
- layers (int): 模型的层数。
"""
self.name = name
self.shape = shape
self.layers = layers
def apply(self, weights: STATE_DICT_T):
"""向weights字典中添加权重或修改权重为0"""
for layer_num in range(self.layers):
tensor = torch.zeros(self.shape)
name = self.name.replace(DIGIT_FMT, str(layer_num))
weights[name] = tensor
def revert(self, weights: STATE_DICT_T) -> None:
"""反向操作:删除由 apply 添加的零权重键"""
for layer_num in range(self.layers):
name = self.name.replace(DIGIT_FMT, str(layer_num))
weights.pop(name, None)
class RenameOp(Operator):
def __init__(self, patterns: Tuple[Tuple[str, str], ...]):
self.patterns = patterns
@staticmethod
def replace_parentheses(text):
"""
使用序号替换文本中的括号内容。
该方法使用正则表达式匹配文本中所有的括号内容,并使用一个递增的序号替换它们。
序号以反斜杠开头,如 \1, \2 等。
参数:
text -- 待处理的文本字符串。
返回:
替换后的文本字符串。
doctest 示例:
>>> RenameOp.replace_parentheses("这是一个测试 (示例) 和另一个 (例子)")
'这是一个测试 \\1 和另一个 \\2'
>>> RenameOp.replace_parentheses("无括号文本")
'无括号文本'
>>> RenameOp.replace_parentheses("(第一个) 包含多个 (括号内容) 示例")
'\\1 包含多个 \\2 示例'
"""
count = 0
def repl(match):
"""
替换匹配到的括号内容为序号。
参数:
match -- 正则表达式匹配到的对象。
返回:
包含序号的字符串,如 \1, \2 等。
"""
nonlocal count
count += 1
return f'\\{count}'
return re.sub(r'\([^)]*\)', repl, text)
def apply(self, weights: STATE_DICT_T):
mapping = {}
for weight_name in weights:
new_name = weight_name
for pattern, replacement in self.patterns:
replacement = self.replace_parentheses(replacement)
new_name = re.sub(pattern, replacement, new_name)
if new_name != weight_name:
mapping[weight_name] = new_name
for old_name, new_name in mapping.items():
weights[new_name] = weights.pop(old_name)
def revert(self, weights: STATE_DICT_T):
self.patterns = tuple((right, left) for left, right in self.patterns)
self.apply(weights=weights)
class MergeOp(Operator):
def __init__(self, raw_names: Any, new_name: str):
self.raw_names = raw_names
self.new_name = new_name
@abstractmethod
def merge(self, tensors: List[Tensor]) -> Tensor:
pass
@abstractmethod
def split(self, tensor: Tensor) -> List[Tensor]:
pass
def apply(self, weights: STATE_DICT_T):
layer_num = get_layer_num(self.raw_names[0], weights.keys())
for num in range(layer_num):
tensors = [weights.pop(name.replace(DIGIT_FMT, str(num)))
for name in self.raw_names
if name.replace(DIGIT_FMT, str(num)) in weights.keys()]
if tensors:
merged_tensor = self.merge(tensors)
weights[self.new_name.replace(DIGIT_FMT, str(num))] = merged_tensor
def revert(self, weights: STATE_DICT_T) -> None:
"""反向操作:拆分合并后的权重并恢复原始键"""
layer_num = get_layer_num(self.new_name, weights.keys())
for num in range(layer_num):
merged_name = self.new_name.replace(DIGIT_FMT, str(num))
merged_tensor = weights.pop(merged_name)
split_tensors = self.split(merged_tensor)
if len(split_tensors) != len(self.raw_names):
raise ValueError(f"Split tensor count {len(split_tensors)} != raw names count {len(self.raw_names)}")
for raw_name, tensor in zip(self.raw_names, split_tensors):
weights[raw_name.replace(DIGIT_FMT, str(num))] = tensor
class UpGateMergeOp(MergeOp):
"""GLU门控单元中权重合并逻辑,从transformers的gate up两个小矩阵到megatron的mlp中linear_fc1中的一个大矩阵是直接concat的"""
def merge(self, tensors: List[Tensor]) -> Tensor:
return torch.concat(tensors, dim=0)
def split(self, tensor: Tensor) -> List[Tensor]:
return list(torch.chunk(tensor, chunks=2, dim=0))
class QKVDirectMergeOp(MergeOp):
"""qkv中权重合并逻辑,从transformers的q1q2/k1k2/v1v2(或q1q2/k1k2v1v2)独立小矩阵到megatron的linear_qkv一个大矩阵是q1q2k1k2v1v2直接拼接的"""
def merge(self, tensors: List[Tensor]) -> Tensor:
return torch.concat(tensors, dim=0)
def split(self, tensor: Tensor) -> List[Tensor]:
raise NotImplementedError("The method has not yet been implemented")
class ExpertUpGateMergeOp(UpGateMergeOp):
def apply(self, weights: STATE_DICT_T):
layer_num, expert_num = get_layer_and_expert_num(self.raw_names[0], weights.keys())
def replace_layer_expert(name, id_layer, id_expert):
return name.replace(DIGIT_FMT, str(id_layer), 1).replace(DIGIT_FMT, str(id_expert), 1)
for layer_id in range(layer_num):
for expert_id in range(expert_num):
tensors = [weights.pop(replace_layer_expert(name, layer_id, expert_id))
for name in self.raw_names
if replace_layer_expert(name, layer_id, expert_id) in weights.keys()]
if tensors:
merged_tensor = self.merge(tensors)
weights[replace_layer_expert(self.new_name, layer_id, expert_id)] = merged_tensor
class QKVMergeOp(MergeOp):
"""qkv中权重合并逻辑,从transformers的q1q2/k1k2/v1v2三个小矩阵到megatron的linear_qkv一个大矩阵是q1k1v1q2k2v2交织拼接的"""
def __init__(self, raw_names: QKV_NAME_T, new_name: str, group: int, q_size: int, k_size: int, v_size: int):
self.group = group
self.q_size = q_size
self.k_size = k_size
self.v_size = v_size
super().__init__(raw_names, new_name)
def merge(self, tensors: QKV_NAME_T) -> Tensor:
return merge_to_interleaved_qkv(*tensors, group=self.group)
def split(self, tensor: Tensor) -> List[Tensor]:
qkv = split_to_interleaved_qkv(tensor, self.group, self.q_size, self.k_size, self.v_size)
return list(qkv)
class QVToQKVMergeOp(MergeOp):
"""构造k权重,并执行qkv中权重合并逻辑,从transformers的q1q2/k1k2/v1v2三个小矩阵到megatron的linear_qkv一个大矩阵是q1k1v1q2k2v2交织拼接的"""
def __init__(self, raw_names: QV_NAME_T, new_name: str, group: int, q_size: int, k_size: int, v_size: int):
self.group = group
self.q_size = q_size
self.k_size = k_size
self.v_size = v_size
super().__init__(raw_names, new_name)
def merge(self, tensors: List[Tensor]) -> Tensor:
tensors = [tensors[0], torch.zeros_like(tensors[0]), tensors[1]]
return merge_to_interleaved_qkv(*tensors, group=self.group)
def split(self, tensor: Tensor) -> List[Tensor]:
qkv = split_to_interleaved_qkv(tensor, self.group, self.q_size, self.k_size, self.v_size)
return [qkv[0], qkv[2]]
class RelocateOp(Operator):
"""处理 QKV 权重在直接拼接和交织结构之间的转换,对于vit,transformers和megatron都是一个大矩阵,但transformer中的qkv是q1q2k1k2v1v2直接拼接,而megatron中是q1k1v1q2k2v2交织拼接"""
def __init__(self, name: str, new_name: str, group: int, split_size: List[int]):
"""
Args:
name: 权重名称模板(如 "decoder.layers.{}.self_attention.linear_qkv.weight")
group: 分组数(如多头注意力的头数)
split_size: 拆分维度 [q_size, k_size, v_size]
"""
self.name = name
self.new_name = new_name
self.group = group
self.split_size = split_size
def apply(self, weights: STATE_DICT_T) -> None:
"""正向操作:将直接拼接的 Q/K/V 转换为交织结构"""
self._apply_transformation(weights, concated_qkv_to_interleaved)
def revert(self, weights: STATE_DICT_T) -> None:
self.name = self.new_name
"""反向操作:将交织结构恢复为直接拼接的 Q/K/V"""
self.split_size = [i // self.group for i in self.split_size]
self._apply_transformation(weights, interleaved_qkv_to_concated)
def _apply_transformation(
self,
weights: STATE_DICT_T,
transform_func: Callable[[Tensor, int, List[int]], Tensor]
) -> None:
"""统一处理权重重新排布逻辑"""
layer_num = get_layer_num(self.name, weights.keys())
for num in range(layer_num):
name = self.name.replace(DIGIT_FMT, str(num))
tensor = weights.get(name)
if tensor is not None:
weights[name] = transform_func(tensor, self.group, self.split_size)
class TieOp(Operator):
"""对于tie word embedding,其output layer是直接复用的word_embedding的权重"""
def __init__(self, raw_name: str, new_name: str):
self.raw_name = raw_name
self.new_name = new_name
def apply(self, weights: STATE_DICT_T):
weights[self.new_name] = weights.get(self.raw_name)
def revert(self, weights: STATE_DICT_T) -> None:
pass
class ResizeEmbedOp(Operator):
"""调整词嵌入权重(word_embeddings.weight)的形状至指定词汇表大小"""
def __init__(self, name: str, new_vocab_size: int):
"""Args: name:嵌入权重名称, new_vocab_size:目标词汇表大小"""
self.name = name
self.new_vocab_size = new_vocab_size
self.original_vocab_size = None
def apply(self, weights: STATE_DICT_T) -> None:
"""正向操作:将嵌入权重调整为新词汇表大小"""
if self.name not in weights:
raise KeyError(f"weight: {self.name} not exist in state dict")
tensor = weights[self.name]
if tensor.ndim != 2:
raise ValueError(f"weight: {self.name} must be 2D tensor, but the actual shape is {tensor.shape}")
if self.original_vocab_size is None:
self.original_vocab_size = tensor.shape[0]
weights[self.name] = resize_embedding(
tensor,
new_vocab_size=self.new_vocab_size
)
def revert(self, weights: STATE_DICT_T) -> None:
"""反向操作:将嵌入权重恢复为原始大小"""
if self.name not in weights:
raise KeyError(f"weight: {self.name} not exist in state dict")
if self.original_vocab_size is None:
raise RuntimeError("No apply operation has been performed, so revert cannot be executed. ")
tensor = weights[self.name]
weights[self.name] = resize_embedding(
tensor,
new_vocab_size=None,
original_vocab_size=self.original_vocab_size
)
class BaseSplit(ABC):
"""Abstract base class for tensor splitting strategies."""
@staticmethod
@abstractmethod
def split(tp_size: int, value: Tensor) -> List[Tensor]:
"""Split tensor into TP_SIZE partitions.
Args:
tp_size: Number of tensor parallel partitions
value: Tensor to split
Returns:
List of split tensors, one per TP rank
"""
pass
@staticmethod
@abstractmethod
def merge(tp_values: List[Tensor]) -> Tensor:
"""Merge split tensors back into original tensor.
Args:
tp_values: List of tensors from TP ranks
Returns:
Merged tensor
"""
pass
class ColSplit(BaseSplit):
"""按列切分权重"""
@staticmethod
def split(tp_size: int, value: Tensor) -> List[Tensor]:
size_per_tp = value.shape[1] // tp_size
return [
value[:, rank * size_per_tp:(rank + 1) * size_per_tp]
for rank in range(tp_size)
]
@staticmethod
def merge(tp_values: List[Tensor]) -> Tensor:
return torch.cat(tp_values, dim=1)
class RowSplit(BaseSplit):
"""按行切分偏置"""
@staticmethod
def split(tp_size: int, value: Tensor) -> List[Tensor]:
size_per_tp = value.shape[0] // tp_size
return [
value[rank * size_per_tp:(rank + 1) * size_per_tp]
for rank in range(tp_size)
]
@staticmethod
def merge(tp_values: List[Tensor]) -> Tensor:
return torch.cat(tp_values, dim=0)
def unaligned_divide(numerator: int, world_size: int, rank: int) -> int:
res = numerator // world_size
if rank < numerator % world_size:
res += 1
return res
class UnalignedEmbeddingSplit(BaseSplit):
"""按行切分嵌入权重,处理不均匀分布的情况"""
@staticmethod
def split(tp_size: int, value: Tensor) -> List[Tensor]:
value_per_tp = []
start_idx = 0
for tp_rank in range(tp_size):
size_per_tp = unaligned_divide(value.shape[0], tp_size, tp_rank)
value_per_tp.append(value[start_idx:start_idx + size_per_tp])
start_idx += size_per_tp
return value_per_tp
@staticmethod
def merge(tp_values: List[Tensor]) -> Tensor:
return torch.cat(tp_values, dim=0)
class UnalignedRowSplit(BaseSplit):
"""按行切分权重,处理不均匀分布的情况"""
def __init__(self, num_attention_heads: int):
self.head = num_attention_heads
def split(self, tp_size: int, value: Tensor) -> List[Tensor]:
values = torch.chunk(value, 3, dim=0)
value_per_tp = []
start_idx = 0
for tp_rank in range(tp_size):
heads_per_tp = unaligned_divide(self.head, tp_size, tp_rank)
size_per_tp = heads_per_tp * (value.shape[0] // 3 // self.head)
w_q = values[0][start_idx:start_idx + size_per_tp]
w_k = values[1][start_idx:start_idx + size_per_tp]
w_v = values[2][start_idx:start_idx + size_per_tp]
value_per_tp.append(torch.cat((w_q, w_k, w_v), dim=0))
start_idx += size_per_tp
return value_per_tp
@staticmethod
def merge(tp_values: List[Tensor]) -> Tensor:
return torch.cat(tp_values, dim=0)
class UnalignedColSplit(BaseSplit):
"""按列切分权重,处理不均匀分布的情况"""
def __init__(self, num_attention_heads: int):
self.head = num_attention_heads
def split(self, tp_size: int, value: Tensor) -> List[Tensor]:
value_per_tp = []
start_idx = 0
for tp_rank in range(tp_size):
heads_per_tp = unaligned_divide(self.head, tp_size, tp_rank)
size_per_tp = heads_per_tp * (value.shape[1] // self.head)
value_per_tp.append(value[:, start_idx: start_idx + size_per_tp])
start_idx += size_per_tp
return value_per_tp
@staticmethod
def merge(tp_values: List[Tensor]) -> Tensor:
return torch.cat(tp_values, dim=1)
class GLUSplit(BaseSplit):
"""GLU 通用切分基类"""
@staticmethod
def split(tp_size: int, value: Tensor) -> List[Tensor]:
size_per_tp = value.shape[0] // tp_size // 2
gate, up = torch.chunk(value, 2, dim=0)
return [
torch.cat((
gate[rank * size_per_tp: (rank + 1) * size_per_tp],
up[rank * size_per_tp: (rank + 1) * size_per_tp]
), dim=0)
for rank in range(tp_size)
]
@staticmethod
def merge(tp_values: List[Tensor]) -> Tensor:
split_results = [torch.chunk(v, 2, dim=0) for v in tp_values]
all_gates = [chunk[0] for chunk in split_results]
all_ups = [chunk[1] for chunk in split_results]
return torch.cat([
torch.cat(all_gates, dim=0),
torch.cat(all_ups, dim=0)
], dim=0)
class ExpertSplitOp(Operator):
"""moe experts split op"""
def __init__(self, raw_name: Any, new_name: str, num_experts: int):
self.raw_name = raw_name
self.new_name = new_name
self.num_experts = num_experts
def apply(self, weights: STATE_DICT_T):
if self.num_experts > 0:
experts_tensor = weights.pop(self.raw_name)
for num in range(self.num_experts):
split_tensors = experts_tensor[num]
weights[self.new_name.replace(DIGIT_FMT, str(num))] = split_tensors.T
def revert(self, weights: STATE_DICT_T):
if self.num_experts > 0:
split_tensors_list = []
for num in range(self.num_experts):
split_tensors = weights.pop(self.new_name.replace(DIGIT_FMT, str(num))).T
split_tensors_list.append(split_tensors)
ori_dtype = split_tensors_list[0].dtype
need_trans_dtype = ori_dtype == torch.bfloat16
if need_trans_dtype:
split_tensors_list = [x.to(torch.float32) for x in split_tensors_list]
merged_tensor = torch.stack(split_tensors_list, dim=0)
if need_trans_dtype:
merged_tensor = merged_tensor.to(ori_dtype)
weights[self.raw_name] = merged_tensor
TP_PATTERN_T = Dict[str, BaseSplit]