from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from diffusers.utils import logging
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import Attention, SpatialNorm
from mindspeed_mm.models.common.checkpoint import load_checkpoint
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.models.common.distrib import DiagonalGaussianDistribution
from mindspeed_mm.models.common.conv import TimePaddingCausalConv3d
from mindspeed_mm.models.common.resnet_block import ResnetBlockCausal3D
from mindspeed_mm.models.common.updownsample import UpsampleCausal3D, DownsampleCausal3D
logger = logging.get_logger(__name__)
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
seq_len = n_frame * n_hw
mask = torch.full((seq_len,), float("-inf"), dtype=dtype, device=device)
frame_indices = torch.div(torch.arange(seq_len, device=device), n_hw, rounding_mode='floor')
comparison_matrix = frame_indices.unsqueeze(1) >= frame_indices.unsqueeze(0)
mask = torch.where(comparison_matrix, torch.tensor(0., dtype=dtype, device=device), mask).view(seq_len, seq_len)
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
class UNetMidBlockCausal3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
"""
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
resnets = [
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
B, C, T, H, W = hidden_states.shape
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
attention_mask = prepare_causal_attention_mask(
T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
)
hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class DownEncoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_stride: int = 2,
downsample_padding: int = 1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
DownsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
stride=downsample_stride,
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, scale=scale)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale)
return hidden_states
class UpDecoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
upsample_scale_factor=(2, 2, 2),
temb_channels: Optional[int] = None,
):
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList(
[
UpsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
)
]
)
else:
self.upsamplers = None
self.resolution_idx = resolution_idx
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class EncoderCausal3D(nn.Module):
r"""
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = TimePaddingCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(
i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
and not is_final_block
)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownEncoderBlockCausal3D":
down_block = DownEncoderBlockCausal3D(
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
dropout=0.0,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
downsample_stride=downsample_stride,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample_padding=0,
resnet_time_scale_shift="default",
)
else:
raise ValueError(f"{down_block_type} does not exist.")
self.down_blocks.append(down_block)
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = TimePaddingCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderCausal3D` class."""
sample = self.conv_in(sample)
for down_block in self.down_blocks:
sample = down_block(sample)
sample = self.mid_block(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class DecoderCausal3D(nn.Module):
r"""
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group",
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = TimePaddingCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(
i >= len(block_out_channels) - 1 - num_time_upsample_layers
and not is_final_block
)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpDecoderBlockCausal3D":
up_block = UpDecoderBlockCausal3D(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
resolution_idx=None,
dropout=0.0,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=norm_type,
temb_channels=temb_channels,
)
else:
raise ValueError(f"{up_block_type} does not exist.")
self.up_blocks.append(up_block)
prev_output_channel = output_channel
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = TimePaddingCausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `DecoderCausal3D` class."""
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds
)
sample = sample.to(upscale_dtype)
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds
)
else:
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class AutoencoderKLHunyuanVideo(MultiModalModule):
r"""
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
"""
def __init__(
self,
from_pretrained: str = None,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
sample_tsize: int = 64,
scaling_factor: float = 0.18215,
force_upcast: float = True,
spatial_compression_ratio: int = 8,
time_compression_ratio: int = 4,
mid_block_add_attention: bool = True,
enable_tiling: bool = False,
use_slicing: bool = False,
use_spatial_tiling: bool = False,
use_temporal_tiling: bool = False,
**kwargs
):
super().__init__(config=None)
self.time_compression_ratio = time_compression_ratio
self.spatial_compression_ratio = spatial_compression_ratio
self.encoder = EncoderCausal3D(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
mid_block_add_attention=mid_block_add_attention,
)
self.decoder = DecoderCausal3D(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
self.use_slicing = use_slicing
self.use_spatial_tiling = use_spatial_tiling
self.use_temporal_tiling = use_temporal_tiling
self.tile_sample_min_tsize = sample_tsize
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
self.tile_sample_min_size = sample_size
sample_size = (
sample_size[0] if isinstance(sample_size, (list, tuple))
else sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
self.scaling_factor = scaling_factor
self.force_upcast = force_upcast
if from_pretrained is not None:
load_checkpoint(self, from_pretrained)
if enable_tiling:
self.enable_tiling()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
module.gradient_checkpointing = value
def enable_temporal_tiling(self, use_tiling: bool = True):
self.use_temporal_tiling = use_tiling
def disable_temporal_tiling(self):
self.enable_temporal_tiling(False)
def enable_spatial_tiling(self, use_tiling: bool = True):
self.use_spatial_tiling = use_tiling
def disable_spatial_tiling(self):
self.enable_spatial_tiling(False)
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger videos.
"""
self.enable_spatial_tiling(use_tiling)
self.enable_temporal_tiling(use_tiling)
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.disable_spatial_tiling()
self.disable_temporal_tiling()
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def encode(
self, x: torch.FloatTensor,
do_sample: bool = True,
generator: Optional[torch.Generator] = None,
) -> torch.FloatTensor:
"""
Encode a batch of images/videos into latents.
Args:
x (`torch.FloatTensor`): Input batch of images/videos.
do_sample:
generate
Returns:
"""
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
posterior = self.temporal_tiled_encode(x)
if do_sample:
z = posterior.sample() * self.scaling_factor
else:
z = posterior.mode() * self.scaling_factor
return z
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
posterior = self.spatial_tiled_encode(x)
if do_sample:
z = posterior.sample() * self.scaling_factor
else:
z = posterior.mode() * self.scaling_factor
return z
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if do_sample:
z = posterior.sample() * self.scaling_factor
else:
z = posterior.mode() * self.scaling_factor
return z
def _decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
return self.temporal_tiled_decode(z)
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.spatial_tiled_decode(z)
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, generator=None
) -> torch.FloatTensor:
"""
Decode a batch of images/videos.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
Returns:
dec (`torch.FloatTensor`): Decode images/videos
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
return decoded
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
return b
def spatial_tiled_encode(self, x: torch.FloatTensor, return_moments: bool = False) -> DiagonalGaussianDistribution:
r"""
Encode a batch of images/videos using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`):
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
`tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, x.shape[-2], overlap_size):
row = []
for j in range(0, x.shape[-1], overlap_size):
tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
if return_moments:
return moments
posterior = DiagonalGaussianDistribution(moments)
return posterior
def spatial_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
r"""
Decode a batch of images/videos using a tiled decoder.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
Returns:
dec(`torch.FloatTensor`): Decode images/videos
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, z.shape[-2], overlap_size):
row = []
for j in range(0, z.shape[-1], overlap_size):
tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def temporal_tiled_encode(self, x: torch.FloatTensor) -> DiagonalGaussianDistribution:
B, C, T, H, W = x.shape
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
t_limit = self.tile_latent_min_tsize - blend_extent
row = []
for i in range(0, T, overlap_size):
tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
tile = self.spatial_tiled_encode(tile, return_moments=True)
else:
tile = self.encoder(tile)
tile = self.quant_conv(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, :t_limit + 1, :, :])
moments = torch.cat(result_row, dim=2)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def temporal_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
B, C, T, H, W = z.shape
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
t_limit = self.tile_sample_min_tsize - blend_extent
row = []
for i in range(0, T, overlap_size):
tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
decoded = self.spatial_tiled_decode(tile)
else:
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
if i > 0:
decoded = decoded[:, :, 1:, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, :t_limit + 1, :, :])
dec = torch.cat(result_row, dim=2)
return dec
def forward(
self,
sample: torch.FloatTensor,
do_sample: bool = False,
generator: Optional[torch.Generator] = None,
) -> torch.FloatTensor:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
"""
x = sample
posterior = self.encode(x).latent_dist
if do_sample:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
return dec