import random
from typing import Any, Dict, Optional, Tuple, List
from dataclasses import dataclass
from einops import rearrange, repeat
import torch
from torch import nn
import torch.nn.functional as F
from diffusers.utils import is_torch_version
from megatron.legacy.model.rms_norm import RMSNorm
from megatron.core import mpu, tensor_parallel
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.models.common.embeddings import PatchEmbed2D, CombinedTimestepTextProjEmbeddings
from mindspeed_mm.models.common.ffn import FeedForward
from mindspeed_mm.models.common.normalize import FP32LayerNorm
from mindspeed_mm.models.common.attention import MultiHeadSparseMMAttentionSBH
from mindspeed_mm.models.common.communications import split_forward_gather_backward, gather_forward_split_backward
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
def zero_initialized_skip_connection(module_cls):
if not issubclass(module_cls, nn.Linear):
raise TypeError(f"Expected module_cls to be nn.Linear, but got {module_cls.__name__}.")
def zero_init(*args, **kwargs):
module = module_cls(*args, **kwargs)
in_features = module.in_features
out_features = module.out_features
if in_features != 2 * out_features:
raise ValueError("Expected in_features to be twice out_features, "
f"but got in_features={in_features} and out_features={out_features}.")
module.weight.data[:, :out_features] = torch.eye(out_features, dtype=module.weight.dtype)
module.weight.data[:, out_features:] = 0.0
if module.bias is not None:
module.bias.data.fill_(0.0)
return module
return zero_init
def maybe_clamp_tensor(x, max_value=65504.0, min_value=-65504.0, training=True):
if not training and x.dtype == torch.float16:
x.nan_to_num_(posinf=max_value, neginf=min_value).clamp_(min_value, max_value)
return x
@dataclass
class BlockForwardInputs:
embedded_timestep: torch.Tensor
frames: int
height: int
width: int
video_rotary_emb: Tuple[torch.Tensor]
@dataclass
class PositionParams:
b: int
t: int
h: int
w: int
device: str
training: bool = True
class SparseUMMDiT(MultiModalModule):
"""
A video dit model for video generation. can process both standard continuous images of shape
(batch_size, num_channels, width, height) as well as quantized image embeddings of shape
(batch_size, num_image_vectors). Define whether input is continuous or discrete depending on config.
Args:
num_layers: The number of layers for VideoDiTBlock.
num_heads: The number of heads to use for multi-head attention.
head_dim: The number of channels in each head.
in_channels: The number of channels in· the input (specify if the input is continuous).
out_channels: The number of channels in the output.
dropout: The dropout probability to use.
cross_attention_dim: The number of prompt dimensions to use.
attention_bias: Whether to use bias in VideoDiTBlock's attention.
patch_size_thw: The shape of the patchs.
activation_fn: The name of activation function use in VideoDiTBlock.
norm_elementwise_affine: Whether to use learnable elementwise affine parameters for normalization.
norm_eps: The eps of the normalization.
interpolation_scale: The scale for interpolation.
"""
def __init__(
self,
num_heads: int = 16,
head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: Optional[List[int]] = None,
sparse_n: Optional[List[int]] = None,
double_ff: bool = False,
dropout: float = 0.0,
attention_bias: bool = True,
sample_size_h: Optional[int] = None,
sample_size_w: Optional[int] = None,
sample_size_t: Optional[int] = None,
patch_size_thw: Optional[Tuple[int]] = None,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-6,
caption_channels: int = None,
interpolation_scale_h: float = 1.0,
interpolation_scale_w: float = 1.0,
interpolation_scale_t: float = 1.0,
sparse1d: bool = False,
pooled_projection_dim: int = 1024,
timestep_embed_dim: int = 512,
norm_cls: str = 'fp32_layer_norm',
skip_connection: bool = False,
explicit_uniform_rope: bool = False,
skip_connection_zero_init: bool = True,
**kwargs
):
super().__init__(config=None)
args = get_args()
self.num_heads = num_heads
self.head_dim = head_dim
self.sequence_parallel = args.sequence_parallel
self.gradient_checkpointing = True
self.recompute_granularity = args.recompute_granularity
self.distribute_saved_activations = args.distribute_saved_activations
self.recompute_num_layers = args.recompute_num_layers
if self.recompute_granularity == "selective":
raise ValueError("recompute_granularity does not support selective mode in VideoDiT")
if self.distribute_saved_activations:
raise NotImplementedError("distribute_saved_activations is currently not supported")
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
hidden_size = num_heads * head_dim
self.num_layers = num_layers if num_layers is not None else [2, 4, 8, 4, 2]
self.sparse_n = sparse_n if sparse_n is not None else [1, 4, 16, 4, 1]
self.patch_size_t = patch_size_thw[0]
self.patch_size = patch_size_thw[1]
self.skip_connection = skip_connection
self.skip_connection_zero_init = skip_connection_zero_init
if norm_cls == 'rms_norm':
self.norm_cls = RMSNorm
elif norm_cls == 'fp32_layer_norm':
self.norm_cls = FP32LayerNorm
if len(num_layers) != len(sparse_n):
raise ValueError("num_layers and sparse_n must have the same length")
if len(num_layers) % 2 != 1:
raise ValueError("num_layers must have odd length")
if any([i % 2 != 0 for i in num_layers]):
raise ValueError("num_layers must have even numbers")
if not sparse1d:
self.sparse_n = [1] * len(num_layers)
interpolation_scale_thw = (interpolation_scale_t, interpolation_scale_h, interpolation_scale_w)
self.patch_embed = PatchEmbed2D(
patch_size=self.patch_size,
in_channels=in_channels,
embed_dim=hidden_size,
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
timestep_embed_dim=timestep_embed_dim,
embedding_dim=timestep_embed_dim,
pooled_projection_dim=pooled_projection_dim
)
self.caption_projection = nn.Linear(caption_channels, hidden_size)
self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw)
self.position_getter = PositionGetter3D(
sample_size_t, sample_size_h, sample_size_w, explicit_uniform_rope, atten_layout="SBH"
)
self.transformer_blocks = []
self.skip_norm_linear = []
self.skip_norm_linear_enc = []
for idx, (num_layer, sparse_n) in enumerate(zip(self.num_layers, self.sparse_n)):
is_last_stage = idx == len(num_layers) - 1
if self.skip_connection and idx > len(num_layers) // 2:
skip_connection_linear = zero_initialized_skip_connection(nn.Linear) if self.skip_connection_zero_init else nn.Linear
self.skip_norm_linear.append(
nn.Sequential(
self.norm_cls(
hidden_size * 2,
eps=norm_eps,
sequence_parallel=self.sequence_parallel,
) if not self.skip_connection_zero_init else nn.Identity(),
skip_connection_linear(hidden_size * 2, hidden_size),
)
)
self.skip_norm_linear_enc.append(
nn.Sequential(
self.norm_cls(
hidden_size * 2,
eps=norm_eps,
sequence_parallel=self.sequence_parallel,
) if not self.skip_connection_zero_init else nn.Identity(),
skip_connection_linear(hidden_size * 2, hidden_size),
)
)
stage_blocks = nn.ModuleList(
[
SparseMMDiTBlock(
dim=hidden_size,
num_heads=num_heads,
head_dim=head_dim,
timestep_embed_dim=timestep_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
interpolation_scale_thw=interpolation_scale_thw,
double_ff=double_ff and not (is_last_stage and i == num_layer - 1),
sparse1d=sparse1d if sparse_n > 1 else False,
sparse_n=sparse_n,
sparse_group=i % 2 == 1 if sparse_n > 1 else False,
context_pre_only=is_last_stage and (i == num_layer - 1),
norm_cls=norm_cls,
)
for i in range(num_layer)
]
)
self.transformer_blocks.append(stage_blocks)
self.transformer_blocks = nn.ModuleList(self.transformer_blocks)
if self.skip_connection:
self.skip_norm_linear = nn.ModuleList(self.skip_norm_linear)
self.skip_norm_linear_enc = nn.ModuleList(self.skip_norm_linear_enc)
self.norm_final = self.norm_cls(
hidden_size, eps=norm_eps, sequence_parallel=self.sequence_parallel,
)
self.norm_out = AdaNorm(
embedding_dim=timestep_embed_dim,
output_dim=hidden_size * 2,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
norm_cls=norm_cls,
)
self.proj_out = nn.Linear(
hidden_size, self.patch_size_t * self.patch_size * self.patch_size * out_channels
)
modules = [self.norm_final]
if self.skip_connection:
modules += [self.skip_norm_linear, self.skip_norm_linear_enc]
for module in modules:
for _, param in module.named_parameters():
setattr(param, "sequence_parallel", self.sequence_parallel)
def prepare_sparse_mask(self, attention_mask, encoder_attention_mask, sparse_n):
attention_mask = attention_mask.unsqueeze(1)
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
length = attention_mask.shape[-1]
if length % (sparse_n * sparse_n) == 0:
pad_len = 0
else:
pad_len = sparse_n * sparse_n - length % (sparse_n * sparse_n)
attention_mask_sparse = F.pad(attention_mask, (0, pad_len, 0, 0), value=-10000.0)
attention_mask_sparse_1d = rearrange(
attention_mask_sparse,
'b 1 1 (g k) -> (k b) 1 1 g',
k=sparse_n
)
attention_mask_sparse_1d_group = rearrange(
attention_mask_sparse,
'b 1 1 (n m k) -> (m b) 1 1 (n k)',
m=sparse_n,
k=sparse_n
)
encoder_attention_mask_sparse = encoder_attention_mask.repeat(sparse_n, 1, 1, 1)
attention_mask_sparse_1d = torch.cat([attention_mask_sparse_1d, encoder_attention_mask_sparse], dim=-1)
attention_mask_sparse_1d_group = torch.cat([attention_mask_sparse_1d_group, encoder_attention_mask_sparse], dim=-1)
def get_attention_mask(mask, repeat_num):
mask = mask.to(torch.bool)
mask = mask.repeat(1, 1, repeat_num, 1)
return mask
attention_mask_sparse_1d = get_attention_mask(
attention_mask_sparse_1d, attention_mask_sparse_1d.shape[-1]
)
attention_mask_sparse_1d_group = get_attention_mask(
attention_mask_sparse_1d_group, attention_mask_sparse_1d_group.shape[-1]
)
return {
False: attention_mask_sparse_1d,
True: attention_mask_sparse_1d_group
}
@property
def dtype(self) -> torch.dtype:
"""The dtype of the module (assuming that all the module parameters have the same dtype)."""
params = tuple(self.parameters())
if len(params) > 0:
return params[0].dtype
else:
buffers = tuple(self.buffers())
return buffers[0].dtype
def _operate_on_enc(
self, hidden_states, encoder_hidden_states, inputs: BlockForwardInputs
):
skip_connections = []
for _, stage_block in enumerate(self.transformer_blocks[:len(self.num_layers) // 2]):
for _, block in enumerate(stage_block):
attention_mask = self.sparse_mask[block.sparse_n][block.sparse_group]
hidden_states, encoder_hidden_states = self.block_forward(
block=block,
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
inputs=inputs
)
if self.skip_connection:
skip_connections.append([hidden_states, encoder_hidden_states])
return hidden_states, encoder_hidden_states, skip_connections
def _operate_on_mid(
self, hidden_states, encoder_hidden_states, inputs: BlockForwardInputs
):
for _, block in enumerate(self.transformer_blocks[len(self.num_layers) // 2]):
attention_mask = self.sparse_mask[block.sparse_n][block.sparse_group]
hidden_states, encoder_hidden_states = self.block_forward(
block=block,
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
inputs=inputs
)
return hidden_states, encoder_hidden_states
def _operate_on_dec(
self, hidden_states, skip_connections, encoder_hidden_states, inputs: BlockForwardInputs
):
for idx, stage_block in enumerate(self.transformer_blocks[len(self.num_layers) // 2 + 1:]):
if self.skip_connection:
skip_hidden_states, skip_encoder_hidden_states = skip_connections.pop()
hidden_states = torch.cat([hidden_states, skip_hidden_states], dim=-1)
hidden_states = self.skip_norm_linear[idx](hidden_states)
encoder_hidden_states = torch.cat([encoder_hidden_states, skip_encoder_hidden_states], dim=-1)
encoder_hidden_states = self.skip_norm_linear_enc[idx](encoder_hidden_states)
for _, block in enumerate(stage_block):
attention_mask = self.sparse_mask[block.sparse_n][block.sparse_group]
hidden_states, encoder_hidden_states = self.block_forward(
block=block,
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
inputs=inputs
)
return hidden_states, encoder_hidden_states
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, pooled_projections):
hidden_states = self.patch_embed(hidden_states.to(self.dtype))
if pooled_projections.shape[1] != 1:
raise AssertionError("Pooled projection should have shape (b, 1, 1, d)")
pooled_projections = pooled_projections.squeeze(1)
timesteps_emb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
if encoder_hidden_states.shape[1] != 1:
raise AssertionError("Encoder hidden states should have shape (b, 1, l, d)")
encoder_hidden_states = encoder_hidden_states.squeeze(1)
return hidden_states, encoder_hidden_states, timesteps_emb
def _get_output_for_patched_inputs(
self, hidden_states, embedded_timestep, frames, height, width
):
hidden_states = self.norm_final(hidden_states)
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
if self.sequence_parallel:
hidden_states = tensor_parallel.gather_from_sequence_parallel_region(hidden_states,
tensor_parallel_output_grad=False)
hidden_states = rearrange(hidden_states, 's b h -> b s h', b=hidden_states.shape[1]).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
-1, frames, height, width,
self.patch_size_t, self.patch_size, self.patch_size, self.out_channels
)
hidden_states = torch.einsum("nthwopqc -> nctohpwq", hidden_states)
output = hidden_states.reshape(
-1,
self.out_channels,
frames * self.patch_size_t,
height * self.patch_size,
width * self.patch_size
)
return output
def block_forward(
self,
block,
hidden_states,
attention_mask,
encoder_hidden_states,
inputs: BlockForwardInputs,
):
if self.training and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
inputs,
**ckpt_kwargs
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
inputs=inputs
)
return hidden_states, encoder_hidden_states
def forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
prompt: Optional[torch.Tensor] = None,
video_mask: Optional[torch.Tensor] = None,
prompt_mask: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
pooled_projections = prompt[1] if len(prompt) > 1 else None
encoder_hidden_states = prompt[0]
encoder_attention_mask = prompt_mask[0]
attention_mask = video_mask
batch_size, c, frames, height, width = hidden_states.shape
if encoder_hidden_states.ndim == 3:
encoder_hidden_states = encoder_hidden_states.unsqueeze(1)
encoder_attention_mask = encoder_attention_mask.view(batch_size, -1, encoder_attention_mask.shape[-1])
if mpu.get_context_parallel_world_size() > 1:
frames //= mpu.get_context_parallel_world_size()
hidden_states = split_forward_gather_backward(hidden_states, mpu.get_context_parallel_group(), dim=2,
grad_scale='down')
encoder_hidden_states = split_forward_gather_backward(encoder_hidden_states, mpu.get_context_parallel_group(),
dim=2, grad_scale='down')
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask.to(self.dtype)
attention_mask = attention_mask.unsqueeze(1)
attention_mask = F.max_pool3d(
attention_mask,
kernel_size=(self.patch_size_t, self.patch_size, self.patch_size),
stride=(self.patch_size_t, self.patch_size, self.patch_size)
)
attention_mask = rearrange(attention_mask, 'b 1 t h w -> (b 1) 1 (t h w)')
attention_mask = (1 - attention_mask.bool().to(self.dtype)) * -10000.0
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3:
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
frames = ((frames - 1) // self.patch_size_t + 1) if frames % 2 == 1 else frames // self.patch_size_t
height, width = height // self.patch_size, width // self.patch_size
if pooled_projections is not None and pooled_projections.ndim == 2:
pooled_projections = pooled_projections.unsqueeze(1)
hidden_states, encoder_hidden_states, embedded_timestep = self._operate_on_patched_inputs(
hidden_states, encoder_hidden_states, timestep, pooled_projections
)
hidden_states = rearrange(hidden_states, 'b s h -> s b h', b=batch_size).contiguous()
encoder_hidden_states = rearrange(encoder_hidden_states, 'b s h -> s b h', b=batch_size).contiguous()
self.sparse_mask = {}
if attention_mask.device != encoder_attention_mask.device:
encoder_attention_mask = encoder_attention_mask.to(attention_mask.device)
for sparse_n in list(set(self.sparse_n)):
self.sparse_mask[sparse_n] = self.prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n)
pos_thw = self.position_getter(
PositionParams(
b=batch_size,
t=frames * mpu.get_context_parallel_world_size(),
h=height,
w=width,
device=hidden_states.device,
training=self.training
)
)
video_rotary_emb = self.rope(self.head_dim, pos_thw, hidden_states.device)
if self.sequence_parallel:
hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(hidden_states)
encoder_hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(encoder_hidden_states)
inputs = BlockForwardInputs(
embedded_timestep=embedded_timestep,
frames=frames,
height=height,
width=width,
video_rotary_emb=video_rotary_emb
)
hidden_states, encoder_hidden_states, skip_connections = self._operate_on_enc(
hidden_states, encoder_hidden_states, inputs
)
hidden_states, encoder_hidden_states = self._operate_on_mid(
hidden_states, encoder_hidden_states, inputs
)
hidden_states, encoder_hidden_states = self._operate_on_dec(
hidden_states, skip_connections, encoder_hidden_states, inputs
)
output = self._get_output_for_patched_inputs(
hidden_states, embedded_timestep, frames, height, width
)
if mpu.get_context_parallel_world_size() > 1:
output = gather_forward_split_backward(output, mpu.get_context_parallel_group(), dim=2,
grad_scale='up')
return output
class SparseMMDiTBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
head_dim: int,
timestep_embed_dim: int,
dropout=0.0,
activation_fn: str = "geglu",
attention_bias: bool = False,
attention_out_bias: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = False,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = False,
context_pre_only: bool = False,
interpolation_scale_thw: Tuple[int] = (1, 1, 1),
double_ff: bool = False,
sparse1d: bool = False,
sparse_n: int = 2,
sparse_group: bool = False,
norm_cls: str = 'fp32_layer_norm',
):
super().__init__()
self.sparse1d = sparse1d
self.sparse_n = sparse_n
self.sparse_group = sparse_group
self.head_dim = head_dim
self.double_ff = double_ff
self.context_pre_only = context_pre_only
if norm_cls == 'rms_norm':
self.norm_cls = RMSNorm
elif norm_cls == 'fp32_layer_norm':
self.norm_cls = FP32LayerNorm
self.norm1 = OpenSoraNormZero(
timestep_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True, norm_cls=norm_cls
)
self.attn1 = MultiHeadSparseMMAttentionSBH(
query_dim=dim,
key_dim=None,
num_heads=num_heads,
head_dim=head_dim,
added_kv_proj_dim=dim,
dropout=dropout,
proj_qkv_bias=attention_bias,
proj_out_bias=attention_out_bias,
context_pre_only=context_pre_only,
qk_norm='rms_norm',
eps=norm_eps,
sparse1d=sparse1d,
sparse_n=sparse_n,
sparse_group=sparse_group,
is_cross_attn=False,
)
self.norm2 = OpenSoraNormZero(
timestep_embed_dim,
dim,
norm_elementwise_affine,
norm_eps,
bias=True,
norm_cls=norm_cls,
context_pre_only=(not double_ff and context_pre_only)
)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
if self.double_ff:
self.ff_enc = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
inputs: Optional[BlockForwardInputs] = None,
) -> torch.FloatTensor:
vis_seq_length, batch_size = hidden_states.shape[:2]
hidden_states = maybe_clamp_tensor(hidden_states, training=self.training)
encoder_hidden_states = maybe_clamp_tensor(encoder_hidden_states, training=self.training)
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, inputs.embedded_timestep
)
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
frames=inputs.frames,
height=inputs.height,
width=inputs.width,
attention_mask=attention_mask,
video_rotary_emb=inputs.video_rotary_emb,
)
weight_dtype = hidden_states.dtype
if gate_msa.dtype != torch.float32 or enc_gate_msa.dtype != torch.float32:
raise ValueError("Gate must be float32.")
hidden_states = hidden_states.float() + gate_msa * attn_hidden_states.float()
hidden_states = hidden_states.to(weight_dtype)
if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states.float() + enc_gate_msa * attn_encoder_hidden_states.float()
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
hidden_states = maybe_clamp_tensor(hidden_states, training=self.training)
if not self.context_pre_only:
encoder_hidden_states = maybe_clamp_tensor(encoder_hidden_states, training=self.training)
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, inputs.embedded_timestep
)
weight_dtype = hidden_states.dtype
if gate_ff.dtype != torch.float32 or enc_gate_ff.dtype != torch.float32:
raise AssertionError("Gate FFN should be float32")
if self.double_ff:
vis_ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states.float() + gate_ff * vis_ff_output.float()
hidden_states = hidden_states.to(weight_dtype)
if self.ff_enc is not None:
enc_ff_output = self.ff_enc(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.float() + enc_gate_ff * enc_ff_output.float()
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
else:
norm_hidden_states = torch.cat([norm_hidden_states, norm_encoder_hidden_states], dim=0)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states.float() + gate_ff * ff_output[:vis_seq_length].float()
hidden_states = hidden_states.to(weight_dtype)
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states.float() + enc_gate_ff * ff_output[vis_seq_length:].float()
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
return hidden_states, encoder_hidden_states
class AdaNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
"""
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
norm_cls: str = 'fp32_layer_norm',
):
super().__init__()
args = get_args()
self.sequence_parallel = args.sequence_parallel
config = core_transformer_config_from_args(args)
config.sequence_parallel = False
output_dim = output_dim or embedding_dim * 2
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
self.silu = nn.SiLU()
self.linear = tensor_parallel.ColumnParallelLinear(
embedding_dim,
output_dim,
config=config,
init_method=config.init_method,
gather_output=False
)
if norm_cls == 'rms_norm':
self.norm_cls = RMSNorm
elif norm_cls == 'fp32_layer_norm':
self.norm_cls = FP32LayerNorm
self.norm = self.norm_cls(
output_dim // 2, eps=norm_eps
)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))[0]
if self.sequence_parallel:
temb = tensor_parallel.mappings.all_gather_last_dim_from_tensor_parallel_region(temb)
else:
temb = tensor_parallel.mappings.gather_from_tensor_model_parallel_region(temb)
temb = temb.float()
shift, scale = temb.chunk(2, dim=1)
shift = shift[None, :, :]
scale = scale[None, :, :]
if shift.dtype != torch.float32 or scale.dtype != torch.float32:
raise ValueError("Shift and scale must be float32.")
weight_dtype = x.dtype
x = self.norm(x).float() * (1 + scale) + shift
return x.to(weight_dtype)
class OpenSoraNormZero(nn.Module):
def __init__(
self,
timestep_embed_dim: int,
embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
norm_cls: str = 'fp32_layer_norm',
context_pre_only: bool = False,
) -> None:
super().__init__()
args = get_args()
self.sequence_parallel = args.sequence_parallel
config = core_transformer_config_from_args(args)
if norm_cls == 'rms_norm':
self.norm_cls = RMSNorm
elif norm_cls == 'fp32_layer_norm':
self.norm_cls = FP32LayerNorm
self.silu = nn.SiLU()
config.sequence_parallel = False
self.linear = tensor_parallel.ColumnParallelLinear(
timestep_embed_dim,
6 * embedding_dim,
config=config,
init_method=config.init_method,
gather_output=False
)
config.sequence_parallel = self.sequence_parallel
self.norm = self.norm_cls(embedding_dim, eps=eps, sequence_parallel=self.sequence_parallel)
self.norm_enc = None
if not context_pre_only:
self.norm_enc = self.norm_cls(embedding_dim, eps=eps, sequence_parallel=self.sequence_parallel)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
temb = self.linear(self.silu(temb))[0]
temb = temb.float()
if self.sequence_parallel:
temb = tensor_parallel.mappings.all_gather_last_dim_from_tensor_parallel_region(temb)
else:
temb = tensor_parallel.mappings.gather_from_tensor_model_parallel_region(temb)
shift, scale, gate, enc_shift, enc_scale, enc_gate = temb.chunk(6, dim=1)
if not all(value.dtype == torch.float32 for value in [shift, scale, gate, enc_shift, enc_scale, enc_gate]):
raise ValueError("Shift, scale and gate must be float32.")
weight_dtype = hidden_states.dtype
hidden_states = self.norm(hidden_states).float() * (1 + scale)[None, :, :] + shift[None, :, :]
if self.norm_enc is not None:
encoder_hidden_states = self.norm_enc(encoder_hidden_states).float() * (1 + enc_scale)[None, :, :] + enc_shift[None, :, :]
return hidden_states.to(weight_dtype), encoder_hidden_states.to(weight_dtype), gate[None, :, :], enc_gate[None, :, :]
class RoPE3D(torch.nn.Module):
def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)):
super().__init__()
self.base = freq
self.F0 = F0
self.interpolation_scale_t = interpolation_scale_thw[0]
self.interpolation_scale_h = interpolation_scale_thw[1]
self.interpolation_scale_w = interpolation_scale_thw[2]
self.cache = {}
def get_cos_sin(self, D, seq_start, seq_end, device, interpolation_scale=1):
if (D, seq_start, seq_start, seq_end) not in self.cache:
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
t = torch.arange(seq_start, seq_end, device=device, dtype=torch.float32) / interpolation_scale
freqs = torch.einsum("i,j->ij", t, inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos()
sin = freqs.sin()
self.cache[D, seq_start, seq_start, seq_end] = (cos, sin)
return self.cache[D, seq_start, seq_start, seq_end]
def forward(self, dim, positions, device):
"""
input:
* dim: head_dim
* positions: batch_size x ntokens x 3 (t, y and x position of each token)
output:
* tokens after applying RoPE3D (ntokens x batch_size x nheads x dim)
"""
if dim % 16 != 0:
raise Error(f"number of dimensions should be a multiple of 16")
D_t = dim // 16 * 4
D = dim // 16 * 6
poses, min_poses, max_poses = positions
if len(poses) != 3 or poses[0].ndim != 2:
raise AssertionError("poses shape error")
cos_t, sin_t = self.get_cos_sin(D_t, min_poses[0], max_poses[0], device, self.interpolation_scale_t)
cos_y, sin_y = self.get_cos_sin(D, min_poses[1], max_poses[1], device, self.interpolation_scale_h)
cos_x, sin_x = self.get_cos_sin(D, min_poses[2], max_poses[2], device, self.interpolation_scale_w)
cos_t, sin_t = compute_rope1d(poses[0], cos_t, sin_t)
cos_y, sin_y = compute_rope1d(poses[1], cos_y, sin_y)
cos_x, sin_x = compute_rope1d(poses[2], cos_x, sin_x)
return cos_t, sin_t, cos_y, sin_y, cos_x, sin_x
def compute_rope1d(pos1d, cos, sin):
"""
* pos1d: ntokens x batch_size
"""
if pos1d.ndim != 2:
raise AssertionError("pos1d.ndim must be 2")
cos = F.embedding(pos1d, cos)[:, :, None, :]
sin = F.embedding(pos1d, sin)[:, :, None, :]
return cos, sin
class PositionGetter3D:
"""This function returns the positions of image patches."""
def __init__(self, max_t, max_h, max_w, explicit_uniform_rope=False, atten_layout="BSH"):
self.cache_positions = {}
self.atten_layout = atten_layout
self.max_t = max_t
self.max_h = max_h
self.max_w = max_w
self.explicit_uniform_rope = explicit_uniform_rope
@staticmethod
def check_type(param):
if isinstance(param, torch.Tensor):
param = param.item()
return param
def get_positions(self, b, t, h, w, device):
x = torch.arange(w, device=device)
y = torch.arange(h, device=device)
z = torch.arange(t, device=device)
pos = torch.cartesian_prod(z, y, x)
if self.atten_layout == "SBH":
pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, -1, 1).contiguous().expand(3, -1, b).clone()
elif self.atten_layout == "BSH":
pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone()
else:
raise ValueError(f"Unsupported layout type: {self.atten_layout}")
return (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous())
def __call__(self, params: PositionParams):
b = self.check_type(params.b)
t = self.check_type(params.t)
h = self.check_type(params.h)
w = self.check_type(params.w)
device = params.device
training = params.training
s_t = random.randint(0, self.max_t - t) if self.explicit_uniform_rope and training else 0
e_t = s_t + t
s_h = random.randint(0, self.max_h - h) if self.explicit_uniform_rope and training else 0
e_h = s_h + h
s_w = random.randint(0, self.max_w - w) if self.explicit_uniform_rope and training else 0
e_w = s_w + w
key = (b, s_t, e_t, s_h, e_h, s_w, e_w)
if key not in self.cache_positions:
poses = self.get_positions(b, t, h, w, device)
max_poses = (e_t, e_h, e_w)
min_poses = (s_t, s_h, s_w)
self.cache_positions[key] = (poses, min_poses, max_poses)
pos = self.cache_positions[key]
return pos