import os
from collections import deque
import torch.nn as nn
import torch
from megatron.core import mpu
from mindspeed_mm.models.common.checkpoint import load_checkpoint
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.models.common.normalize import normalize
from mindspeed_mm.models.common.attention import WfCausalConv3dAttnBlock
from mindspeed_mm.models.common.distrib import DiagonalGaussianDistribution
from mindspeed_mm.models.common.wavelet import (
HaarWaveletTransform2D,
HaarWaveletTransform3D,
InverseHaarWaveletTransform2D,
InverseHaarWaveletTransform3D
)
from mindspeed_mm.models.common.conv import Conv2d, WfCausalConv3d
from mindspeed_mm.models.common.resnet_block import ResnetBlock2D, ResnetBlock3D
from mindspeed_mm.models.common.updownsample import (
Upsample,
Downsample,
Spatial2xTime2x3DDownsample,
CachedCausal3DUpsample
)
from mindspeed_mm.utils.utils import (
is_context_parallel_initialized,
initialize_context_parallel,
get_context_parallel_group,
get_context_parallel_rank,
)
from mindspeed_mm.models.common.communications import split_tensor, collect_tensors_across_ranks
WFVAE_MODULE_MAPS = {
"Conv2d": Conv2d,
"ResnetBlock2D": ResnetBlock2D,
"CausalConv3d": WfCausalConv3d,
"AttnBlock3D": WfCausalConv3dAttnBlock,
"ResnetBlock3D": ResnetBlock3D,
"Downsample": Downsample,
"HaarWaveletTransform2D": HaarWaveletTransform2D,
"HaarWaveletTransform3D": HaarWaveletTransform3D,
"InverseHaarWaveletTransform2D": InverseHaarWaveletTransform2D,
"InverseHaarWaveletTransform3D": InverseHaarWaveletTransform3D,
"Spatial2xTime2x3DDownsample": Spatial2xTime2x3DDownsample,
"Upsample": Upsample,
"Spatial2xTime2x3DUpsample": CachedCausal3DUpsample,
"WfCausalConv3dAttnBlock": WfCausalConv3dAttnBlock,
}
def model_name_to_cls(model_name):
if model_name in WFVAE_MODULE_MAPS:
return WFVAE_MODULE_MAPS[model_name]
else:
raise ValueError(f"Model name {model_name} not supported")
class Encoder(MultiModalModule):
def __init__(
self,
latent_dim: int = 8,
base_channels: int = 128,
num_resblocks: int = 2,
energy_flow_hidden_size: int = 64,
dropout: float = 0.0,
use_attention: bool = True,
atten_block: str = "WfCausalConv3dAttnBlock",
norm_type: str = "groupnorm",
l1_dowmsample_block: str = "Downsample",
l1_downsample_wavelet: str = "HaarWaveletTransform2D",
l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample",
l2_downsample_wavelet: str = "HaarWaveletTransform3D",
enable_vae_cp: bool = False,
training: bool = False
) -> None:
super().__init__(config=None)
self.activation = nn.SiLU()
self.down1 = nn.Sequential(
Conv2d(24, base_channels, kernel_size=3, stride=1, padding=1),
*[
ResnetBlock2D(
in_channels=base_channels,
out_channels=base_channels,
dropout=dropout,
norm_type=norm_type,
)
for _ in range(num_resblocks)
],
model_name_to_cls(l1_dowmsample_block)(
in_channels=base_channels,
out_channels=base_channels,
conv_type="WfCausalConv3d"
),
)
self.down2 = nn.Sequential(
Conv2d(
base_channels + energy_flow_hidden_size,
base_channels * 2,
kernel_size=3,
stride=1,
padding=1,
),
*[
ResnetBlock3D(
in_channels=base_channels * 2,
out_channels=base_channels * 2,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d",
enable_vae_cp=enable_vae_cp
)
for _ in range(num_resblocks)
],
model_name_to_cls(l2_dowmsample_block)(base_channels * 2, base_channels * 2, conv_type="WfCausalConv3d",
enable_vae_cp=enable_vae_cp),
)
if l1_downsample_wavelet.endswith("2D"):
l1_channels = 12
elif l1_downsample_wavelet.endswith("3D"):
l1_channels = 24
else:
raise ValueError('l1_downsample_wavelet only Support `HaarWaveletTransform2D` and `HaarWaveletTransform3D`')
self.connect_l1 = Conv2d(
l1_channels, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1
)
self.connect_l2 = Conv2d(
24, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1
)
mid_layers = [
ResnetBlock3D(
in_channels=base_channels * 2 + energy_flow_hidden_size,
out_channels=base_channels * 4,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d",
enable_vae_cp=enable_vae_cp
),
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d",
enable_vae_cp=enable_vae_cp
),
]
if use_attention:
mid_layers.insert(
1, model_name_to_cls(atten_block)(
in_channels=base_channels * 4,
out_channels=base_channels * 4,
norm_type=norm_type)
)
self.mid = nn.Sequential(*mid_layers)
self.norm_out = normalize(base_channels * 4, norm_type=norm_type)
self.conv_out = WfCausalConv3d(
base_channels * 4, latent_dim * 2, kernel_size=3, stride=1, padding=1, parallel_input=False
)
if not training:
self.wavelet_tranform_in = model_name_to_cls(l2_downsample_wavelet)()
self.wavelet_transform_l1 = model_name_to_cls(l1_downsample_wavelet)()
self.wavelet_transform_l2 = model_name_to_cls(l2_downsample_wavelet)()
def forward(self, coeffs):
if not self.training:
coeffs = split_tensor(coeffs, get_context_parallel_group(), get_context_parallel_rank(), first_padding=3)
coeffs = self.wavelet_tranform_in(coeffs)
else:
wt = HaarWaveletTransform3D().to(coeffs.device, dtype=coeffs.dtype)
coeffs = wt(coeffs)
l1_coeffs = coeffs[:, :3]
l1_coeffs = self.wavelet_transform_l1(l1_coeffs)
l1 = self.connect_l1(l1_coeffs)
l2_coeffs = self.wavelet_transform_l2(l1_coeffs[:, :3])
l2 = self.connect_l2(l2_coeffs)
h = self.down1(coeffs)
h = torch.concat([h, l1], dim=1)
h = self.down2(h)
h = torch.concat([h, l2], dim=1)
h = self.mid(h)
h = self.norm_out(h)
h = self.activation(h)
h = torch.cat(collect_tensors_across_ranks(h, get_context_parallel_group()), dim=2)
h = self.conv_out(h)
return h
class Decoder(MultiModalModule):
def __init__(
self,
latent_dim: int = 8,
base_channels: int = 128,
num_resblocks: int = 2,
dropout: float = 0.0,
energy_flow_hidden_size: int = 128,
use_attention: bool = True,
atten_block: str = "WfCausalConv3dAttnBlock",
norm_type: str = "groupnorm",
t_interpolation: str = "nearest",
connect_res_layer_num: int = 2,
l1_upsample_block: str = "Upsample",
l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D",
l2_upsample_block: str = "Spatial2xTime2x3DUpsample",
l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D",
enable_vae_cp: bool = False,
training: bool = False
) -> None:
super().__init__(config=None)
self.energy_flow_hidden_size = energy_flow_hidden_size
self.activation = nn.SiLU()
self.conv_in = WfCausalConv3d(
latent_dim, base_channels * 4, kernel_size=3, stride=1, padding=1
)
mid_layers = [
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d"
),
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4 + energy_flow_hidden_size,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d"
),
]
if use_attention:
mid_layers.insert(
1, model_name_to_cls(atten_block)(
in_channels=base_channels * 4,
out_channels=base_channels * 4,
norm_type=norm_type)
)
self.mid = nn.Sequential(*mid_layers)
self.up2 = nn.Sequential(
*[
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d"
)
for _ in range(num_resblocks)
],
model_name_to_cls(l2_upsample_block)(
base_channels * 4, base_channels * 4, t_interpolation=t_interpolation
),
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4 + energy_flow_hidden_size,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d"
),
)
self.up1 = nn.Sequential(
*[
ResnetBlock3D(
in_channels=base_channels * (4 if i == 0 else 2),
out_channels=base_channels * 2,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d"
)
for i in range(num_resblocks)
],
model_name_to_cls(l1_upsample_block)(in_channels=base_channels * 2, out_channels=base_channels * 2),
ResnetBlock3D(
in_channels=base_channels * 2,
out_channels=base_channels * 2,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d",
enable_vae_cp=enable_vae_cp
),
)
self.layer = nn.Sequential(
*[
ResnetBlock3D(
in_channels=base_channels * (2 if i == 0 else 1),
out_channels=base_channels,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d"
)
for i in range(2)
],
)
if l1_upsample_wavelet.endswith("2D"):
l1_channels = 12
elif l1_upsample_wavelet.endswith("3D"):
l1_channels = 24
else:
raise ValueError('l1_upsample_wavelet only Support `InverseHaarWaveletTransform2D` and `InverseHaarWaveletTransform3D`')
self.connect_l1 = nn.Sequential(
*[
ResnetBlock3D(
in_channels=energy_flow_hidden_size,
out_channels=energy_flow_hidden_size,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d",
enable_vae_cp=enable_vae_cp
)
for _ in range(connect_res_layer_num)
],
Conv2d(energy_flow_hidden_size, l1_channels, kernel_size=3, stride=1, padding=1),
)
self.connect_l2 = nn.Sequential(
*[
ResnetBlock3D(
in_channels=energy_flow_hidden_size,
out_channels=energy_flow_hidden_size,
dropout=dropout,
norm_type=norm_type,
conv_type="WfCausalConv3d"
)
for _ in range(connect_res_layer_num)
],
Conv2d(energy_flow_hidden_size, 24, kernel_size=3, stride=1, padding=1),
)
self.norm_out = normalize(base_channels, norm_type=norm_type)
self.conv_out = Conv2d(base_channels, 24, kernel_size=3, stride=1, padding=1)
if not training:
self.inverse_wavelet_transform_out = model_name_to_cls(l2_upsample_wavelet)()
self.inverse_wavelet_transform_l1 = model_name_to_cls(l1_upsample_wavelet)()
self.inverse_wavelet_transform_l2 = model_name_to_cls(l2_upsample_wavelet)()
def forward(self, z):
h = self.conv_in(z)
h = self.mid(h)
l2_coeffs = self.connect_l2(h[:, -self.energy_flow_hidden_size:])
l2 = self.inverse_wavelet_transform_l2(l2_coeffs)
h = self.up2(h[:, : -self.energy_flow_hidden_size])
l1_coeffs = h[:, -self.energy_flow_hidden_size:]
l1_coeffs = self.connect_l1(l1_coeffs)
l1_coeffs[:, :3] = l1_coeffs[:, :3] + l2
l1 = self.inverse_wavelet_transform_l1(l1_coeffs)
h = self.up1(h[:, : -self.energy_flow_hidden_size])
h = self.layer(h)
h = self.norm_out(h)
h = self.activation(h)
h = self.conv_out(h)
h[:, :3] = h[:, :3] + l1
if not self.training:
dec = self.inverse_wavelet_transform_out(h)
else:
wt = InverseHaarWaveletTransform3D().to(h.device, dtype=h.dtype)
dec = wt(h)
return dec
class WFVAE(MultiModalModule):
def __init__(
self,
from_pretrained: str = None,
latent_dim: int = 8,
base_channels: int = 128,
encoder_num_resblocks: int = 2,
encoder_energy_flow_hidden_size: int = 64,
decoder_num_resblocks: int = 2,
decoder_energy_flow_hidden_size: int = 128,
use_attention: bool = True,
atten_block: str = "WfCausalConv3dAttnBlock",
dropout: float = 0.0,
norm_type: str = "groupnorm",
t_interpolation: str = "nearest",
vae_scale_factor: list = None,
use_tiling: bool = False,
connect_res_layer_num: int = 2,
l1_dowmsample_block: str = "Downsample",
l1_downsample_wavelet: str = "HaarWaveletTransform2D",
l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample",
l2_downsample_wavelet: str = "HaarWaveletTransform3D",
l1_upsample_block: str = "Upsample",
l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D",
l2_upsample_block: str = "Spatial2xTime2x3DUpsample",
l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D",
scale: list = None,
shift: list = None,
dtype: str = "fp32",
vae_cp_size: int = 1,
t_chunk_enc: int = 8,
t_chunk_dec: int = 2,
t_upsample_times=2,
**kwargs
) -> None:
super().__init__(config=None)
self.t_chunk_enc = t_chunk_enc
self.t_chunk_dec = t_chunk_dec
self.t_upsample_times = t_upsample_times
self.use_quant_layer = False
self.vae_scale_factor = vae_scale_factor
self.vae_cp_size = vae_cp_size
if self.vae_cp_size > 0:
if not is_context_parallel_initialized():
initialize_context_parallel(self.vae_cp_size)
self.enable_vae_cp = self.vae_cp_size > 1
if self.enable_vae_cp and mpu.get_pipeline_model_parallel_world_size() > 1:
raise ValueError("VAE-CP can not be enabled with PP.")
self.encoder = Encoder(
latent_dim=latent_dim,
base_channels=base_channels,
num_resblocks=encoder_num_resblocks,
energy_flow_hidden_size=encoder_energy_flow_hidden_size,
dropout=dropout,
use_attention=use_attention,
norm_type=norm_type,
atten_block=atten_block,
l1_dowmsample_block=l1_dowmsample_block,
l1_downsample_wavelet=l1_downsample_wavelet,
l2_dowmsample_block=l2_dowmsample_block,
l2_downsample_wavelet=l2_downsample_wavelet,
enable_vae_cp=self.enable_vae_cp,
training=kwargs.get("training", False)
)
self.decoder = Decoder(
latent_dim=latent_dim,
base_channels=base_channels,
num_resblocks=decoder_num_resblocks,
energy_flow_hidden_size=decoder_energy_flow_hidden_size,
dropout=dropout,
use_attention=use_attention,
norm_type=norm_type,
t_interpolation=t_interpolation,
connect_res_layer_num=connect_res_layer_num,
atten_block=atten_block,
l1_upsample_block=l1_upsample_block,
l1_upsample_wavelet=l1_upsample_wavelet,
l2_upsample_block=l2_upsample_block,
l2_upsample_wavelet=l2_upsample_wavelet,
training=kwargs.get("training", False)
)
self.set_parameters(dtype, scale, shift)
self.configure_cache_offset(l1_dowmsample_block)
if from_pretrained is not None:
load_checkpoint(self, from_pretrained)
self.enable_tiling(use_tiling)
if self.enable_vae_cp:
self.dp_group_nums = torch.distributed.get_world_size() // mpu.get_data_parallel_world_size()
def set_parameters(self, dtype, scale, shift):
self.dtype = dtype
if scale is not None and shift is not None:
self.register_buffer('shift', torch.tensor(shift, dtype=self.dtype)[None, :, None, None, None])
self.register_buffer('scale', torch.tensor(scale, dtype=self.dtype)[None, :, None, None, None])
else:
self.register_buffer("shift", torch.zeros(1, 8, 1, 1, 1))
self.register_buffer("scale", torch.tensor([0.18215] * 8)[None, :, None, None, None])
def configure_cache_offset(self, l1_dowmsample_block):
if l1_dowmsample_block == "Downsample":
self.temporal_uptimes = 4
self._set_cache_offset([self.decoder.up2, self.decoder.connect_l2, self.decoder.conv_in, self.decoder.mid],
1)
self._set_cache_offset(
[self.decoder.up2[-2:], self.decoder.up1, self.decoder.connect_l1, self.decoder.layer],
self.t_upsample_times)
else:
self.temporal_uptimes = 8
self._set_cache_offset([self.decoder.up2, self.decoder.connect_l2, self.decoder.conv_in, self.decoder.mid],
1)
self._set_cache_offset([self.decoder.up2[-2:], self.decoder.connect_l1, self.decoder.up1], 2)
self._set_cache_offset([self.decoder.up1[-2:], self.decoder.layer], 4)
def get_encoder(self):
if self.use_quant_layer:
return [self.quant_conv, self.encoder]
return [self.encoder]
def get_decoder(self):
if self.use_quant_layer:
return [self.post_quant_conv, self.decoder]
return [self.decoder]
def _empty_causal_cached(self, parent):
for _, module in parent.named_modules():
if hasattr(module, 'causal_cached'):
module.causal_cached = deque()
def _set_first_chunk(self, is_first_chunk=True):
for module in self.modules():
if hasattr(module, 'is_first_chunk'):
module.is_first_chunk = is_first_chunk
def _set_causal_cached(self, enable_cached=True):
for _, module in self.named_modules():
if hasattr(module, 'enable_cached'):
module.enable_cached = enable_cached
def _set_cache_offset(self, modules, cache_offset=0):
for module in modules:
for submodule in module.modules():
if hasattr(submodule, 'cache_offset'):
submodule.cache_offset = cache_offset
def build_chunk_start_end(self, t, decoder_mode=False):
start_end = [[0, 1]]
start = 1
end = start
while True:
if start >= t:
break
end = min(t, end + (self.t_chunk_dec if decoder_mode else self.t_chunk_enc))
start_end.append([start, end])
start = end
return start_end
def encode(self, x):
if self.enable_vae_cp and self.vae_cp_size % self.dp_group_nums == 0 and self.vae_cp_size > self.dp_group_nums:
data_list = [torch.empty_like(x) for _ in range(self.vae_cp_size)]
data_list[get_context_parallel_rank()] = x
torch.distributed.all_gather(data_list, x, group=get_context_parallel_group())
data_list = data_list[::self.dp_group_nums]
latents = []
for data in data_list:
latents.append(self._encode(data))
return latents[get_context_parallel_rank() // self.dp_group_nums]
else:
return self._encode(x)
def _encode(self, x):
self._empty_causal_cached(self.encoder)
self._set_first_chunk(True)
dtype = x.dtype
if self.use_tiling:
h = self.tile_encode(x)
else:
h = self.encoder(x)
if self.use_quant_layer:
h = self.quant_conv(h)
self._empty_causal_cached(self.encoder)
posterior = DiagonalGaussianDistribution(h)
if not self.training:
return (posterior.sample() - self.shift.to(x.device, dtype=dtype)) * self.scale.to(x.device, dtype)
else:
return posterior
def tile_encode(self, x):
b, c, t, h, w = x.shape
start_end = self.build_chunk_start_end(t)
result = []
for idx, (start, end) in enumerate(start_end):
self._set_first_chunk(idx == 0)
chunk = x[:, :, start:end, :, :]
chunk = self.encoder(chunk)
if self.use_quant_layer:
chunk = self.quant_conv(chunk)
result.append(chunk)
return torch.cat(result, dim=2)
def decode(self, z):
self._empty_causal_cached(self.decoder)
self._set_first_chunk(True)
if not self.training:
z = z / self.scale.to(z.device, dtype=z.dtype) + self.shift.to(z.device, dtype=z.dtype)
if self.use_tiling:
dec = self.tile_decode(z)
else:
if self.use_quant_layer:
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def tile_decode(self, x):
b, c, t, h, w = x.shape
start_end = self.build_chunk_start_end(t, decoder_mode=True)
result = []
for idx, (start, end) in enumerate(start_end):
self._set_first_chunk(idx == 0)
if idx != 0 and end + 1 < t:
chunk = x[:, :, start:end + 1, :, :]
else:
chunk = x[:, :, start:end, :, :]
if self.use_quant_layer:
chunk = self.post_quant_conv(chunk)
chunk = self.decoder(chunk)
if idx != 0 and end + 1 < t:
chunk = chunk[:, :, :-self.temporal_uptimes]
result.append(chunk.clone())
else:
result.append(chunk.clone())
return torch.cat(result, dim=2)
def forward(self, x, sample_posterior=True):
posterior = self.encode(x)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_last_layer(self):
if hasattr(self.decoder.conv_out, "conv"):
return self.decoder.conv_out.conv.weight
else:
return self.decoder.conv_out.weight
def enable_tiling(self, use_tiling: bool = False):
self.use_tiling = use_tiling
self._set_causal_cached(use_tiling)
def disable_tiling(self):
self.enable_tiling(False)