import json
import math
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import random
import numpy as np
import torch_npu
import torch
import torch.nn.functional as F
import torch.nn.init as init
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormSingle
from diffusers.models.attention import AdaLayerNorm, FeedForward
from diffusers.models.attention_processor import Attention
from diffusers.utils import BaseOutput, is_torch_version
from einops import rearrange, repeat
from torch import nn
from diffusers.utils.torch_utils import maybe_allow_in_graph
from mindspeed_mm.models.common.embeddings import PatchEmbed2D_3DsincosPE
try:
from diffusers.models.embeddings import PixArtAlphaTextProjection
except ImportError:
from diffusers.models.embeddings import \
CaptionProjection as PixArtAlphaTextProjection
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
@maybe_allow_in_graph
class ProxyTokensTransformerBlock(nn.Module):
r"""
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
norm_eps: float = 1e-5,
shift_window: bool = False,
final_dropout: bool = False,
attention_type: str = "default",
compress_ratios=None,
proxy_compress_ratios=None,
):
super().__init__()
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.proxy_norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.proxy_cross_norm = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.proxy_attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
upcast_attention=upcast_attention,
)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
upcast_attention=upcast_attention,
)
if cross_attention_dim is not None:
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.norm_before = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.pvisual_attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.norm2 = None
self.attn2 = None
self.norm_before = None
self.shift_window = shift_window
if self.shift_window:
self.attn3 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
upcast_attention=upcast_attention,
)
self.shift_window_norm = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.linear_2 = zero_module(nn.Linear(dim, dim))
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
self.hid_dim = dim
self.linear_1_visual = zero_module(nn.Linear(self.hid_dim, dim))
if compress_ratios[0] == 1:
self.scale_ratio = (1, compress_ratios[1] // proxy_compress_ratios[1], compress_ratios[2] // proxy_compress_ratios[2])
else:
self.scale_ratio = (compress_ratios[0] // compress_ratios[0], compress_ratios[1] // proxy_compress_ratios[1], compress_ratios[2] // proxy_compress_ratios[2])
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
if self.shift_window:
self.proxy_scale_shift_table = nn.Parameter(torch.randn(3, dim) / dim**0.5)
self._chunk_size = None
self._chunk_dim = 0
def window_shift(self, hidden_states, number_proxys, compress_ratios, num_frames, height, width):
compress_scale = compress_ratios[0] * compress_ratios[1] * compress_ratios[2]
dim = hidden_states.shape[-1]
F_new, H_new, W_new = num_frames // compress_ratios[0], height // compress_ratios[1], width // compress_ratios[2]
f_shift_size, h_shift_size, w_shift_size = compress_ratios[0] // 2, compress_ratios[1] // 2, compress_ratios[2] // 2
hidden_states = rearrange(hidden_states, "(b p) n c -> b p n c", p=number_proxys)
hidden_states = hidden_states.reshape(-1, F_new, H_new, W_new, compress_ratios[0], compress_ratios[1], compress_ratios[2], dim)
hidden_states = rearrange(hidden_states, "b f h w x y z c -> b (f x) (h y) (w z) c")
after_shift_hidden_states = torch.roll(hidden_states, shifts=(-f_shift_size, -h_shift_size, -w_shift_size), dims=(1, 2, 3))
after_shift_hidden_states = rearrange(after_shift_hidden_states, "b (f x) (h y) (w z) c -> b f h w x y z c",
f=F_new, h=H_new, w=W_new)
after_shift_hidden_states = rearrange(after_shift_hidden_states, "b f h w x y z c -> (b f h w) (x y z) c")
return after_shift_hidden_states
def window_Ishift(self, after_shift_hidden_states, number_proxys, compress_ratios, num_frames, height, width):
compress_scale = compress_ratios[0] * compress_ratios[1] * compress_ratios[2]
dim = after_shift_hidden_states.shape[-1]
F_new, H_new, W_new = num_frames // compress_ratios[0], height // compress_ratios[1], width // compress_ratios[2]
f_shift_size, h_shift_size, w_shift_size = compress_ratios[0] // 2, compress_ratios[1] // 2, compress_ratios[2] // 2
after_shift_hidden_states = rearrange(after_shift_hidden_states, "(b p) n c -> b p n c", p=number_proxys)
after_shift_hidden_states = after_shift_hidden_states.reshape(-1, F_new, H_new, W_new, compress_ratios[0], compress_ratios[1], compress_ratios[2], dim)
after_shift_hidden_states = rearrange(after_shift_hidden_states, "b f h w x y z c -> b (f x) (h y) (w z) c")
hidden_states = torch.roll(after_shift_hidden_states, shifts=(f_shift_size, h_shift_size, w_shift_size), dims=(1, 2, 3))
hidden_states = rearrange(hidden_states, "b (f x) (h y) (w z) c -> b f h w x y z c",
f=F_new, h=H_new, w=W_new)
hidden_states = rearrange(hidden_states, "b f h w x y z c -> (b f h w) (x y z) c")
return hidden_states
def get_proxy_token(self, hidden_states, compress_ratios, scale_ratio, window_number, mode='first'):
number, dim = hidden_states.shape[1], hidden_states.shape[-1]
f, h, w = compress_ratios[0] // scale_ratio[0], compress_ratios[1] // scale_ratio[1], compress_ratios[2] // scale_ratio[2]
proxy_hidden_states = hidden_states.reshape(hidden_states.shape[0], *compress_ratios, dim)
proxy_hidden_states = proxy_hidden_states.reshape(hidden_states.shape[0], f, scale_ratio[0], h, scale_ratio[1],
w, scale_ratio[2], dim)
proxy_hidden_states = rearrange(proxy_hidden_states, "b f x h y w z c -> b (f h w) (x y z) c")
if mode == 'first':
proxy_hidden_states = proxy_hidden_states[:, 0, :, :]
elif mode == 'mean':
proxy_hidden_states = proxy_hidden_states.mean(1)
proxy_hidden_states = rearrange(proxy_hidden_states, "(b p) n c -> b (p n) c", p=window_number)
return proxy_hidden_states
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
shift_window_attention_mask: Optional[torch.FloatTensor] = None,
compress_ratios: Optional[torch.FloatTensor] = None,
proxy_compress_ratios: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
num_frames: int = 16,
height: int = 32,
width: int = 32,
) -> torch.FloatTensor:
compress_f, compress_h, compress_w = num_frames // compress_ratios[0], height // compress_ratios[1], width // compress_ratios[2]
window_number = compress_f * compress_h * compress_w
p_compress_f, p_compress_h, p_compress_w = num_frames // proxy_compress_ratios[0], height // proxy_compress_ratios[1], width // proxy_compress_ratios[2]
proxy_token_number = p_compress_f * p_compress_h * p_compress_w
h_dtype = hidden_states.dtype
batch_size, dim = hidden_states.shape[0] // window_number, hidden_states.shape[-1]
if self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1).repeat(window_number, 1, 1)
).chunk(6, dim=1)
if self.shift_window:
shift_sw_msa, scale_sw_msa, gate_sw_msa = (
self.proxy_scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)[:, :3, :].repeat(window_number, 1, 1)
).chunk(3, dim=1)
else:
raise ValueError("Incorrect norm used")
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
proxy_hidden_states = self.get_proxy_token(hidden_states, compress_ratios, self.scale_ratio, window_number, mode="mean")
norm_proxy_hidden_states = self.proxy_norm1(proxy_hidden_states)
proxy_attn_output = self.proxy_attn1(
norm_proxy_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
proxy_hidden_states = proxy_attn_output + proxy_hidden_states
if proxy_hidden_states.ndim == 4:
proxy_hidden_states = proxy_hidden_states.squeeze(1)
if self.pvisual_attn2 is not None:
a_hidden_states = rearrange(hidden_states, "(b w) n c -> b (w n) c", w=window_number)
pvisual_hidden_states = self.proxy_cross_norm(proxy_hidden_states)
a_hidden_states = self.norm_before(a_hidden_states)
pvisual_attn_output = self.pvisual_attn2(
a_hidden_states,
encoder_hidden_states=pvisual_hidden_states,
attention_mask=None,
**cross_attention_kwargs,
)
pvisual_attn_output = rearrange(pvisual_attn_output, "b (w n) c -> (b w) n c", w=window_number)
hidden_states = hidden_states + self.linear_1_visual(pvisual_attn_output)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
if self.shift_window:
after_shift_hidden_states = self.window_shift(hidden_states, window_number,
compress_ratios,
num_frames, height, width)
after_shift_norm_hidden_states = self.shift_window_norm(after_shift_hidden_states)
after_shift_norm_hidden_states = after_shift_norm_hidden_states * (1 + scale_sw_msa) + shift_sw_msa
after_shift_attn_output = self.attn3(
after_shift_norm_hidden_states,
encoder_hidden_states=None,
attention_mask=shift_window_attention_mask,
**cross_attention_kwargs,
)
attn_output = self.window_Ishift(after_shift_attn_output, window_number,
compress_ratios,
num_frames, height, width)
attn_output = gate_sw_msa * attn_output
hidden_states = self.linear_2(attn_output) + hidden_states
if self.attn2 is not None:
hidden_states = rearrange(hidden_states, "(b p) n c -> b (p n) c", p=window_number)
norm_hidden_states = hidden_states
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
hidden_states = rearrange(hidden_states, "b (p n) c -> (b p) n c", p=window_number)
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))
if self.use_additional_conditions:
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else:
conditioning = timesteps_emb
return conditioning
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PTDiTDiffuser(ModelMixin, ConfigMixin):
"""
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_heads: int = 16,
head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
frame: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
shift_window: bool = False,
compress_ratios: Optional[list] = None,
proxy_compress_ratios: Optional[list] = None,
**kwargs
):
super().__init__()
self.num_attention_heads = num_heads
self.attention_head_dim = head_dim
inner_dim = num_heads * head_dim
self.norm_type = norm_type
self.compress_ratios = [1, 8, 8] if compress_ratios is None else compress_ratios
self.proxy_compress_ratios = [1, 4, 4] if proxy_compress_ratios is None else proxy_compress_ratios
print(f'compress_ratios: {compress_ratios}')
print(f'proxy_compress_ratios: {proxy_compress_ratios}')
if sample_size is None:
raise ValueError("PTDiTDiffuser over patched input must provide sample_size")
self.height = sample_size
self.width = sample_size
self.sample_size = sample_size
self.patch_size = patch_size
interpolation_scale = self.sample_size // 64
interpolation_scale = max(interpolation_scale, 1)
self.pos_embed = PatchEmbed2D_3DsincosPE(
height=sample_size,
width=sample_size,
frame=frame,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
self.transformer_blocks = nn.ModuleList(
[
ProxyTokensTransformerBlock(
inner_dim,
num_heads,
head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
shift_window=shift_window,
attention_type=attention_type,
compress_ratios=self.compress_ratios,
proxy_compress_ratios=self.proxy_compress_ratios,
)
for d in range(num_layers)
]
)
self.out_channels = in_channels if out_channels is None else out_channels
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.compress_scale = self.compress_ratios[0] * self.compress_ratios[1] * self.compress_ratios[2]
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
self.adaln_single = None
self.use_additional_conditions = False
if norm_type == "ada_norm_single":
self.use_additional_conditions = self.sample_size == 128
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
self.caption_projection = None
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = True
self.shift_window_attention_mask = None
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def get_shift_window_attention_mask(self, video_length, height, weight, compress_ratios, c_dtype, c_device):
img_mask = torch.zeros((1, video_length, height, weight))
shift_size = [compress_ratios[0] // 2, compress_ratios[1] // 2, compress_ratios[2] // 2]
h_slices = (slice(0, -compress_ratios[1]),
slice(-compress_ratios[1], -shift_size[1]),
slice(-shift_size[1], None))
w_slices = (slice(0, -compress_ratios[2]),
slice(-compress_ratios[2], -shift_size[2]),
slice(-shift_size[2], None))
f_slices = (slice(0, -compress_ratios[0]),
slice(-compress_ratios[0], -shift_size[0]),
slice(-shift_size[0], None))
cnt = 0
for f in f_slices:
for h in h_slices:
for w in w_slices:
img_mask[:, f, h, w] = cnt
cnt += 1
img_mask = img_mask.view(1, video_length // compress_ratios[0], compress_ratios[0],
height // compress_ratios[1], compress_ratios[1],
weight // compress_ratios[2], compress_ratios[2])
mask_windows = rearrange(img_mask, "b f x h y w z -> (b f h w) (x y z)")
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1.0)).masked_fill(attn_mask == 0, float(0.0))
self.shift_window_attention_mask = attn_mask.to(device=c_device, dtype=c_dtype)
def extract_proxy_tokens(self, hidden_states, compress_ratios):
b, c, f, h, w = hidden_states.shape
n_f, n_h, n_w = f // compress_ratios[0], h // compress_ratios[1], w // compress_ratios[2]
hidden_states = hidden_states.reshape(b, c, n_f, compress_ratios[0], n_h, compress_ratios[1], n_w, compress_ratios[2])
hidden_states = rearrange(hidden_states, "b c f x h y w z -> (b f h w) c x y z")
return hidden_states
def set_input_tensor(self, input_tensor):
"""
Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func
"""
self.input_tensor = input_tensor
def zero_init_shift_window_attention_linear(self):
for block in self.transformer_blocks:
try:
zero_module(block.linear_2)
except ValueError as e:
print(f"An error occurred while zzeroing module: {e}")
def forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
):
"""
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
Returns:
sample tensor.
"""
dtype = self.pos_embed.proj.weight.dtype
hidden_states = hidden_states.to(dtype)
timestep = timestep.to(dtype)
encoder_hidden_states = encoder_hidden_states.to(dtype)
if added_cond_kwargs is None:
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
compress_f, compress_h, compress_w = video_length // self.compress_ratios[0], height // self.compress_ratios[1], width // self.compress_ratios[2]
number_proxy_tokens = compress_f * compress_h * compress_w
hidden_states = rearrange(hidden_states, "b c f h w -> b f c h w")
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
batch_size = hidden_states.shape[0] // video_length
t = timestep
timestep, embedded_timestep = self.adaln_single(
t, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
hidden_states = self.extract_proxy_tokens(hidden_states, self.compress_ratios)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
if self.shift_window_attention_mask is None:
self.get_shift_window_attention_mask(video_length, height, width, self.compress_ratios,
hidden_states.dtype, hidden_states.device)
if self.shift_window_attention_mask is not None and self.shift_window_attention_mask.ndim == 3:
shift_window_attention_mask = self.shift_window_attention_mask.unsqueeze(0).repeat(batch_size, 1, 1, 1)
shift_window_attention_mask = shift_window_attention_mask.reshape(batch_size * number_proxy_tokens, self.compress_scale, self.compress_scale) * 10000.0
if self.caption_projection is not None:
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
args = [video_length, height, width]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
shift_window_attention_mask,
self.compress_ratios,
self.proxy_compress_ratios,
cross_attention_kwargs,
*args,
**ckpt_kwargs,
)
else:
kwargs = {"num_frames": video_length, "height": height, "width": width}
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
shift_window_attention_mask=shift_window_attention_mask,
compress_ratios=self.compress_ratios,
proxy_compress_ratios=self.proxy_compress_ratios,
cross_attention_kwargs=cross_attention_kwargs,
**kwargs
)
if self.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
hidden_states = rearrange(hidden_states, "(b p) n c -> b (p n) c", p=number_proxy_tokens)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
hidden_states = rearrange(hidden_states, "b (p n) c -> b p n c", p=number_proxy_tokens)
hidden_states = hidden_states.reshape(-1, compress_f, compress_h, compress_w, self.compress_ratios[0],
self.compress_ratios[1], self.compress_ratios[2],
self.patch_size * self.patch_size * self.out_channels)
hidden_states = rearrange(hidden_states, "b f h w x y z c -> b (f x) (h y) (w z) c")
hidden_states = hidden_states.reshape(
shape=(-1, video_length, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, video_length, height * self.patch_size, width * self.patch_size)
)
output = output.to(torch.float32)
return output