from safetensors.torch import load_file
from checkpoint.sora_model.sora_model_converter import SoraModelConverter
from checkpoint.sora_model.convert_utils.cfg import ConvertConfig
from checkpoint.sora_model.convert_utils.utils import check_method_support
from checkpoint.sora_model.convert_utils.save_load_utils import save_as_mm
class OpenSoraConverter(SoraModelConverter):
"""Converter for OpenSora"""
_supported_methods = ["hf_to_mm", "layerzero_to_mm"]
_enable_tp = False
_enable_pp = False
_enable_vpp = False
hf_to_mm_convert_mapping = {
"time_in.in_layer.bias": "time_in.mlp.0.bias",
"time_in.in_layer.weight": "time_in.mlp.0.weight",
"time_in.out_layer.bias": "time_in.mlp.2.bias",
"time_in.out_layer.weight": "time_in.mlp.2.weight",
"vector_in.in_layer.bias": "vector_in.fc1.bias",
"vector_in.in_layer.weight": "vector_in.fc1.weight",
"vector_in.out_layer.bias": "vector_in.fc2.bias",
"vector_in.out_layer.weight": "vector_in.fc2.weight",
}
def __init__(self) -> None:
super().__init__()
double_stream_layers = 19
single_stream_layers = 38
for i in range(double_stream_layers):
self.hf_to_mm_convert_mapping.update({
f"double_blocks.{i}.img_mod.lin.bias": f"double_blocks.{i}.img_mod.linear.bias",
f"double_blocks.{i}.img_mod.lin.weight": f"double_blocks.{i}.img_mod.linear.weight",
f"double_blocks.{i}.img_attn.q_proj.bias": f"double_blocks.{i}.img_attn.proj_q.bias",
f"double_blocks.{i}.img_attn.q_proj.weight": f"double_blocks.{i}.img_attn.proj_q.weight",
f"double_blocks.{i}.img_attn.k_proj.bias": f"double_blocks.{i}.img_attn.proj_k.bias",
f"double_blocks.{i}.img_attn.k_proj.weight": f"double_blocks.{i}.img_attn.proj_k.weight",
f"double_blocks.{i}.img_attn.v_proj.bias": f"double_blocks.{i}.img_attn.proj_v.bias",
f"double_blocks.{i}.img_attn.v_proj.weight": f"double_blocks.{i}.img_attn.proj_v.weight",
f"double_blocks.{i}.img_attn.proj.bias": f"double_blocks.{i}.img_attn.proj_out.bias",
f"double_blocks.{i}.img_attn.proj.weight": f"double_blocks.{i}.img_attn.proj_out.weight",
f"double_blocks.{i}.img_attn.norm.query_norm.scale": f"double_blocks.{i}.img_attn.q_norm.weight",
f"double_blocks.{i}.img_attn.norm.key_norm.scale": f"double_blocks.{i}.img_attn.k_norm.weight",
f"double_blocks.{i}.img_mlp.0.bias": f"double_blocks.{i}.img_mlp.fc1.bias",
f"double_blocks.{i}.img_mlp.0.weight": f"double_blocks.{i}.img_mlp.fc1.weight",
f"double_blocks.{i}.img_mlp.2.bias": f"double_blocks.{i}.img_mlp.fc2.bias",
f"double_blocks.{i}.img_mlp.2.weight": f"double_blocks.{i}.img_mlp.fc2.weight",
f"double_blocks.{i}.txt_mod.lin.bias": f"double_blocks.{i}.txt_mod.linear.bias",
f"double_blocks.{i}.txt_mod.lin.weight": f"double_blocks.{i}.txt_mod.linear.weight",
f"double_blocks.{i}.txt_attn.q_proj.bias": f"double_blocks.{i}.txt_attn.proj_q.bias",
f"double_blocks.{i}.txt_attn.q_proj.weight": f"double_blocks.{i}.txt_attn.proj_q.weight",
f"double_blocks.{i}.txt_attn.k_proj.bias": f"double_blocks.{i}.txt_attn.proj_k.bias",
f"double_blocks.{i}.txt_attn.k_proj.weight": f"double_blocks.{i}.txt_attn.proj_k.weight",
f"double_blocks.{i}.txt_attn.v_proj.bias": f"double_blocks.{i}.txt_attn.proj_v.bias",
f"double_blocks.{i}.txt_attn.v_proj.weight": f"double_blocks.{i}.txt_attn.proj_v.weight",
f"double_blocks.{i}.txt_attn.proj.bias": f"double_blocks.{i}.txt_attn.proj_out.bias",
f"double_blocks.{i}.txt_attn.proj.weight": f"double_blocks.{i}.txt_attn.proj_out.weight",
f"double_blocks.{i}.txt_attn.norm.query_norm.scale": f"double_blocks.{i}.txt_attn.q_norm.weight",
f"double_blocks.{i}.txt_attn.norm.key_norm.scale": f"double_blocks.{i}.txt_attn.k_norm.weight",
f"double_blocks.{i}.txt_mlp.0.bias": f"double_blocks.{i}.txt_mlp.fc1.bias",
f"double_blocks.{i}.txt_mlp.0.weight": f"double_blocks.{i}.txt_mlp.fc1.weight",
f"double_blocks.{i}.txt_mlp.2.bias": f"double_blocks.{i}.txt_mlp.fc2.bias",
f"double_blocks.{i}.txt_mlp.2.weight": f"double_blocks.{i}.txt_mlp.fc2.weight"
})
for i in range(single_stream_layers):
self.hf_to_mm_convert_mapping.update({
f"single_blocks.{i}.norm.query_norm.scale": f"single_blocks.{i}.q_norm.weight",
f"single_blocks.{i}.norm.key_norm.scale": f"single_blocks.{i}.k_norm.weight",
f"single_blocks.{i}.modulation.lin.bias": f"single_blocks.{i}.modulation.linear.bias",
f"single_blocks.{i}.modulation.lin.weight": f"single_blocks.{i}.modulation.linear.weight"
})
@check_method_support
def hf_to_mm(self, cfg: ConvertConfig):
state_dict = load_file(cfg.source_path)
state_dict = self._replace_state_dict(
state_dict,
self.hf_to_mm_convert_mapping,
self.hf_to_mm_str_replace_mapping
)
state_dicts = self._mm_split(state_dict, cfg.target_parallel_config)
save_as_mm(cfg.target_path, state_dicts)