from torch import nn
from .ddpm import DDPM
from .iddpm import IDDPM
from .rflow import RFlow
from .cogvideo_diffusion import CogVideoDiffusion
from .flow_match_discrete_scheduler import FlowMatchDiscreteScheduler
from .wan_flow_match_scheduler import WanFlowMatchScheduler
from .hunyuanvideo_i2v_diffusion import HunyuanVideoI2VDiffusion
from .opensoraplanv1_5_scheduler import OpenSoraPlanScheduler
from .opensora2_flow_match_scheduler import Opensora2FlowMatchScheduler
from .diffusers_scheduler import DIFFUSERS_SCHEDULE_MAPPINGS, DiffusersScheduler
from .hunyuan_video_15_scheduler import Hunyuan_15_FlowMatchDiscreteScheduler
DIFFUSION_MODEL_MAPPINGS = {
"ddpm": DDPM,
"iddpm": IDDPM,
"rflow": RFlow,
"cogvideo_diffusion": CogVideoDiffusion,
"flow_match_discrete_scheduler": FlowMatchDiscreteScheduler,
"wan_flow_match_scheduler": WanFlowMatchScheduler,
"hunyuanvideo_i2v_diffusion": HunyuanVideoI2VDiffusion,
"hunyuanvideo_15_diffusion": Hunyuan_15_FlowMatchDiscreteScheduler,
"OpenSoraPlan": OpenSoraPlanScheduler,
"opensora2_flow_match_scheduler": Opensora2FlowMatchScheduler
}
class DiffusionModel:
"""
Factory class for all customized diffusion models and diffusers schedulers.
Args:
config:
{
"model_id": "ddpm",
"num_timesteps": 1000,
"beta_schedule": "linear",
...
}
"""
def __init__(self, config):
if config.model_id in DIFFUSION_MODEL_MAPPINGS:
model_cls = DIFFUSION_MODEL_MAPPINGS[config.model_id]
self.diffusion = model_cls(**config.to_dict())
else:
self.diffusion = DiffusersScheduler(config.to_dict())
def get_model(self):
return self.diffusion