import random
import torch
class CogVideoXI2VProcessor:
"""
The I2V Processor of CogVideoX:
1. add noise to first frame
2. encode the first frame
3. random dropout image latent
Args:
config (dict): the processor config
{
"noised_image_all_concat": False,
"noised_image_dropout": 0.05,
"noised_image_input": True
}
"""
def __init__(self, config):
self.noised_image_all_concat = config.get("noised_image_all_concat", False)
self.noised_image_dropout = config.get("noised_image_dropout", 0.05)
self.noised_image_input = config.get("noised_image_input", True)
@staticmethod
def add_noise_to_image(image):
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
sigma = torch.exp(sigma).to(image.dtype)
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
image = image + image_noise
return image
def __call__(self, vae_model, videos, video_latents, images=None, **kwargs):
if images is None:
images = videos[:, 0:1]
images = images.to(videos.dtype)
if random.random() < self.noised_image_dropout:
image_latents = torch.zeros_like(video_latents)
else:
if self.noised_image_input:
images = self.add_noise_to_image(images)
image_latents = vae_model.encode(images, enable_cp=False)
if self.noised_image_all_concat:
image_latents = image_latents.repeat(1, 1, video_latents.size(2), 1, 1)
else:
image_latents = torch.concat(
[
image_latents,
torch.zeros_like(video_latents[:, :, 1:])
],
dim=2
)
return {"masked_video": image_latents}