import torch
import safetensors
from safetensors.torch import load_file
from checkpoint.sora_model.sora_model_converter import SoraModelConverter
from checkpoint.sora_model.convert_utils.cfg import ConvertConfig, ParallelConfig
from checkpoint.sora_model.convert_utils.utils import check_method_support, flip_mapping
from checkpoint.sora_model.convert_utils.save_load_utils import save_as_mm, load_from_hf, load_from_mm, save_as_pt
class VACEConverter(SoraModelConverter):
"""Converter for VACE"""
_supported_methods = ["hf_to_mm", "mm_to_hf", "hf_diffusers_to_mm", "mm_to_hf_diffusers"]
_enable_tp = False
_enable_pp = False
_enable_vpp = False
def __init__(self) -> None:
super().__init__()
self.hf_to_mm_convert_mapping = {
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.linear_1.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.linear_1.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.linear_2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.linear_2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"condition_embedder.image_embedder.ff.net.0.proj.weight": "img_emb.proj.1.weight",
"condition_embedder.image_embedder.ff.net.0.proj.bias": "img_emb.proj.1.bias",
"condition_embedder.image_embedder.ff.net.2.weight": "img_emb.proj.3.weight",
"condition_embedder.image_embedder.ff.net.2.bias": "img_emb.proj.3.bias",
"condition_embedder.image_embedder.norm1.weight": "img_emb.proj.0.weight",
"condition_embedder.image_embedder.norm1.bias": "img_emb.proj.0.bias",
"condition_embedder.image_embedder.norm2.weight": "img_emb.proj.4.weight",
"condition_embedder.image_embedder.norm2.bias": "img_emb.proj.4.bias",
"condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
self.hf_to_mm_str_replace_mapping_wan = {
"attn1.norm_q": "self_attn.q_norm",
"attn1.norm_k": "self_attn.k_norm",
"attn2.norm_q": "cross_attn.q_norm",
"attn2.norm_k": "cross_attn.k_norm",
"attn1.to_q.": "self_attn.proj_q.",
"attn1.to_k.": "self_attn.proj_k.",
"attn1.to_v.": "self_attn.proj_v.",
"attn1.to_out.0.": "self_attn.proj_out.",
"attn2.to_q.": "cross_attn.proj_q.",
"attn2.to_k.": "cross_attn.proj_k.",
"attn2.to_v.": "cross_attn.proj_v.",
"attn2.add_k_proj": "cross_attn.k_img",
"attn2.add_v_proj": "cross_attn.v_img",
"attn2.norm_added_k": "cross_attn.k_norm_img",
"attn2.to_out.0.": "cross_attn.proj_out.",
".ffn.net.0.proj.": ".ffn.0.",
".ffn.net.2.": ".ffn.2.",
"scale_shift_table": "modulation",
".norm2.": ".norm3."
}
self.hf_to_mm_str_replace_mapping_vace = {
".proj_in.": ".before_proj.",
".proj_out.": ".after_proj.",
"attn1.norm_q": "wan_dit_block.self_attn.q_norm",
"attn1.norm_k": "wan_dit_block.self_attn.k_norm",
"attn2.norm_q": "wan_dit_block.cross_attn.q_norm",
"attn2.norm_k": "wan_dit_block.cross_attn.k_norm",
"attn1.to_q.": "wan_dit_block.self_attn.proj_q.",
"attn1.to_k.": "wan_dit_block.self_attn.proj_k.",
"attn1.to_v.": "wan_dit_block.self_attn.proj_v.",
"attn1.to_out.0.": "wan_dit_block.self_attn.proj_out.",
"attn2.to_q.": "wan_dit_block.cross_attn.proj_q.",
"attn2.to_k.": "wan_dit_block.cross_attn.proj_k.",
"attn2.to_v.": "wan_dit_block.cross_attn.proj_v.",
"attn2.add_k_proj": "wan_dit_block.cross_attn.k_img",
"attn2.add_v_proj": "wan_dit_block.cross_attn.v_img",
"attn2.norm_added_k": "wan_dit_block.cross_attn.k_norm_img",
"attn2.to_out.0.": "wan_dit_block.cross_attn.proj_out.",
".ffn.net.0.proj.": ".wan_dit_block.ffn.0.",
".ffn.net.2.": ".wan_dit_block.ffn.2.",
"scale_shift_table": "wan_dit_block.modulation",
".norm2.": ".wan_dit_block.norm3."
}
self.hf_civitai_to_diffusers_convert_mapping = {
"text_embedding.0.bias": "condition_embedder.text_embedder.linear_1.bias",
"text_embedding.0.weight": "condition_embedder.text_embedder.linear_1.weight",
"text_embedding.2.bias": "condition_embedder.text_embedder.linear_2.bias",
"text_embedding.2.weight": "condition_embedder.text_embedder.linear_2.weight",
"time_embedding.0.bias": "condition_embedder.time_embedder.linear_1.bias",
"time_embedding.0.weight": "condition_embedder.time_embedder.linear_1.weight",
"time_embedding.2.bias": "condition_embedder.time_embedder.linear_2.bias",
"time_embedding.2.weight": "condition_embedder.time_embedder.linear_2.weight",
"time_projection.1.bias": "condition_embedder.time_proj.bias",
"time_projection.1.weight": "condition_embedder.time_proj.weight",
"img_emb.proj.1.bias": "condition_embedder.image_embedder.ff.net.0.proj.bias",
"img_emb.proj.1.weight": "condition_embedder.image_embedder.ff.net.0.proj.weight",
"img_emb.proj.3.bias": "condition_embedder.image_embedder.ff.net.2.bias",
"img_emb.proj.3.weight": "condition_embedder.image_embedder.ff.net.2.weight",
"img_emb.proj.0.bias": "condition_embedder.image_embedder.norm1.bias",
"img_emb.proj.0.weight": "condition_embedder.image_embedder.norm1.weight",
"img_emb.proj.4.bias": "condition_embedder.image_embedder.norm2.bias",
"img_emb.proj.4.weight": "condition_embedder.image_embedder.norm2.weight",
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
"head.modulation": "scale_shift_table",
"head.head.bias": "proj_out.bias",
"head.head.weight": "proj_out.weight"
}
self.hf_civitai_to_diffusers_replace_mapping = {
".cross_attn.k.": ".attn2.to_k.",
".cross_attn.norm_k.weight": ".attn2.norm_k.weight",
".cross_attn.norm_q.weight": ".attn2.norm_q.weight",
".cross_attn.o.": ".attn2.to_out.0.",
".cross_attn.q.": ".attn2.to_q.",
".cross_attn.v.": ".attn2.to_v.",
".ffn.0.": ".ffn.net.0.proj.",
".ffn.2.": ".ffn.net.2.",
".modulation": ".scale_shift_table",
".norm3.": ".norm2.",
".self_attn.k.": ".attn1.to_k.",
".self_attn.norm_k.weight": ".attn1.norm_k.weight",
".self_attn.norm_q.weight": ".attn1.norm_q.weight",
".self_attn.o.": ".attn1.to_out.0.",
".self_attn.q.": ".attn1.to_q.",
".self_attn.v.": ".attn1.to_v.",
".after_proj.": ".proj_out.",
".before_proj.": ".proj_in.",
".cross_attn.k_img.": ".attn2.add_k_proj.",
".cross_attn.v_img.": ".attn2.add_v_proj.",
".cross_attn.norm_k_img.weight": ".attn2.norm_added_k.weight",
}
@check_method_support
def hf_diffusers_to_mm(self, cfg: ConvertConfig):
state_dict = load_from_hf(cfg.source_path)
self._hf_to_mm_state(cfg, state_dict)
@check_method_support
def mm_to_hf_diffusers(self, cfg: ConvertConfig):
state_dict = load_from_mm(cfg.source_path)
state_dict = self._mm_merge(state_dict)
state_dict = self._mm_to_hf_state(cfg, state_dict)
save_as_pt(state_dict, cfg.target_path)
@check_method_support
def hf_to_mm(self, cfg: ConvertConfig):
state_dict = load_from_hf(cfg.source_path)
state_dict = self._replace_state_dict(
state_dict,
self.hf_civitai_to_diffusers_convert_mapping,
self.hf_civitai_to_diffusers_replace_mapping
)
self._hf_to_mm_state(cfg, state_dict)
@check_method_support
def mm_to_hf(self, cfg: ConvertConfig):
state_dict = load_from_mm(cfg.source_path)
state_dict = self._mm_merge(state_dict)
state_dict = self._mm_to_hf_state(cfg, state_dict)
state_dict = self._replace_state_dict(
state_dict,
flip_mapping(self.hf_civitai_to_diffusers_convert_mapping),
flip_mapping(self.hf_civitai_to_diffusers_replace_mapping)
)
save_as_pt(state_dict, cfg.target_path)
def _hf_to_mm_state(self, cfg: ConvertConfig, state_dict: None):
vace_state_dict = {key: state_dict[key] for key in state_dict if "vace" in key}
wan_state_dict = {key: state_dict[key] for key in state_dict if "vace" not in key}
vace_state_dict = self._replace_state_dict(
vace_state_dict,
self.hf_to_mm_convert_mapping,
self.hf_to_mm_str_replace_mapping_vace
)
wan_state_dict = self._replace_state_dict(
wan_state_dict,
self.hf_to_mm_convert_mapping,
self.hf_to_mm_str_replace_mapping_wan
)
vace_state_dict = {f"vace_dit.{key}": vace_state_dict[key] for key in vace_state_dict}
wan_state_dict = {f"wan_dit.{key}": wan_state_dict[key] for key in wan_state_dict}
new_state_dict = {**vace_state_dict, **wan_state_dict}
new_state_dict = self._mm_split(new_state_dict, cfg.target_parallel_config)
save_as_mm(cfg.target_path, new_state_dict)
def _mm_to_hf_state(self, cfg: ConvertConfig, state_dict: None):
vace_state_dict = {key[9:]: state_dict[key] for key in state_dict if "vace_dit." in key}
wan_state_dict = {key[8:]: state_dict[key] for key in state_dict if "wan_dit." in key}
vace_state_dict = self._replace_state_dict(
vace_state_dict,
flip_mapping(self.hf_to_mm_convert_mapping),
flip_mapping(self.hf_to_mm_str_replace_mapping_vace)
)
wan_state_dict = self._replace_state_dict(
wan_state_dict,
flip_mapping(self.hf_to_mm_convert_mapping),
flip_mapping(self.hf_to_mm_str_replace_mapping_wan)
)
state_dict = {**wan_state_dict, **vace_state_dict}
return state_dict