from typing import Optional, Union, List, Callable
import math
import inspect
import torch
import torch.nn as nn
from mindspeed_mm.tasks.inference.pipeline.pipeline_base import MMPipeline
from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.encode_mixin import MMEncoderMixin
from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.inputs_checks_mixin import InputsCheckMixin
from mindspeed_mm.tasks.inference.pipeline.patchs.sora_patchs import replace_with_fp32_forwards
from mindspeed_mm.utils.mask_utils import MaskProcessor, MaskCompressor, TYPE_TO_STR
from mindspeed_mm.data.data_utils.constants import INPUT_MASK, MASKED_VIDEO
from mindspeed_mm.tasks.inference.pipeline.utils.sora_utils import get_pixel_values, get_mask_type_cond_indices, get_video_transform, get_resize_transform
class OpenSoraPlanPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin):
def __init__(self, vae, text_encoder, tokenizer, scheduler, predict_model, config=None):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.scheduler = scheduler
self.predict_model = predict_model
text_encoder.use_attention_mask = config.use_attention_mask
self.num_frames, self.height, self.width = config.input_size
self.version = config.version
self.model_type = config.model_type
if self.model_type == "i2v":
self.mask_processor = MaskProcessor(min_clear_ratio=0.5, max_clear_ratio=0.5)
self.mask_compressor = MaskCompressor(ae_stride_t=self.vae.vae_scale_factor[0],
ae_stride_h=self.vae.vae_scale_factor[1],
ae_stride_w=self.vae.vae_scale_factor[2])
replace_with_fp32_forwards()
@torch.no_grad()
def __call__(self,
prompt,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt: Optional[str] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
eta: float = 0.0,
num_images_per_prompt: Optional[int] = 1,
guidance_scale: float = 4.5,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
max_sequence_length: Optional[int] = 300,
clean_caption: bool = True,
enable_temporal_attentions: bool = True,
added_cond_kwargs: dict = None,
use_prompt_template: bool = True,
use_prompt_preprocess: bool = True,
device: torch.device = "npu",
**kwargs,
):
if use_prompt_template:
prompt, negative_prompt = self.prompt_template(positive_prompt=prompt, negative_prompt=negative_prompt)
self.text_prompt_checks(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds)
self.generate_params_checks(self.height, self.width)
if self.model_type == "i2v":
conditional_pixel_values_path = kwargs.get("conditional_pixel_values_path", None)
mask_type = kwargs.get("mask_type", None)
conditional_pixel_values_indices = kwargs.get("conditional_pixel_values_indices", None)
crop_for_hw = kwargs.get("crop_for_hw", False)
max_hxw = kwargs.get("max_hxw", 236544)
pixel_values = []
pixel_values_indices = []
for pixel_values_path_i in conditional_pixel_values_path:
self.i2v_prompt_checks(pixel_values_path_i, mask_type)
mask_type, pixel_values_indices_i = get_mask_type_cond_indices(mask_type,
pixel_values_path_i,
conditional_pixel_values_indices,
self.num_frames)
pixel_values_i = get_pixel_values(pixel_values_path_i, self.num_frames)
pixel_values += pixel_values_i
pixel_values_indices.append(pixel_values_indices_i)
min_height = min([pixels.shape[2] for pixels in pixel_values])
min_width = min([pixels.shape[3] for pixels in pixel_values])
resize_transform = get_resize_transform(
ori_height=min_height,
ori_width=min_width,
height=self.height,
width=self.width,
crop_for_hw=crop_for_hw,
max_hxw=max_hxw,
)
video_transform = get_video_transform()
pixel_values = torch.cat([resize_transform(pixels) for pixels in pixel_values])
height, width = pixel_values.shape[-2], pixel_values.shape[-1]
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
do_classifier_free_guidance = guidance_scale > 1.0
if not isinstance(self.tokenizer, list):
prompt_embeds, prompt_embeds_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = self.encode_texts(
prompt=prompt,
negative_prompt=negative_prompt,
device=device,
do_classifier_free_guidance=do_classifier_free_guidance,
max_length=max_sequence_length,
clean_caption=clean_caption,
use_prompt_preprocess=use_prompt_preprocess
)
else:
prompt_embeds, prompt_embeds_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = [], [], [], []
for tokenizer, text_encoder in zip(self.tokenizer, self.text_encoder):
prompt_embed, attention_mask, negative_prompt_embed, negative_attention_mask = self.encode_texts(
tokenizer=tokenizer,
text_encoder=text_encoder,
prompt=prompt,
negative_prompt=negative_prompt,
device=device,
do_classifier_free_guidance=do_classifier_free_guidance,
max_length=tokenizer.model_max_length,
clean_caption=clean_caption,
use_prompt_preprocess=use_prompt_preprocess
)
prompt_embeds.append(prompt_embed)
prompt_embeds_attention_mask.append(attention_mask)
negative_prompt_embeds.append(negative_prompt_embed)
negative_prompt_attention_mask.append(negative_attention_mask)
if do_classifier_free_guidance:
if isinstance(prompt_embeds, list):
prompt_embeds = [
torch.cat([negative_prompt_embed, prompt_embed], dim=0)
for negative_prompt_embed, prompt_embed in zip(negative_prompt_embeds, prompt_embeds)
]
prompt_embeds_attention_mask = [
torch.cat([negative_attention_mask, attention_mask], dim=0)
for negative_attention_mask, attention_mask in zip(negative_prompt_attention_mask, prompt_embeds_attention_mask)
]
else:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_embeds_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_embeds_attention_mask],
dim=0)
if self.model_type == "i2v":
masked_pixel_values, mask = self.get_masked_pixel_values_mask(
pixel_values,
pixel_values_indices,
mask_type,
batch_size,
num_images_per_prompt,
self.num_frames,
height,
width,
video_transform,
prompt_embeds.dtype,
device
)
i2v_kwargs = {MASKED_VIDEO: masked_pixel_values, INPUT_MASK: mask}
latent_channels = self.predict_model.in_channels
batch_size = batch_size * num_images_per_prompt
shape = (
batch_size,
latent_channels,
(math.ceil((int(self.num_frames) - 1) / self.vae.vae_scale_factor[0]) + 1) if int(
self.num_frames) % 2 == 1 else math.ceil(int(self.num_frames) / self.vae.vae_scale_factor[0]),
math.ceil(int(self.height) / self.vae.vae_scale_factor[1]),
math.ceil(int(self.width) / self.vae.vae_scale_factor[2]),
)
latents = self.prepare_latents(
shape,
generator=generator,
device=device,
dtype=prompt_embeds[0].dtype
if isinstance(prompt_embeds, list) else prompt_embeds.dtype,
latents=latents
)
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if not isinstance(prompt_embeds, list):
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1)
if prompt_embeds_attention_mask.ndim == 2:
prompt_embeds_attention_mask = prompt_embeds_attention_mask.unsqueeze(1)
model_kwargs = {"prompt": prompt_embeds,
"added_cond_kwargs": added_cond_kwargs,
"enable_temporal_attentions": enable_temporal_attentions,
"prompt_mask": prompt_embeds_attention_mask,
"return_dict": False}
if self.model_type == "i2v":
model_kwargs.update(i2v_kwargs)
latents = self.scheduler.sample(model=self.predict_model, shape=shape, latents=latents, model_kwargs=model_kwargs,
extra_step_kwargs=extra_step_kwargs)
video = self.decode_latents(latents.to(self.vae.dtype))
return video
def prepare_extra_step_kwargs(self, generator, eta):
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def prompt_template(self, positive_prompt, negative_prompt):
positive_template_list = []
negative_template_list = []
if not negative_prompt:
negative_prompt = ""
if self.version == "v1.2":
positive_template = "(masterpiece), (best quality), (ultra-detailed), {}. emotional, harmonious, vignette, " \
"4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous"
negative_template = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, " \
"fewer digits, cropped, worst quality, low quality, normal quality, " \
"jpeg artifacts, signature, watermark, username, blurry"
elif self.version == "v1.3":
positive_template = """
high quality, high aesthetic, {}
"""
negative_template = """
nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality,
low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.
"""
elif self.version == "v1.5":
positive_template = """
high quality, {}
"""
negative_template = """
Worst quality, Normal quality, Low quality, Low res, Blurry, Jpeg artifacts, Grainy, watermark, banner,
Cropped, Out of frame, Out of focus, Bad anatomy, Bad proportions, Deformed, Disconnected limbs, Disfigured,
username, error, sketch, duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, overexposed, underexposed.
"""
else:
positive_template = "{}"
negative_template = "{}"
if isinstance(positive_prompt, (list, tuple)):
for positive_prompt_i in positive_prompt:
positive_template_i = positive_template.format(positive_prompt_i)
negative_template_i = negative_template + negative_prompt
positive_template_list.append(positive_template_i)
negative_template_list.append(negative_template_i)
return positive_template_list, negative_template_list
else:
positive_template_i = positive_template.format(positive_prompt)
negative_template_i = negative_template + negative_prompt
return [positive_template_i], [negative_template_i]
def get_masked_pixel_values_mask(
self,
conditional_pixel_values,
conditional_pixel_values_indices,
mask_type,
batch_size,
num_samples_per_prompt,
num_frames,
height,
width,
video_transform,
weight_dtype,
device
):
if device is None:
device = getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda')
conditional_pixel_values = conditional_pixel_values.to(device=device, dtype=weight_dtype)
masked_pixel_values = []
mask = []
for i in range(batch_size):
conditional_pixel_values_i = conditional_pixel_values[i].unsqueeze(0)
conditional_pixel_values_indices_i = conditional_pixel_values_indices[i]
if conditional_pixel_values_i.shape[0] == num_frames:
inpaint_cond_data = self.mask_processor(conditional_pixel_values_i, mask_type=mask_type)
masked_pixel_values_i, mask_i = inpaint_cond_data['masked_pixel_values'], inpaint_cond_data['mask']
else:
input_pixel_values = torch.zeros([num_frames, 3, height, width], device=device, dtype=weight_dtype)
input_mask = torch.ones([num_frames, 1, height, width], device=device, dtype=weight_dtype)
input_pixel_values[conditional_pixel_values_indices_i] = conditional_pixel_values_i
input_mask[conditional_pixel_values_indices_i] = 0
masked_pixel_values_i = input_pixel_values * (input_mask < 0.5)
mask_i = input_mask
print('conditional_pixel_values_indices_i', conditional_pixel_values_indices_i)
print('mask_type', TYPE_TO_STR[mask_type])
masked_pixel_values_i = video_transform(masked_pixel_values_i)
masked_pixel_values_i = masked_pixel_values_i.unsqueeze(0).repeat(num_samples_per_prompt, 1, 1, 1,
1).transpose(1, 2).contiguous()
mask_i = mask_i.unsqueeze(0).repeat(num_samples_per_prompt, 1, 1, 1, 1).transpose(1, 2).contiguous()
masked_pixel_values.append(masked_pixel_values_i)
mask.append(mask_i)
masked_pixel_values = torch.cat(masked_pixel_values, dim=0)
mask = torch.cat(mask, dim=0)
masked_pixel_values = masked_pixel_values.to(self.vae.dtype)
masked_pixel_values = self.vae.encode(masked_pixel_values)
mask = self.mask_compressor(mask)
masked_pixel_values = torch.cat(
[masked_pixel_values] * 2) if do_classifier_free_guidance else masked_pixel_values
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_pixel_values = masked_pixel_values.to(weight_dtype)
mask = mask.to(weight_dtype)
return masked_pixel_values, mask