import torch
from checkpoint.sora_model.convert_utils.tp_patterns import TPPattern
from checkpoint.sora_model.sora_model_converter import SoraModelConverter
from checkpoint.sora_model.convert_utils.cfg import ConvertConfig
from checkpoint.sora_model.convert_utils.utils import check_method_support
from checkpoint.sora_model.convert_utils.save_load_utils import load_from_hf, save_as_mm
class KVfusedColumnTP(TPPattern):
@staticmethod
def split(weight, tp_size):
wk, wv = torch.chunk(weight, 2, dim=0)
wks = torch.chunk(wk, tp_size, dim=0)
wvs = torch.chunk(wv, tp_size, dim=0)
weights = [torch.cat([wks[i], wvs[i]], dim=0) for i in range(tp_size)]
return weights
@staticmethod
def merge(weights):
chunked_weights = [torch.chunk(weight, 2, dim=0) for weight in weights]
wks = [chunk[0] for chunk in chunked_weights]
wvs = [chunk[1] for chunk in chunked_weights]
weight = torch.cat([
torch.cat(wks, dim=0),
torch.cat(wvs, dim=0)
], dim=0)
return weight
class LayerIndexConverter:
@staticmethod
def get_layer_index(name):
if name.startswith("transformer_blocks"):
idx = int(name.split('.')[1])
return idx
return None
@staticmethod
def convert_layer_index(name, new_layer_index):
if name.startswith("transformer_blocks"):
parts = name.split('.')
parts[1] = str(new_layer_index)
return ".".join(parts)
return name
class StepVideoConverter(SoraModelConverter):
"""Converter for StepVideo"""
_supported_methods = ["hf_to_mm", "resplit"]
_enable_tp = True
_enable_pp = True
_enable_vpp = False
hf_to_mm_convert_mapping = {
"pos_embed.proj.bias": "pos_embed.proj.bias",
"pos_embed.proj.weight": "pos_embed.proj.weight",
"scale_shift_table": "scale_shift_table",
"adaln_single.emb.timestep_embedder.linear_1.bias": "adaln_single.emb.timestep_embedder.linear_1.bias",
"adaln_single.emb.timestep_embedder.linear_1.weight": "adaln_single.emb.timestep_embedder.linear_1.weight",
"adaln_single.emb.timestep_embedder.linear_2.bias": "adaln_single.emb.timestep_embedder.linear_2.bias",
"adaln_single.emb.timestep_embedder.linear_2.weight": "adaln_single.emb.timestep_embedder.linear_2.weight",
"caption_projection.linear_1.bias": "caption_projection.linear_1.bias",
"caption_projection.linear_1.weight": "caption_projection.linear_1.weight",
"caption_projection.linear_2.bias": "caption_projection.linear_2.bias",
"caption_projection.linear_2.weight": "caption_projection.linear_2.weight",
"clip_projection.bias": "clip_projection.bias",
"clip_projection.weight": "clip_projection.weight",
"proj_out.bias": "proj_out.bias",
"proj_out.weight": "proj_out.weight"
}
pre_process_weight_names = [
"pos_embed.proj.bias", "pos_embed.proj.weight",
"adaln_single.emb.timestep_embedder.linear_1.bias", "adaln_single.emb.timestep_embedder.linear_1.weight",
"adaln_single.emb.timestep_embedder.linear_2.bias", "adaln_single.emb.timestep_embedder.linear_2.weight",
"adaln_single.linear.weight", "adaln_single.linear.bias",
"caption_projection.linear_1.bias", "caption_projection.linear_1.weight",
"caption_projection.linear_2.bias", "caption_projection.linear_2.weight",
"clip_projection.bias", "clip_projection.weight"
]
post_preprocess_weight_names = [
"scale_shift_table",
"proj_out.bias", "proj_out.weight"
]
tp_split_mapping = {
"column_parallel_tp": [
"adaln_single.linear.weight",
"adaln_single.linear.bias",
],
"row_parallel_tp": [],
"qkv_fused_column_tp": []
}
kv_fused_column_tp = KVfusedColumnTP()
spec_tp_split_mapping = {kv_fused_column_tp: []}
layer_index_converter = LayerIndexConverter()
def __init__(self) -> None:
super().__init__()
num_layers = 48
self.num_heads = 48
for index in range(num_layers):
self.hf_to_mm_convert_mapping.update({
f"transformer_blocks.{index}.attn1.k_norm.weight": f"transformer_blocks.{index}.attn1.k_norm.weight",
f"transformer_blocks.{index}.attn1.q_norm.weight": f"transformer_blocks.{index}.attn1.q_norm.weight",
f"transformer_blocks.{index}.attn1.wo.weight": f"transformer_blocks.{index}.attn1.proj_out.weight",
f"transformer_blocks.{index}.attn1.wqkv.weight": f"transformer_blocks.{index}.attn1.proj_qkv.weight",
f"transformer_blocks.{index}.attn2.k_norm.weight": f"transformer_blocks.{index}.attn2.k_norm.weight",
f"transformer_blocks.{index}.attn2.q_norm.weight": f"transformer_blocks.{index}.attn2.q_norm.weight",
f"transformer_blocks.{index}.attn2.wkv.weight": f"transformer_blocks.{index}.attn2.proj_kv.weight",
f"transformer_blocks.{index}.attn2.wo.weight": f"transformer_blocks.{index}.attn2.proj_out.weight",
f"transformer_blocks.{index}.attn2.wq.weight": f"transformer_blocks.{index}.attn2.proj_q.weight",
f"transformer_blocks.{index}.ff.net.0.proj.weight": f"transformer_blocks.{index}.ff.net.0.proj.weight",
f"transformer_blocks.{index}.ff.net.2.weight": f"transformer_blocks.{index}.ff.net.2.weight",
f"transformer_blocks.{index}.norm1.bias": f"transformer_blocks.{index}.norm1.bias",
f"transformer_blocks.{index}.norm1.weight": f"transformer_blocks.{index}.norm1.weight",
f"transformer_blocks.{index}.norm2.bias": f"transformer_blocks.{index}.norm2.bias",
f"transformer_blocks.{index}.norm2.weight": f"transformer_blocks.{index}.norm2.weight",
f"transformer_blocks.{index}.scale_shift_table": f"transformer_blocks.{index}.scale_shift_table"
})
self.tp_split_mapping["column_parallel_tp"] += [
f"transformer_blocks.{index}.ff.net.0.proj.weight",
f"transformer_blocks.{index}.attn2.proj_q.weight",
]
self.tp_split_mapping["row_parallel_tp"] += [
f"transformer_blocks.{index}.attn1.proj_out.weight",
f"transformer_blocks.{index}.attn2.proj_out.weight",
f"transformer_blocks.{index}.ff.net.2.weight",
]
self.tp_split_mapping["qkv_fused_column_tp"] += [
f"transformer_blocks.{index}.attn1.proj_qkv.weight",
]
self.spec_tp_split_mapping[self.kv_fused_column_tp] += [
f"transformer_blocks.{index}.attn2.proj_kv.weight",
]
@check_method_support
def hf_to_mm(self, cfg: ConvertConfig):
state_dict = load_from_hf(cfg.source_path)
state_dict = self._replace_state_dict(
state_dict,
self.hf_to_mm_convert_mapping,
self.hf_to_mm_str_replace_mapping
)
state_dict = self._xfuse_to_mm(state_dict)
state_dicts = self._mm_split(state_dict, cfg.target_parallel_config)
save_as_mm(cfg.target_path, state_dicts)
def _xfuse_to_mm(self, state_dict):
def head_weight_permute(weight, fuse_num):
weight_per_heads = torch.chunk(weight, self.num_heads)
part_weight_per_heads = [
torch.chunk(weight_per_head, fuse_num, dim=0)
for weight_per_head in weight_per_heads
]
part_weights = []
for i in range(fuse_num):
part_weights.append(
torch.cat([part_weight_per_head[i] for part_weight_per_head in part_weight_per_heads], dim=0)
)
weight = torch.cat(part_weights, dim=0).clone()
return weight
keys = state_dict.keys()
for key in keys:
if key in self.tp_split_mapping["qkv_fused_column_tp"]:
state_dict[key] = head_weight_permute(state_dict[key], fuse_num=3)
elif key in self.spec_tp_split_mapping[self.kv_fused_column_tp]:
state_dict[key] = head_weight_permute(state_dict[key], fuse_num=2)
return state_dict