from typing import Optional, Dict
import torch
from torch import nn
from diffusers.models.activations import get_activation
from diffusers.models.normalization import AdaGroupNorm
from diffusers.models.attention_processor import SpatialNorm
from mindspeed_mm.utils.utils import video_to_image
from mindspeed_mm.models.common.conv import (
CausalConv3d,
SafeConv3d,
ContextParallelCausalConv3d,
WfCausalConv3d,
TimePaddingCausalConv3d
)
from mindspeed_mm.models.common.updownsample import UpsampleCausal3D, DownsampleCausal3D
from mindspeed_mm.models.common.normalize import normalize
from mindspeed_mm.models.common.activations import Sigmoid, get_activation_layer
class ResnetBlock2D(nn.Module):
def __init__(
self,
dropout,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_shortcut=False,
num_groups=32,
eps=1e-6,
affine=True,
norm_type="groupnorm",
act_type="silu"
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = normalize(in_channels, num_groups, eps, affine, norm_type=norm_type)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.norm2 = normalize(out_channels, num_groups, eps, affine, norm_type=norm_type)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.activation = get_activation_layer(act_type)()
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
@video_to_image
def forward(self, x):
h = x
h = self.norm1(h)
h = self.activation(h)
h = self.conv1(h)
h = self.norm2(h)
h = self.activation(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
x = x + h
return x
class ResnetBlock3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
num_groups=32,
eps=1e-6,
affine=True,
conv_shortcut=False,
dropout=0,
norm_type="groupnorm",
conv_type="CausalConv3d",
enable_vae_cp=False
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = normalize(in_channels, num_groups, eps, affine, norm_type=norm_type)
self.enable_vae_cp = enable_vae_cp
if conv_type == "WfCausalConv3d":
ConvLayer = ContextParallelCausalConv3d if self.enable_vae_cp else WfCausalConv3d
elif conv_type == "CausalConv3d":
ConvLayer = CausalConv3d
else:
raise ValueError(f"Unsupported convolution type: {conv_type}")
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
self.norm2 = normalize(out_channels, num_groups, eps, affine, norm_type=norm_type)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, padding=padding)
self.activation = nn.SiLU()
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ConvLayer(in_channels, out_channels, kernel_size, padding=padding)
else:
self.nin_shortcut = ConvLayer(in_channels, out_channels, kernel_size=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = self.activation(h)
h_conv1 = self.conv1(h)
h = h_conv1[0] if isinstance(h_conv1, tuple) else h_conv1
h = self.norm2(h)
h = self.activation(h)
h = self.dropout(h)
h_conv2 = self.conv2(h)
h = h_conv2[0] if isinstance(h_conv2, tuple) else h_conv2
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x_conv_shortcut = self.conv_shortcut(x)
x = x_conv_shortcut[0] if isinstance(x_conv_shortcut, tuple) else x_conv_shortcut
else:
x_nin_shortcut = self.nin_shortcut(x)
x = x_nin_shortcut[0] if isinstance(x_nin_shortcut, tuple) else x_nin_shortcut
return x + h
class ContextParallelResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
gather_norm=False,
normalization=normalize,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = normalization(
in_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.conv1 = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = normalization(
out_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = ContextParallelCausalConv3d(
chan_in=out_channels,
chan_out=out_channels,
kernel_size=3,
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
else:
self.nin_shortcut = SafeConv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
self.act = Sigmoid()
def forward(self, x, temb=None, zq=None, clear_fake_cp_cache=True, enable_cp=True, is_encode=True,
conv_cache: Optional[Dict[str, torch.Tensor]] = None, use_conv_cache=False):
new_conv_cache = {}
conv_cache = conv_cache or {}
h = x
if zq is not None:
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, enable_cp=enable_cp)
else:
h = self.norm1(h)
h = self.act(h)
conv_enable_cp = enable_cp if is_encode else True
h, new_conv_cache["conv1"] = self.conv1(h, clear_cache=clear_fake_cp_cache, enable_cp=conv_enable_cp,
conv_cache=conv_cache.get("conv1"),
use_conv_cache=use_conv_cache)
if temb is not None:
h = h + self.temb_proj(self.act(temb))[:, :, None, None, None]
if zq is not None:
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, enable_cp=enable_cp)
else:
h = self.norm2(h)
h = self.act(h)
h = self.dropout(h)
h, new_conv_cache["conv2"] = self.conv2(h, clear_cache=clear_fake_cp_cache, enable_cp=conv_enable_cp,
conv_cache=conv_cache.get("conv2"),
use_conv_cache=use_conv_cache)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x, new_conv_cache["conv_shortcut"] = self.conv_shortcut(x,
clear_cache=clear_fake_cp_cache,
enable_cp=enable_cp,
conv_cache=conv_cache.get("conv_shortcut"),
use_conv_cache=use_conv_cache)
else:
x = self.nin_shortcut(x)
return x + h, new_conv_cache
class ResnetBlockCausal3D(nn.Module):
r"""
A Resnet block.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
pre_norm: bool = True,
eps: float = 1e-6,
non_linearity: str = "swish",
skip_time_act: bool = False,
time_embedding_norm: str = "default",
kernel: Optional[torch.FloatTensor] = None,
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
up: bool = False,
down: bool = False,
conv_shortcut_bias: bool = True,
conv_3d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
linear_cls = nn.Linear
if groups_out is None:
groups_out = groups
if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = TimePaddingCausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = linear_cls(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
else:
self.time_emb_proj = None
if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_3d_out_channels = conv_3d_out_channels or out_channels
self.conv2 = TimePaddingCausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
if self.up:
self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
elif self.down:
self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = TimePaddingCausalConv3d(
in_channels,
conv_3d_out_channels,
kernel_size=1,
stride=1,
bias=conv_shortcut_bias,
)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = (
self.upsample(input_tensor, scale=scale)
)
hidden_states = (
self.upsample(hidden_states, scale=scale)
)
elif self.downsample is not None:
input_tensor = (
self.downsample(input_tensor, scale=scale)
)
hidden_states = (
self.downsample(hidden_states, scale=scale)
)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = (
self.time_emb_proj(temb, scale)[:, :, None, None]
)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = (
self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor