from typing_extensions import Literal
from checkpoint.sora_model.sora_model_converter import SoraModelConverter
from checkpoint.sora_model.convert_utils.save_load_utils import (
load_pt,
save_as_pt,
save_as_mm
)
from checkpoint.sora_model.convert_utils.cfg import ConvertConfig
from checkpoint.sora_model.convert_utils.utils import check_method_support
class LayerIndexConverter:
@staticmethod
def get_layer_index(name):
if name.startswith("videodit_sparse_blocks"):
idx = int(name.split('.')[1])
return idx
return None
@staticmethod
def convert_layer_index(name, new_layer_index):
if name.startswith("videodit_sparse_blocks"):
parts = name.split('.')
parts[1] = str(new_layer_index)
return ".".join(parts)
return name
class OpenSoraPlanConverter(SoraModelConverter):
"""Converter for OpenSoraPlan"""
def __init__(self, version: Literal["v1.2", "v1.3", "v1.5"] = "v1.5") -> None:
super().__init__()
self.version = version
if self.version == "v1.2":
self._supported_methods = ["hf_to_mm", "resplit"]
self.hf_to_mm_str_replace_mapping = {
"transformer_blocks": "videodit_blocks",
"attn1": "self_atten",
"attn2": "cross_atten",
"to_q": "proj_q",
"to_k": "proj_k",
"to_v": "proj_v",
"to_out.0": "proj_out",
"to_out.1": "dropout"
}
self._enable_tp = True
num_layers = 32
for i in range(num_layers):
self.tp_split_mapping["column_parallel_tp"] += [
f"videodit_blocks.{i}.self_atten.proj_q.weight",
f"videodit_blocks.{i}.self_atten.proj_q.bias",
f"videodit_blocks.{i}.cross_atten.proj_q.weight",
f"videodit_blocks.{i}.cross_atten.proj_q.bias",
f"videodit_blocks.{i}.self_atten.proj_k.weight",
f"videodit_blocks.{i}.self_atten.proj_k.bias",
f"videodit_blocks.{i}.cross_atten.proj_k.weight",
f"videodit_blocks.{i}.cross_atten.proj_k.bias",
f"videodit_blocks.{i}.self_atten.proj_v.weight",
f"videodit_blocks.{i}.self_atten.proj_v.bias",
f"videodit_blocks.{i}.cross_atten.proj_v.weight",
f"videodit_blocks.{i}.cross_atten.proj_v.bias",
f"videodit_blocks.{i}.ff.net.0.proj.weight",
f"videodit_blocks.{i}.ff.net.0.proj.bias"
]
self.tp_split_mapping["row_parallel_tp"] += [
f"videodit_blocks.{i}.self_atten.proj_out.weight",
f"videodit_blocks.{i}.cross_atten.proj_out.weight",
f"videodit_blocks.{i}.ff.net.2.weight"
]
elif self.version == "v1.3":
self._supported_methods = ["hf_to_mm", "resplit"]
self.hf_to_mm_str_replace_mapping = {
"transformer_blocks": "videodit_sparse_blocks",
"attn1": "self_atten",
"attn2": "cross_atten",
"to_q": "proj_q",
"to_k": "proj_k",
"to_v": "proj_v",
"to_out.0": "proj_out",
"to_out.1": "dropout",
}
self._enable_tp = True
num_layers = 32
for i in range(num_layers):
self.tp_split_mapping["column_parallel_tp"] += [
f"videodit_sparse_blocks.{i}.self_atten.proj_q.weight",
f"videodit_sparse_blocks.{i}.self_atten.proj_q.bias",
f"videodit_sparse_blocks.{i}.cross_atten.proj_q.weight",
f"videodit_sparse_blocks.{i}.cross_atten.proj_q.bias",
f"videodit_sparse_blocks.{i}.self_atten.proj_k.weight",
f"videodit_sparse_blocks.{i}.self_atten.proj_k.bias",
f"videodit_sparse_blocks.{i}.cross_atten.proj_k.weight",
f"videodit_sparse_blocks.{i}.cross_atten.proj_k.bias",
f"videodit_sparse_blocks.{i}.self_atten.proj_v.weight",
f"videodit_sparse_blocks.{i}.self_atten.proj_v.bias",
f"videodit_sparse_blocks.{i}.cross_atten.proj_v.weight",
f"videodit_sparse_blocks.{i}.cross_atten.proj_v.bias",
f"videodit_sparse_blocks.{i}.ff.net.0.proj.weight",
f"videodit_sparse_blocks.{i}.ff.net.0.proj.bias",
]
self.tp_split_mapping["row_parallel_tp"] += [
f"videodit_sparse_blocks.{i}.self_atten.proj_out.weight",
f"videodit_sparse_blocks.{i}.cross_atten.proj_out.weight",
f"videodit_sparse_blocks.{i}.ff.net.2.weight",
]
self._enable_pp = True
self._enable_vpp = True
self.pre_process_weight_names = [
"adaln_single.emb.timestep_embedder.linear_1.bias",
"adaln_single.emb.timestep_embedder.linear_1.weight",
"adaln_single.emb.timestep_embedder.linear_2.bias",
"adaln_single.emb.timestep_embedder.linear_2.weight",
"adaln_single.linear.bias",
"adaln_single.linear.weight",
"caption_projection.linear_1.bias",
"caption_projection.linear_1.weight",
"caption_projection.linear_2.bias",
"caption_projection.linear_2.weight",
"pos_embed.proj.bias",
"pos_embed.proj.weight",
]
self.post_preprocess_weight_names = ['scale_shift_table', 'proj_out.weight', 'proj_out.bias']
self.layer_index_converter = LayerIndexConverter()
elif self.version == "v1.5":
self._supported_methods = ["source_to_mm", "resplit"]
self.str_replace_mapping = {
"attn1.norm_q": "attn1.norm_proj_q",
"attn1.norm_k": "attn1.norm_proj_k",
"attn1.to_q": "attn1.proj_q",
"attn1.to_k": "attn1.proj_k",
"attn1.to_v": "attn1.proj_v",
"attn1.add_q_proj": "attn1.added_proj_q",
"attn1.add_k_proj": "attn1.added_proj_k",
"attn1.add_v_proj": "attn1.added_proj_v",
"attn1.to_out.0": "attn1.proj_out",
"attn1.to_add_out": "attn1.added_proj_out",
"attn1.norm_added_q": "attn1.norm_added_proj_q",
"attn1.norm_added_k": "attn1.norm_added_proj_k",
}
self._enable_tp = True
num_layers = [2, 4, 6, 8, 6, 4, 2]
self.tp_split_mapping["column_parallel_tp"] += [
"norm_out.linear.weight",
"norm_out.linear.bias"
]
for i, nums in enumerate(num_layers):
for j in range(nums):
self.tp_split_mapping["column_parallel_tp"] += [
f"transformer_blocks.{i}.{j}.attn1.proj_q.weight",
f"transformer_blocks.{i}.{j}.attn1.proj_q.bias",
f"transformer_blocks.{i}.{j}.attn1.proj_k.weight",
f"transformer_blocks.{i}.{j}.attn1.proj_k.bias",
f"transformer_blocks.{i}.{j}.attn1.proj_v.weight",
f"transformer_blocks.{i}.{j}.attn1.proj_v.bias",
f"transformer_blocks.{i}.{j}.attn1.added_proj_q.weight",
f"transformer_blocks.{i}.{j}.attn1.added_proj_q.bias",
f"transformer_blocks.{i}.{j}.attn1.added_proj_k.weight",
f"transformer_blocks.{i}.{j}.attn1.added_proj_k.bias",
f"transformer_blocks.{i}.{j}.attn1.added_proj_v.weight",
f"transformer_blocks.{i}.{j}.attn1.added_proj_v.bias",
f"transformer_blocks.{i}.{j}.ff.net.0.proj.weight",
f"transformer_blocks.{i}.{j}.ff_enc.net.0.proj.weight",
f"transformer_blocks.{i}.{j}.norm1.linear.weight",
f"transformer_blocks.{i}.{j}.norm1.linear.bias",
f"transformer_blocks.{i}.{j}.norm2.linear.weight",
f"transformer_blocks.{i}.{j}.norm2.linear.bias",
]
self.tp_split_mapping["row_parallel_tp"] += [
f"transformer_blocks.{i}.{j}.attn1.proj_out.weight",
f"transformer_blocks.{i}.{j}.attn1.added_proj_out.weight",
f"transformer_blocks.{i}.{j}.ff.net.2.weight",
f"transformer_blocks.{i}.{j}.ff_enc.net.2.weight",
]
else:
raise NotImplementedError(f"version: {version} is not supported for OpenSoraPlanConverter")
@check_method_support
def source_to_mm(self, cfg: ConvertConfig):
state_dict = load_pt(cfg.source_path)
state_dict = self._replace_state_dict(
state_dict,
self.convert_mapping,
self.str_replace_mapping
)
state_dicts = self._mm_split(state_dict, cfg.target_parallel_config)
save_as_mm(cfg.target_path, state_dicts)
def _replace_state_dict(
self,
state_dict: dict,
convert_mapping: dict = None,
str_replace_mapping: dict = None,
):
state_dict = state_dict.get("ema_state_dict", state_dict)
return super()._replace_state_dict(state_dict, convert_mapping, str_replace_mapping)
def vae_convert(
self,
cfg: ConvertConfig,
use_ema_model: bool = True
):
state_dict = load_pt(cfg.source_path)
if (
"ema_state_dict" in state_dict
and len(state_dict["ema_state_dict"]) > 0
and use_ema_model
):
state_dict = state_dict["ema_state_dict"]
state_dict = {key.replace("module.", ""): value for key, value in state_dict.items()}
elif "state_dict" in state_dict:
if "gen_model" in state_dict["state_dict"]:
state_dict = state_dict["state_dict"]["gen_model"]
else:
state_dict = state_dict["state_dict"]
save_as_pt(state_dict, cfg.target_path)