import torch
from torch import Tensor, nn
from mindspeed_mm.models.ae.movqvae import Encoder
from mindspeed_mm.models.common.checkpoint import load_checkpoint
from mindspeed_mm.models.common.distrib import DiagonalGaussianDistribution
class FluxVae(nn.Module):
def __init__(
self,
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=None,
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
attn_resolutions=None,
use_sdp_attention=True,
from_pretrained: str = None,
**kwargs
):
super().__init__()
if ch_mult is None:
ch_mult = [1, 2, 4, 4]
if attn_resolutions is None:
attn_resolutions = [0]
self.encoder = Encoder(
resolution=resolution,
in_channels=in_channels,
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
z_channels=z_channels,
attn_resolutions=attn_resolutions,
use_sdp_attention=use_sdp_attention,
)
self.scale_factor = scale_factor
self.shift_factor = shift_factor
if from_pretrained is not None:
load_checkpoint(self, from_pretrained)
def encode(self, x: Tensor = None, **kwargs) -> Tensor:
if x is None:
x = kwargs.get('images') or kwargs.get('padded_images')
z = self.encoder(x)
posterior = DiagonalGaussianDistribution(z)
z = posterior.mode()
z = self.scale_factor * (z - self.shift_factor)
return z