from einops import rearrange, repeat
import torch
import torch.nn as nn
from mindspeed_mm.models.common.checkpoint import load_checkpoint
from mindspeed_mm.models.common.resnet_block import ResnetBlock2D
from mindspeed_mm.models.common.attention import Conv2dAttnBlock
from mindspeed_mm.models.common.normalize import normalize
from mindspeed_mm.models.common.activations import Sigmoid
class MOVQ(nn.Module):
def __init__(
self,
from_pretrained: str = None,
double_z=False,
z_channels=4,
resolution=256,
in_channels=3,
out_ch=3,
ch=256,
ch_mult=None,
num_res_blocks=2,
attn_resolutions=None,
dropout=0.0,
n_embed=16384,
embed_dim=4,
**kwargs
):
super().__init__()
self.encoder = Encoder(
double_z=double_z,
z_channels=z_channels,
resolution=resolution,
in_channels=in_channels,
out_ch=out_ch,
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
)
self.quantize = VectorQuantizer(n_embed, embed_dim)
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
if from_pretrained is not None:
load_checkpoint(self, from_pretrained)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
info = self.quantize(h)
return info
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=None, num_res_blocks=2,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels=3,
resolution=256, z_channels=4, double_z=True, use_sdp_attention=False,
**ignore_kwargs):
super().__init__()
if ch_mult is None:
ch_mult = (1, 2, 4, 8)
if attn_resolutions is None:
attn_resolutions = [32]
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock2D(in_channels=block_in, out_channels=block_out, dropout=dropout, act_type="swish"))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(Conv2dAttnBlock(block_in, block_in, use_sdp_attention=use_sdp_attention))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock2D(in_channels=block_in, out_channels=block_in, dropout=dropout, act_type="swish")
self.mid.attn_1 = Conv2dAttnBlock(block_in, block_in, use_sdp_attention=use_sdp_attention)
self.mid.block_2 = ResnetBlock2D(in_channels=block_in, out_channels=block_in, dropout=dropout, act_type="swish")
self.norm_out = normalize(block_in)
self.nonlinearity = Sigmoid()
self.conv_out = nn.Conv2d(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class VectorQuantizer(nn.Module):
def __init__(self, n_e, e_dim):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.e_dim)
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
min_encoding_indices = torch.argmin(d, dim=1)
return min_encoding_indices