from torch import nn
from megatron.training.utils import print_rank_0
from megatron.core import mpu
from mindspeed_mm.models.common.checkpoint import load_checkpoint
from .dits import (
Latte,
VideoDitSparse,
SatDiT,
VideoDitSparseI2V,
PTDiT,
HunyuanVideoDiT,
HunyuanVideo15DiT,
WanDiT,
StepVideoDiT,
SparseUMMDiT,
MMDiT,
VACEModel
)
PREDICTOR_MODEL_MAPPINGS = {
"videoditsparse": VideoDitSparse,
"videoditsparsei2v": VideoDitSparseI2V,
"latte": Latte,
"satdit": SatDiT,
"ptdit": PTDiT,
"hunyuanvideodit": HunyuanVideoDiT,
"hunyuanvideo15dit": HunyuanVideo15DiT,
"wandit": WanDiT,
"stepvideodit": StepVideoDiT,
"SparseUMMDiT": SparseUMMDiT,
"mmdit": MMDiT,
"vace": VACEModel
}
class PredictModel(nn.Module):
"""
The backnone of the denoising model
PredictModel is the factory class for all unets and dits
Args:
config[dict]: for Instantiating an atomic methods
"""
def __init__(self, config):
super().__init__()
model_cls = PREDICTOR_MODEL_MAPPINGS[config.model_id]
config = self._build_predictor_layers_config(config)
self.predictor = model_cls(**config.to_dict())
if hasattr(config, "from_pretrained") and config.from_pretrained is not None:
assign = False
if hasattr(self.predictor, "device") and self.predictor.device.type == "meta":
assign = True
load_checkpoint(self.predictor, config.from_pretrained, assign=assign)
print_rank_0("load predictor's checkpoint sucessfully")
if hasattr(self.predictor, "post_init"):
self.predictor.post_init()
def get_model(self):
return self.predictor
def _build_predictor_layers_config(self, config):
if mpu.get_pipeline_model_parallel_world_size() <= 1:
return config
self.pp_size = mpu.get_pipeline_model_parallel_world_size()
self.enable_vpp = mpu.get_virtual_pipeline_model_parallel_world_size() is not None
if self.enable_vpp:
self.vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()
self.vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
self.pp_rank = mpu.get_pipeline_model_parallel_rank()
print(f"current pp_size: {self.pp_size}, pp_rank: {self.pp_rank}",
f"vpp_size:{self.vpp_size}, vpp_rank: {self.vpp_rank}" if self.enable_vpp else None)
if not hasattr(config, "pipeline_num_layers"):
raise ValueError(f"The `pipeline_num_layers` must be specified in the config for pipeline parallel")
if mpu.is_pipeline_first_stage():
if self.enable_vpp:
if sum(sum(pipeline_num_layer) for pipeline_num_layer in config.pipeline_num_layers) != config.num_layers:
raise ValueError(f"The sum of `pipeline_num_layers` must be equal to the `num_layers`")
else:
if sum(config.pipeline_num_layers) != config.num_layers:
raise ValueError(f"The sum of `pipeline_num_layers` must be equal to the `num_layers`")
if self.enable_vpp:
if self.vpp_size != len(config.pipeline_num_layers):
raise ValueError(f"The vp_size {self.vpp_size} must be equal to the number of layers of "
f"pipeline_num_layers {len(config.pipeline_num_layers)}.")
for vp_rank in range(self.vpp_size):
if self.pp_size != len(config.pipeline_num_layers[vp_rank]):
raise ValueError(f"The pp_size {self.pp_size} must be equal to the number of stages of "
f"pipeline_num_layers {len(config.pipeline_num_layers[vp_rank])}.")
if self.vpp_size * self.pp_size != len(config.pipeline_num_layers) * len(config.pipeline_num_layers[0]):
raise ValueError(f"The product of vpp_size and pp_size must be equal to the num of stages in "
f"pipeline_num_layers of predictor config, "
f"but got vpp_size: {self.vpp_size}, pp_size: {self.pp_size}, "
f"and total num of stages is "
f"{len(config.pipeline_num_layers) * len(config.pipeline_num_layers[0])}")
else:
if self.pp_size != len(config.pipeline_num_layers):
raise ValueError(f"The pp_size should be qual to the num of predictor pipeline layers: "
f"{len(config.pipeline_num_layers)}")
if self.enable_vpp:
pipeline_start_idx = (sum(sum(config.pipeline_num_layers[i]) for i in range(self.vpp_rank)) +
sum(config.pipeline_num_layers[self.vpp_rank][:self.pp_rank]))
pipeline_end_idx = pipeline_start_idx + config.pipeline_num_layers[self.vpp_rank][self.pp_rank]
local_num_layers = config.pipeline_num_layers[self.vpp_rank][self.pp_rank]
else:
pipeline_start_idx = sum(config.pipeline_num_layers[:self.pp_rank])
pipeline_end_idx = sum(config.pipeline_num_layers[:self.pp_rank + 1])
local_num_layers = config.pipeline_num_layers[self.pp_rank]
if local_num_layers <= 0:
raise ValueError(f"for pp_rank {self.pp_rank}, the predictor layer is {local_num_layers}, "
f"which is invalid. ")
config.num_layers = local_num_layers
config.pre_process = mpu.is_pipeline_first_stage()
config.post_process = mpu.is_pipeline_last_stage()
config.global_layer_idx = tuple(range(pipeline_start_idx, pipeline_end_idx))
return config