"""
Lookup Free Quantization
Proposed in https://arxiv.org/abs/2310.05737
basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss
https://arxiv.org/abs/2309.15505
"""
import torch
from einops import rearrange
from torch.nn import Module
def binary_entropy(prob):
return -prob * log(prob) - (1 - prob) * log(1 - prob)
def log(t, eps=1e-20):
return t.clamp(min=eps).log()
def decimal_to_bits(x: torch.LongTensor, bits: int) -> torch.FloatTensor:
mask = 2 ** torch.arange(bits).to(x)
bits = ((x.unsqueeze(-1) & mask) != 0).float()
return bits * 2 - 1
def bits_to_decimal(x: torch.FloatTensor) -> torch.LongTensor:
x = (x > 0).long()
mask = 2 ** torch.arange(x.size(-1)).to(x)
dec = (x * mask).sum(-1)
return dec
class LFQY(Module):
def __init__(self, dim, entropy_loss_weight=0.1, diversity_gamma=1.0):
super().__init__()
self.dim = dim
self.diversity_gamma = diversity_gamma
self.entropy_loss_weight = entropy_loss_weight
def indices_to_codes(self, indices):
codes = decimal_to_bits(indices, self.dim)
return codes
def forward(self, x, mask=None, inv_temperature=1.):
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
"""
assert x.shape[-1] == self.dim
z = torch.tanh(x / inv_temperature)
quantized = torch.sign(x)
z = z + (quantized - z).detach()
indices = bits_to_decimal(z)
if self.training:
prob = torch.sigmoid(x / inv_temperature)
bit_entropy = binary_entropy(prob).sum(-1).mean()
avg_prob = prob.flatten(0, -2).mean(0)
codebook_entropy = binary_entropy(avg_prob).sum()
"""
1. entropy will be nudged to be low for each bit,
so each scalar commits to one latent binary bit or the other.
2. codebook entropy will be nudged to be high,
to encourage all codes to be uniformly used.
"""
entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy
else:
entropy_aux_loss = torch.zeros(1).to(z)
entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight
return z, entropy_aux_loss, indices
def get_codebook_entry(self, encoding_indices):
return self.indices_to_codes(encoding_indices)