import torch


class StepVideoI2VProcessor:
    """
    The I2V Processor of StepVideo:
    1. add noise to first frame
    2. encode the first frame
    3. random dropout image latent
    """

    def __init__(self, config):
        super().__init__()

    @staticmethod
    def add_noise_to_image(image):
        sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
        sigma = torch.exp(sigma).to(image.dtype)
        image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
        image = image + image_noise
        return image

    def __call__(self, vae_model, videos, video_latents, images=None, **kwargs):
        if images is None:
            images = videos[:, 0:1]
        images = images.to(videos.dtype)
        images = self.add_noise_to_image(images)
        img_emb = vae_model.encode(images).repeat(videos.size(0), 1, 1, 1, 1).to(videos.device)
        padding_tensor = torch.zeros_like(video_latents[:, 1:])
        condition_hidden_states = torch.cat([img_emb, padding_tensor], dim=1)
        return {"image_latents": condition_hidden_states.to(videos.dtype)}