from typing_extensions import Literal
import torch
from safetensors.torch import load_file
from checkpoint.sora_model.sora_model_converter import SoraModelConverter
from checkpoint.sora_model.convert_utils.cfg import ConvertConfig
from checkpoint.sora_model.convert_utils.tp_patterns import TPPattern
from checkpoint.sora_model.convert_utils.save_load_utils import load_pt, save_as_mm
from checkpoint.sora_model.convert_utils.utils import check_method_support
class XMLPFusedRowParallelTP(TPPattern):
def __init__(self, hidden_size=3096) -> None:
self.hidden_size = hidden_size
def split(self, weight, tp_size):
hidden_size = self.hidden_size
wx = weight[:, :hidden_size]
wmlp = weight[:, hidden_size:]
wxs = torch.chunk(wx, tp_size, dim=1)
wmlps = torch.chunk(wmlp, tp_size, dim=1)
weights = [torch.cat([wxs[i], wmlps[i]], dim=1) for i in range(tp_size)]
return weights
def merge(self, weights):
hidden_size = self.hidden_size // len(weights)
wxs = [weight[:, :hidden_size] for weight in weights]
wmlps = [weight[:, hidden_size:] for weight in weights]
weight = torch.cat([
torch.cat(wxs, dim=1),
torch.cat(wmlps, dim=1)
], dim=1)
return weight
class HunyuanVideoConverter(SoraModelConverter):
"""Converter for HunyuanVideo"""
_supported_methods = ["resplit", "layerzero_to_mm", "merge_lora_to_base"]
_enable_tp = True
convert_mapping = {
"txt_in.t_embedder.mlp.0.weight": "txt_in.t_embedder.time_embed.0.weight",
"txt_in.t_embedder.mlp.0.bias": "txt_in.t_embedder.time_embed.0.bias",
"txt_in.t_embedder.mlp.2.weight": "txt_in.t_embedder.time_embed.2.weight",
"txt_in.t_embedder.mlp.2.bias": "txt_in.t_embedder.time_embed.2.bias",
"time_in.mlp.0.weight": "time_in.time_embed.0.weight",
"time_in.mlp.0.bias": "time_in.time_embed.0.bias",
"time_in.mlp.2.weight": "time_in.time_embed.2.weight",
"time_in.mlp.2.bias": "time_in.time_embed.2.bias",
"vector_in.in_layer.weight": "vector_in.fc1.weight",
"vector_in.in_layer.bias": "vector_in.fc1.bias",
"vector_in.out_layer.weight": "vector_in.fc2.weight",
"vector_in.out_layer.bias": "vector_in.fc2.bias",
"guidance_in.mlp.0.weight": "guidance_in.time_embed.0.weight",
"guidance_in.mlp.0.bias": "guidance_in.time_embed.0.bias",
"guidance_in.mlp.2.weight": "guidance_in.time_embed.2.weight",
"guidance_in.mlp.2.bias": "guidance_in.time_embed.2.bias",
"final_layer.linear.weight": "proj_out.weight",
"final_layer.linear.bias": "proj_out.bias",
"final_layer.adaLN_modulation.1.weight": "adaLN_modulation.1.weight",
"final_layer.adaLN_modulation.1.bias": "adaLN_modulation.1.bias"
}
tp_split_mapping = {
"column_parallel_tp": [
"vector_in.fc1.weight",
"vector_in.fc1.bias",
"proj_out.weight",
"proj_out.bias",
"adaLN_modulation.1.weight",
"adaLN_modulation.1.bias",
],
"row_parallel_tp": [
"vector_in.fc2.weight"
],
"qkv_fused_column_tp": []
}
spec_tp_split_mapping = {}
lora_target_modules = [
"linear", "fc1", "fc2", "img_attn_qkv", "img_attn_proj", "txt_attn_qkv", "txt_attn_proj", "linear1_qkv", "linear1_mlp", "linear2", "proj_out"
]
def __init__(self, version: Literal["t2v", "i2v", "t2v-lora", "i2v-lora"] = "t2v") -> None:
self.version = version
self.double_stream_layers = 20
self.single_stream_layers = 40
self.num_heads = 24
self.head_dim = 128
x_mlp_fused_row_parallel_tp_pattern = XMLPFusedRowParallelTP(hidden_size=self.num_heads * self.head_dim)
self.spec_tp_split_mapping = {x_mlp_fused_row_parallel_tp_pattern: []}
for index in range(self.double_stream_layers):
self.tp_split_mapping["column_parallel_tp"] += [
f"double_blocks.{index}.img_mod.linear.weight",
f"double_blocks.{index}.img_mod.linear.bias",
f"double_blocks.{index}.img_mlp.fc1.weight",
f"double_blocks.{index}.img_mlp.fc1.bias",
f"double_blocks.{index}.txt_mod.linear.weight",
f"double_blocks.{index}.txt_mod.linear.bias",
f"double_blocks.{index}.txt_mlp.fc1.weight",
f"double_blocks.{index}.txt_mlp.fc1.bias",
]
self.tp_split_mapping["row_parallel_tp"] += [
f"double_blocks.{index}.img_attn_proj.weight",
f"double_blocks.{index}.img_mlp.fc2.weight",
f"double_blocks.{index}.txt_attn_proj.weight",
f"double_blocks.{index}.txt_mlp.fc2.weight",
]
self.tp_split_mapping["qkv_fused_column_tp"] += [
f"double_blocks.{index}.img_attn_qkv.weight",
f"double_blocks.{index}.img_attn_qkv.bias",
f"double_blocks.{index}.txt_attn_qkv.weight",
f"double_blocks.{index}.txt_attn_qkv.bias",
]
for index in range(self.single_stream_layers):
self.tp_split_mapping["column_parallel_tp"] += [
f"single_blocks.{index}.modulation.linear.weight",
f"single_blocks.{index}.modulation.linear.bias",
f"single_blocks.{index}.linear1_mlp.weight",
f"single_blocks.{index}.linear1_mlp.bias"
]
self.tp_split_mapping["qkv_fused_column_tp"] += [
f"single_blocks.{index}.linear1_qkv.weight",
f"single_blocks.{index}.linear1_qkv.bias"
]
self.spec_tp_split_mapping[x_mlp_fused_row_parallel_tp_pattern] += [
f"single_blocks.{index}.linear2.weight"
]
if self.version == "t2v":
self._supported_methods = ["resplit", "layerzero_to_mm", "merge_lora_to_base", "source_to_mm"]
self._enable_tp = True
self._enable_pp = False
self._enable_vpp = False
elif self.version == "i2v":
self._supported_methods = ["resplit", "layerzero_to_mm", "merge_lora_to_base", "source_to_mm"]
self._enable_tp = False
self._enable_pp = False
self._enable_vpp = False
elif self.version == "t2v-lora":
self._supported_methods = []
self._enable_tp = False
self._enable_pp = False
self._enable_vpp = False
elif self.version == "i2v-lora":
self._supported_methods = ["source_to_mm"]
self._enable_tp = False
self._enable_pp = False
self._enable_vpp = False
i2v_source_lora_prefix = "Hunyuan_video_I2V"
self.str_replace_mapping = {
f"{i2v_source_lora_prefix}_lora_": "",
"single_blocks_": "single_blocks.",
"double_blocks_": "double_blocks.",
"_img_attn_proj": ".img_attn_proj",
"_img_attn_qkv": ".img_attn_qkv",
"_img_mlp_fc": ".img_mlp.fc",
"_txt_mlp_fc": ".txt_mlp.fc",
"_img_mod": ".img_mod",
"_txt": ".txt",
"_modulation": ".modulation",
"_linear": ".linear",
"lora_down": "lora_A.default",
"lora_up": "lora_B.default",
"_individual_token_refiner_blocks_": ".individual_token_refiner.blocks.",
"_mlp_fc": ".mlp.fc",
"vector_in_in_layer": "vector_in.fc1",
"vector_in_out_layer": "vector_in.fc2",
"final_layer.linear": "proj_out",
}
def t2v_text_encoder(self, cfg: ConvertConfig):
"""convert text encooder for t2v task"""
if self.version != "t2v":
raise RuntimeError(f"task: {self.version} do not need to convert text encoder")
from transformers import (
AutoProcessor,
LlavaForConditionalGeneration
)
processor = AutoProcessor.from_pretrained(cfg.source_path)
model = LlavaForConditionalGeneration.from_pretrained(
cfg.source_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
local_files_only=True
)
model.language_model.save_pretrained(cfg.target_path)
processor.tokenizer.save_pretrained(cfg.target_path)
@check_method_support
def source_to_mm(self, cfg: ConvertConfig):
"""convert ckpt from source code to megatron format"""
if cfg.source_path.endswith("safetensors"):
state_dict = load_file(cfg.source_path)
else:
state_dict = load_pt(cfg.source_path, module_name='module')
state_dict = self._replace_state_dict(state_dict, self.convert_mapping, self.str_replace_mapping)
state_dicts = self._mm_split(state_dict, cfg.target_parallel_config)
save_as_mm(cfg.target_path, state_dicts)
def _replace_state_dict(self, state_dict: dict, convert_mapping: dict = None, str_replace_mapping: dict = None):
if self.version.endswith("lora"):
names = list(state_dict.keys())
for name in names:
if ".alpha" in name:
state_dict.pop(name)
state_dict = super()._replace_state_dict(state_dict, convert_mapping, str_replace_mapping)
state_dict = self._split_qkv_mlp_fused_column_linear(state_dict)
return state_dict
def _split_qkv_mlp_fused_column_linear(self, state_dict: dict):
"""split qkv_mlp fused linear in single stream blocks int qkv part and mlp part"""
hidden_size = self.num_heads * self.head_dim
if self.version == "i2v-lora":
names = list(state_dict.keys())
for name in names:
if "linear1" in name:
if ".lora_A" in name:
lora_a = state_dict.pop(name)
state_dict[name.replace("linear1", "linear1_qkv")] = lora_a
state_dict[name.replace("linear1", "linear1_mlp")] = lora_a
else:
lora_b = state_dict.pop(name)
w_qkv = lora_b[:hidden_size * 3]
w_mlp = lora_b[hidden_size * 3:]
state_dict[name.replace("linear1", "linear1_qkv")] = w_qkv
state_dict[name.replace("linear1", "linear1_mlp")] = w_mlp
elif "single_blocks.0.linear1.weight" in state_dict.keys():
for index in range(self.single_stream_layers):
weight1 = state_dict.pop(f"single_blocks.{index}.linear1.weight")
bias1 = state_dict.pop(f"single_blocks.{index}.linear1.bias")
state_dict[f"single_blocks.{index}.linear1_qkv.weight"] = weight1[:hidden_size * 3]
state_dict[f"single_blocks.{index}.linear1_mlp.weight"] = weight1[hidden_size * 3:]
state_dict[f"single_blocks.{index}.linear1_qkv.bias"] = bias1[:hidden_size * 3]
state_dict[f"single_blocks.{index}.linear1_mlp.bias"] = bias1[hidden_size * 3:]
return state_dict