from collections import defaultdict
import torch.nn as nn
from mindspeed_mm.utils.extra_processor.i2v_processors import I2VProcessor
from mindspeed_mm.models.ae.diffusers_ae_model import DiffusersAEModel
from mindspeed_mm.models.ae.casualvae import CausalVAE
from mindspeed_mm.models.ae.wfvae import WFVAE
from mindspeed_mm.models.ae.contextparallel_causalvae import ContextParallelCasualVAE
from mindspeed_mm.models.ae.autoencoder_kl_hunyuanvideo import AutoencoderKLHunyuanVideo
from mindspeed_mm.models.ae.wan_video_vae import WanVideoVAE
from mindspeed_mm.models.ae.stepvideo_vae import StepVideoVae
from mindspeed_mm.models.ae.movqvae import MOVQ
from mindspeed_mm.models.ae.flux_vae import FluxVae
from mindspeed_mm.models.ae.hunyuanvideo_15_vae import AutoencoderKLConv3D
AE_MODEL_MAPPINGS = {
"casualvae": CausalVAE,
"wfvae": WFVAE,
"contextparallelcasualvae": ContextParallelCasualVAE,
"autoencoder_kl_hunyuanvideo": AutoencoderKLHunyuanVideo,
"wan_video_vae": WanVideoVAE,
"stepvideovae": StepVideoVae,
"movqvae": MOVQ,
"flux_vae": FluxVae,
"hunyuanvideo_15_vae": AutoencoderKLConv3D
}
class AEModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config.to_dict()
self.i2v_processor_config = self.config.pop("i2v_processor", None)
if self.i2v_processor_config is not None:
self.i2v_processor = I2VProcessor(self.i2v_processor_config).get_processor()
else:
self.i2v_processor = None
if config.model_id in AE_MODEL_MAPPINGS:
self.model = AE_MODEL_MAPPINGS[config.model_id](**self.config)
else:
self.model = DiffusersAEModel(
model_name=config.model_id, config=self.config
)
def get_model(self):
return self.model
def encode(self, x, **kwargs):
if not isinstance(x, (list, tuple)):
return self._single_encode(x, **kwargs)
video_latents = []
i2v_results = defaultdict(list)
for i, _ in enumerate(x):
kwargs_i = {
key: value[i] if isinstance(value, (list, tuple)) and len(value) == len(x) else value
for key, value in kwargs.items()
}
_video_latents, _i2v_results = self._single_encode(x[i], **kwargs_i)
video_latents.append(_video_latents)
if _i2v_results is None:
continue
for key, value in _i2v_results.items():
i2v_results[key].append(value)
return video_latents, i2v_results
def _single_encode(self, x, **kwargs):
_video_latents = self.model.encode(x)
_i2v_results = None
if self.i2v_processor is not None:
_i2v_results = self.i2v_processor(
vae_model=self.model,
videos=x,
video_latents=_video_latents,
**kwargs
)
return _video_latents, _i2v_results
def decode(self, x):
return self.model.decode(x)
def forward(self, x):
raise NotImplementedError("forward function is not implemented")