from typing import Optional, Tuple, Dict
from einops import rearrange, repeat
import torch
from torch import nn
import torch.nn.functional as F
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, PixArtAlphaTextProjection
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormSingle
from megatron.core import mpu, tensor_parallel
from megatron.training import get_args
from megatron.legacy.model.enums import AttnType
from mindspeed_mm.models.common.ffn import FeedForward
from mindspeed_mm.models.common.communications import split_forward_gather_backward, gather_forward_split_backward
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.models.common.embeddings.patch_embeddings import VideoPatchEmbed2D
from mindspeed_mm.models.common.attention import ParallelAttention, ParallelMultiHeadAttentionSBH
from mindspeed_mm.models.common.embeddings.pos_embeddings import RoPE3DSORA
class VideoDiT(MultiModalModule):
"""
A video dit model for video generation. can process both standard continuous images of shape
(batch_size, num_channels, width, height) as well as quantized image embeddings of shape
(batch_size, num_image_vectors). Define whether input is continuous or discrete depending on config.
Args:
num_layers: The number of layers for VideoDiTBlock.
num_heads: The number of heads to use for multi-head attention.
head_dim: The number of channels in each head.
in_channels: The number of channels in the input (specify if the input is continuous).
out_channels: The number of channels in the output.
dropout: The dropout probability to use.
cross_attention_dim: The number of prompt dimensions to use.
attention_bias: Whether to use bias in VideoDiTBlock's attention.
input_size: The shape of the latents (specify if the input is discrete).
patch_size: The shape of the patchs.
activation_fn: The name of activation function use in VideoDiTBlock.
norm_type: can be 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'.
num_embeds_ada_norm: The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings
that are added to the hidden states.
norm_elementswise_affine: Whether to use learnable elementwise affine parameters for normalization.
norm_eps: The eps of he normalization.
use_rope: Whether to use rope in attention block.
interpolation_scale: The scale for interpolation.
"""
def __init__(
self,
num_layers: int = 1,
num_heads: int = 16,
head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
fa_layout: str = "sbh",
input_size: Tuple[int] = None,
patch_size_thw: Tuple[int] = None,
activation_fn: str = "geglu",
norm_type: str = "layer_norm",
num_embeds_ada_norm: Optional[int] = None,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
caption_channels: int = None,
use_rope: bool = False,
interpolation_scale: Tuple[float] = None,
**kwargs
):
super().__init__(config=None)
if patch_size_thw is not None:
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
raise NotImplementedError(
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
)
elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None:
raise ValueError(
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
)
self.patch_size_t, self.patch_size_h, self.patch_size_w = patch_size_thw
self.norm_type = norm_type
self.in_channels = in_channels
self.out_channels = out_channels
self.num_layers = num_layers
inner_dim = num_heads * head_dim
args = get_args()
self.sequence_parallel = args.sequence_parallel
if mpu.get_tensor_model_parallel_world_size() <= 1:
self.sequence_parallel = False
self.recompute_granularity = args.recompute_granularity
self.distribute_saved_activations = args.distribute_saved_activations
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.task = getattr(args.mm.model, "task", "t2v")
if self.recompute_granularity == "selective":
raise ValueError("recompute_granularity does not support selective mode in VideoDiT")
if self.distribute_saved_activations:
raise NotImplementedError("distribute_saved_activations is currently not supported")
if mpu.get_context_parallel_world_size() > 1:
self.enable_context_parallelism = True
else:
self.enable_context_parallelism = False
self.pos_embed = VideoPatchEmbed2D(
num_frames=input_size[0],
height=input_size[1],
width=input_size[2],
patch_size_t=self.patch_size_t,
patch_size=self.patch_size_h,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=(interpolation_scale[1], interpolation_scale[2]),
interpolation_scale_t=interpolation_scale[0],
use_abs_pos=not use_rope,
)
self.rope = RoPE3DSORA(
head_dim=head_dim,
interpolation_scale=interpolation_scale
)
self.videodit_blocks = nn.ModuleList(
[
VideoDiTBlock(
dim=inner_dim,
num_heads=num_heads,
head_dim=head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
fa_layout=fa_layout,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
rope=self.rope,
interpolation_scale=interpolation_scale,
enable_context_parallelism=self.enable_context_parallelism,
sequence_parallel=self.sequence_parallel,
)
for _ in range(num_layers)
]
)
if norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, self.patch_size_t * self.patch_size_h * self.patch_size_w * self.out_channels)
elif norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim ** 0.5)
self.proj_out = nn.Linear(inner_dim, self.patch_size_t * self.patch_size_h * self.patch_size_w * self.out_channels)
setattr(self.scale_shift_table, "sequence_parallel", self.sequence_parallel)
self.adaln_single = None
if norm_type == "ada_norm_single":
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
for param in self.adaln_single.parameters():
setattr(param, "sequence_parallel", self.sequence_parallel)
self.caption_projection = None
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
def forward(
self,
latents: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
prompt: Optional[torch.Tensor] = None,
video_mask: Optional[torch.Tensor] = None,
prompt_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.Tensor] = None,
use_image_num: Optional[int] = 0,
**kwargs
) -> torch.Tensor:
"""
Args:
latents: Shape (batch size, num latent pixels) if discrete, shape (batch size, channel, height, width) if continuous.
timestep: Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm.
prompt: Conditional embeddings for cross attention layer.
video_mask: An attention mask of shape (batch, key_tokens) is applied to latents.
prompt_mask: Cross-attention mask applied to prompt.
added_cond_kwargs: resolution or aspect_ratio.
class_labels: Used to indicate class labels conditioning.
use_image_num: The number of images use for training.
"""
batch_size, _, t, _, _ = latents.shape
frames = t - use_image_num
vid_mask, img_mask = None, None
prompt_mask = prompt_mask.view(batch_size, -1, prompt_mask.shape[-1])
if video_mask is not None and video_mask.ndim == 4:
video_mask = video_mask.to(self.dtype)
vid_mask = video_mask[:, :frames]
img_mask = video_mask[:, frames:]
if vid_mask.numel() > 0:
vid_mask_first_frame = vid_mask[:, :1].repeat(1, self.patch_size_t - 1, 1, 1)
vid_mask = torch.cat([vid_mask_first_frame, vid_mask], dim=1)
vid_mask = vid_mask.unsqueeze(1)
vid_mask = F.max_pool3d(vid_mask, kernel_size=(self.patch_size_t, self.patch_size_h, self.patch_size_w),
stride=(self.patch_size_t, self.patch_size_h, self.patch_size_w))
vid_mask = rearrange(vid_mask, 'b 1 t h w -> (b 1) 1 (t h w)')
if img_mask.numel() > 0:
img_mask = F.max_pool2d(img_mask, kernel_size=(self.patch_size_h, self.patch_size_w),
stride=(self.patch_size_h, self.patch_size_w))
img_mask = rearrange(img_mask, 'b i h w -> (b i) 1 (h w)')
vid_mask = (1 - vid_mask.bool().to(self.dtype)) * -10000.0 if vid_mask.numel() > 0 else None
img_mask = (1 - img_mask.bool().to(self.dtype)) * -10000.0 if img_mask.numel() > 0 else None
if frames == 1 and use_image_num == 0 and not self.enable_context_parallelism:
img_mask = vid_mask
vid_mask = None
if prompt_mask is not None and prompt_mask.ndim == 3:
prompt_mask = (1 - prompt_mask.to(self.dtype)) * -10000.0
in_t = prompt_mask.shape[1]
prompt_vid_mask = prompt_mask[:, :in_t - use_image_num]
prompt_vid_mask = rearrange(prompt_vid_mask, 'b 1 l -> (b 1) 1 l') if prompt_vid_mask.numel() > 0 else None
prompt_img_mask = prompt_mask[:, in_t - use_image_num:]
prompt_img_mask = rearrange(prompt_img_mask, 'b i l -> (b i) 1 l') if prompt_img_mask.numel() > 0 else None
if frames == 1 and use_image_num == 0 and not self.enable_context_parallelism:
prompt_img_mask = prompt_vid_mask
prompt_vid_mask = None
if vid_mask is not None:
vid_mask = vid_mask.bool().repeat(1, vid_mask.shape[-1], 1)
prompt_vid_mask = prompt_vid_mask.bool().repeat(1, vid_mask.shape[-2], 1)
else:
prompt_vid_mask = prompt_vid_mask.bool()
if img_mask is not None:
img_mask = img_mask.bool().repeat(1, img_mask.shape[-1], 1)
prompt_img_mask = prompt_img_mask.bool().repeat(1, img_mask.shape[-2], 1)
if self.task == "t2v" and not torch.any(vid_mask.bool()):
vid_mask = None
frames = ((frames - 1) // self.patch_size_t + 1) if frames % 2 == 1 else frames // self.patch_size_t
height, width = latents.shape[-2] // self.patch_size_h, latents.shape[-1] // self.patch_size_w
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
latents_vid, latents_img, prompt_vid, prompt_img, timestep_vid, timestep_img, \
embedded_timestep_vid, embedded_timestep_img = \
self._operate_on_patched_inputs(
latents=latents,
prompt=prompt,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
batch_size=batch_size,
frames=frames,
use_image_num=use_image_num
)
latents_vid = rearrange(latents_vid, 'b s h -> s b h', b=batch_size).contiguous()
prompt_vid = rearrange(prompt_vid, 'b s h -> s b h', b=batch_size).contiguous()
if self.enable_context_parallelism and latents_vid is not None and prompt_vid is not None:
timestep_vid = timestep_vid.view(latents_vid.shape[1], 6, -1).transpose(0, 1).contiguous()
latents_vid = split_forward_gather_backward(latents_vid, mpu.get_context_parallel_group(), dim=0,
grad_scale='down')
if self.sequence_parallel:
latents_vid = tensor_parallel.scatter_to_sequence_parallel_region(latents_vid)
prompt_vid = tensor_parallel.scatter_to_sequence_parallel_region(prompt_vid)
frames = torch.tensor(frames)
height = torch.tensor(height)
width = torch.tensor(width)
rotary_pos_emb = self.rope(batch_size, frames, height, width, latents.device)
if mpu.get_context_parallel_world_size() > 1:
rotary_pos_emb = rotary_pos_emb.chunk(mpu.get_context_parallel_world_size(), dim=0)[mpu.get_context_parallel_rank()]
if self.recompute_granularity == "full":
if latents_vid is not None:
latents_vid = self._checkpointed_forward(
latents_vid,
video_mask=vid_mask,
prompt=prompt_vid,
prompt_mask=prompt_vid_mask,
timestep=timestep_vid,
class_labels=class_labels,
frames=frames,
height=height,
width=width,
rotary_pos_emb=rotary_pos_emb,
)
if latents_img is not None:
latents_img = self._checkpointed_forward(
latents_img,
video_mask=img_mask,
prompt=prompt_img,
prompt_mask=prompt_img_mask,
timestep=timestep_img,
class_labels=class_labels,
frames=torch.tensor(1),
height=height,
width=width,
rotary_pos_emb=rotary_pos_emb,
)
else:
for block in self.videodit_blocks:
if latents_vid is not None:
latents_vid = block(
latents_vid,
video_mask=vid_mask,
prompt=prompt_vid,
prompt_mask=prompt_vid_mask,
timestep=timestep_vid,
class_labels=class_labels,
frames=frames,
height=height,
width=width
)
if latents_img is not None:
latents_img = block(
latents_img,
video_mask=img_mask,
prompt=prompt_img,
prompt_mask=prompt_img_mask,
timestep=timestep_img,
class_labels=class_labels,
frames=torch.tensor(1),
height=height,
width=width
)
output_vid, output_img = None, None
if latents_vid is not None:
output_vid = self._get_output_for_patched_inputs(
latents=latents_vid,
timestep=timestep_vid,
class_labels=class_labels,
embedded_timestep=embedded_timestep_vid,
num_frames=frames,
height=height,
width=width,
)
if latents_img is not None:
output_img = self._get_output_for_patched_inputs(
latents=latents_img,
timestep=timestep_img,
class_labels=class_labels,
embedded_timestep=embedded_timestep_img,
num_frames=1,
height=height,
width=width,
)
if use_image_num != 0:
output_img = rearrange(output_img, '(b i) c 1 h w -> b c i h w', i=use_image_num)
if output_vid is not None and output_img is not None:
output = torch.cat([output_vid, output_img], dim=2)
elif output_vid is not None:
output = output_vid
elif output_img is not None:
output = output_img
return output
def _get_block(self, layer_number):
return self.videodit_blocks[layer_number]
def _checkpointed_forward(
self,
latents,
video_mask,
prompt,
prompt_mask,
timestep,
class_labels,
frames,
height,
width,
rotary_pos_emb):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*args, **kwargs):
x_, *args = args
for index in range(start, end):
layer = self._get_block(index)
x_ = layer(x_, *args, **kwargs)
return x_
return custom_forward
if self.recompute_method == "uniform":
layer_num = 0
while layer_num < self.num_layers:
latents = tensor_parallel.checkpoint(
custom(layer_num, layer_num + self.recompute_num_layers),
self.distribute_saved_activations,
latents,
prompt,
video_mask,
prompt_mask,
timestep,
class_labels,
frames,
height,
width,
rotary_pos_emb
)
layer_num += self.recompute_num_layers
elif self.recompute_method == "block":
for layer_num in range(self.num_layers):
if layer_num < self.recompute_num_layers:
latents = tensor_parallel.checkpoint(
custom(layer_num, layer_num + 1),
self.distribute_saved_activations,
latents,
prompt,
video_mask,
prompt_mask,
timestep,
class_labels,
frames,
height,
width,
rotary_pos_emb
)
else:
block = self._get_block(layer_num)
latents = block(
latents,
video_mask=video_mask,
prompt=prompt,
prompt_mask=prompt_mask,
timestep=timestep,
class_labels=class_labels,
frames=frames,
height=height,
width=width,
rotary_pos_emb=rotary_pos_emb
)
else:
raise ValueError("Invalid activation recompute method.")
return latents
@property
def dtype(self) -> torch.dtype:
"""The dtype of the module (assuming that all the module parameters have the same dtype)."""
params = tuple(self.parameters())
if len(params) > 0:
return params[0].dtype
else:
buffers = tuple(self.buffers())
return buffers[0].dtype
def _operate_on_patched_inputs(self, latents, prompt, timestep, added_cond_kwargs, batch_size, frames,
use_image_num):
latents_vid, latents_img = self.pos_embed(latents.to(self.dtype), frames)
timestep_vid, timestep_img = None, None
embedded_timestep_vid, embedded_timestep_img = None, None
prompt_vid, prompt_img = None, None
if self.adaln_single is not None:
timestep, embedded_timestep = self.adaln_single(timestep, added_cond_kwargs, batch_size=batch_size,
hidden_dtype=self.dtype)
if latents_vid is None:
timestep_img = timestep
embedded_timestep_img = embedded_timestep
else:
timestep_vid = timestep
embedded_timestep_vid = embedded_timestep
if latents_img is not None:
timestep_img = repeat(timestep, 'b d -> (b i) d', i=use_image_num).contiguous()
embedded_timestep_img = repeat(embedded_timestep, 'b d -> (b i) d', i=use_image_num).contiguous()
if self.caption_projection is not None:
prompt = self.caption_projection(prompt)
if latents_vid is None:
prompt_img = rearrange(prompt, 'b 1 l d -> (b 1) l d')
else:
prompt_vid = rearrange(prompt[:, :1], 'b 1 l d -> (b 1) l d')
if latents_img is not None:
prompt_img = rearrange(prompt[:, 1:], 'b i l d -> (b i) l d')
return latents_vid, latents_img, prompt_vid, prompt_img, timestep_vid, timestep_img, embedded_timestep_vid, embedded_timestep_img
def _get_output_for_patched_inputs(self, latents, timestep, class_labels, embedded_timestep, num_frames,
height=None, width=None):
if self.sequence_parallel:
latents = tensor_parallel.gather_from_sequence_parallel_region(latents,
tensor_parallel_output_grad=False)
batch_size = latents.shape[1]
latents = rearrange(latents, 's b h -> b s h', b=batch_size).contiguous()
if self.enable_context_parallelism:
latents = gather_forward_split_backward(latents, mpu.get_context_parallel_group(), dim=1,
grad_scale='up')
if self.norm_type != "ada_norm_single":
conditioning = self.videodit_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=self.dtype)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
latents = self.norm_out(latents) * (1 + scale[:, None]) + shift[:, None]
latents = self.proj_out_2(latents)
elif self.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
latents = self.norm_out(latents)
latents = latents * (1 + scale) + shift
latents = self.proj_out(latents)
latents = latents.squeeze(1)
if self.adaln_single is None:
height = width = int(latents.shape[1] ** 0.5)
latents = latents.reshape(shape=(-1, num_frames, height, width, self.patch_size_t,
self.patch_size_h, self.patch_size_w, self.out_channels))
latents = torch.einsum("nthwopqc->nctohpwq", latents)
output = latents.reshape(shape=(-1, self.out_channels, num_frames * self.patch_size_t,
height * self.patch_size_h, width * self.patch_size_w))
return output
class VideoDiTBlock(nn.Module):
"""
A basic dit block for video generation.
Args:
dim: The number out channels in the input and output.
num_heads: The number of heads to use for multi-head attention.
head_dim: The number of channels in each head.
in_channels: The number of channels in the input (specify if the input is continuous).
out_channels: The number of channels in the output.
dropout: The dropout probability to use.
cross_attention_dim: The number of prompt dimensions to use.
attention_bias: Whether to use bias for QKV in VideoDiTBlock's attention.
activation_fn: The name of activation function use in VideoDiTBlock.
norm_type: can be 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'.
num_embeds_ada_norm: The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings
that are added to the hidden states.
norm_elementswise_affine: Whether to use learnable elementwise affine parameters for normalization.
norm_eps: The eps of he normalization.
interpolation_scale: The scale for interpolation.
"""
def __init__(
self,
dim: int,
num_heads: int,
head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
attention_out_bias: bool = True,
fa_layout: str = "sbh",
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = False,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
interpolation_scale: Tuple[float] = None,
enable_context_parallelism: bool = False,
sequence_parallel: bool = False,
rope=None,
):
super().__init__()
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.sequence_parallel = sequence_parallel
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError("If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined.")
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_zero":
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.enable_context_parallelism = enable_context_parallelism
self.self_atten = ParallelAttention(
query_dim=dim,
key_dim=None,
num_attention_heads=num_heads,
hidden_size=head_dim * num_heads,
proj_q_bias=attention_bias,
proj_k_bias=attention_bias,
proj_v_bias=attention_bias,
proj_out_bias=attention_out_bias,
dropout=dropout,
is_qkv_concat=False,
attention_type=AttnType.self_attn,
fa_layout=fa_layout,
rope=rope
)
if norm_type == "ada_norm":
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.cross_atten = ParallelMultiHeadAttentionSBH(
query_dim=dim,
key_dim=cross_attention_dim,
num_heads=num_heads,
head_dim=head_dim,
dropout=dropout,
proj_qkv_bias=attention_bias,
proj_out_bias=attention_out_bias
)
if norm_type == "ada_norm_continuous":
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
if norm_type == "ada_norm_single":
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5)
setattr(self.scale_shift_table, "sequence_parallel", self.sequence_parallel)
self._chunk_size = None
self._chunk_dim = 0
def forward(
self,
latents: torch.Tensor,
prompt: Optional[torch.Tensor] = None,
video_mask: Optional[torch.Tensor] = None,
prompt_mask: Optional[torch.Tensor] = None,
timestep: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.Tensor] = None,
frames: torch.int64 = None,
height: torch.int64 = None,
width: torch.int64 = None,
rotary_pos_emb=None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor:
frames = frames.item()
height = height.item()
width = width.item()
batch_size = latents.shape[1]
if self.norm_type == "ada_norm":
norm_latents = self.norm1(latents, timestep)
elif self.norm_type == "ada_norm_zero":
norm_latents, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
latents, timestep, class_labels, hidden_dtype=latents.dtype
)
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
norm_latents = self.norm1(latents)
elif self.norm_type == "ada_norm_continuous":
norm_latents = self.norm1(latents, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1)
).chunk(6, dim=0)
norm_latents = self.norm1(latents)
norm_latents = norm_latents * (1 + scale_msa) + shift_msa
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_latents = self.pos_embed(norm_latents)
if video_mask is not None:
video_mask = video_mask.view(batch_size, 1, -1, video_mask.shape[-1])
attn_output = self.self_atten(
query=norm_latents,
key=None,
mask=video_mask,
input_layout="sbh",
rotary_pos_emb=rotary_pos_emb
)
if self.norm_type == "ada_norm_zero":
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.norm_type == "ada_norm_single":
attn_output = gate_msa * attn_output
latents = attn_output + latents
if latents.ndim == 4:
latents = latents.squeeze(1)
if self.norm_type == "ada_norm":
norm_latents = self.norm2(latents, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_latents = self.norm2(latents)
elif self.norm_type == "ada_norm_single":
norm_latents = latents
elif self.norm_type == "ada_norm_continuous":
norm_latents = self.norm2(latents, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_latents = self.pos_embed(norm_latents)
if prompt_mask is not None:
prompt_mask = prompt_mask.view(batch_size, 1, -1, prompt_mask.shape[-1])
attn_output = self.cross_atten(
query=norm_latents,
key=prompt,
mask=prompt_mask
)
latents = attn_output + latents
if self.norm_type == "ada_norm_continuous":
norm_latents = self.norm3(latents, added_cond_kwargs["pooled_text_emb"])
elif not self.norm_type == "ada_norm_single":
norm_latents = self.norm3(latents)
if self.norm_type == "ada_norm_zero":
norm_latents = norm_latents * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.norm_type == "ada_norm_single":
norm_latents = self.norm2(latents)
norm_latents = norm_latents * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_latents)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.norm_type == "ada_norm_single":
ff_output = gate_mlp * ff_output
latents = ff_output + latents
if latents.ndim == 4:
latents = latents.squeeze(1)
return latents
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
self._chunk_size = chunk_size
self._chunk_dim = dim