import torch
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
from torch.nn import Module
import torch.nn.functional as F
import random
from beartype import beartype
from beartype.typing import Tuple, Optional, List, Union
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack
from modules.audio2motion.cfm.utils import exists, identity, default, divisible_by, is_odd, coin_flip, pack_one, unpack_one
from modules.audio2motion.cfm.utils import prob_mask_like, reduce_masks_with_and, interpolate_1d, curtail_or_pad, mask_from_start_end_indices, mask_from_frac_lengths
from modules.audio2motion.cfm.module import ConvPositionEmbed, LearnedSinusoidalPosEmb, Transformer
from torch.cuda.amp import autocast
class InContextTransformerAudio2Motion(Module):
def __init__(
self,
*,
dim_in = 64,
dim_audio_in = 1024,
dim = 1024,
depth = 24,
dim_head = 64,
heads = 16,
ff_mult = 4,
ff_dropout = 0.,
time_hidden_dim = None,
conv_pos_embed_kernel_size = 31,
conv_pos_embed_groups = None,
attn_dropout = 0,
attn_flash = False,
attn_qk_norm = True,
use_gateloop_layers = False,
num_register_tokens = 16,
frac_lengths_mask: Tuple[float, float] = (0.7, 1.),
):
super().__init__()
dim_in = default(dim_in, dim)
time_hidden_dim = default(time_hidden_dim, dim * 4)
self.proj_in = nn.Identity()
self.sinu_pos_emb = nn.Sequential(
LearnedSinusoidalPosEmb(dim),
nn.Linear(dim, time_hidden_dim),
nn.SiLU()
)
self.dim_audio_in = dim_audio_in
if self.dim_audio_in != dim_in:
self.to_cond_emb = nn.Linear(self.dim_audio_in, dim_in)
else:
self.to_cond_emb = nn.Identity()
self.frac_lengths_mask = frac_lengths_mask
self.to_embed = nn.Linear(dim_in * 2 + dim_in, dim)
self.null_cond = nn.Parameter(torch.zeros(dim_in))
self.conv_embed = ConvPositionEmbed(
dim = dim,
kernel_size = conv_pos_embed_kernel_size,
groups = conv_pos_embed_groups
)
self.transformer = Transformer(
dim = dim,
depth = depth,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
ff_dropout = ff_dropout,
attn_dropout= attn_dropout,
attn_flash = attn_flash,
attn_qk_norm = attn_qk_norm,
num_register_tokens = num_register_tokens,
adaptive_rmsnorm = True,
adaptive_rmsnorm_cond_dim_in = time_hidden_dim,
use_gateloop_layers = use_gateloop_layers
)
dim_out = dim_in
self.to_pred = nn.Linear(dim, dim_out, bias = False)
@property
def device(self):
return next(self.parameters()).device
@torch.inference_mode()
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
if cond_scale == 1.:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
*,
times,
cond_audio,
self_attn_mask = None,
cond_drop_prob = 0.1,
target = None,
cond = None,
cond_mask = None,
ret=None
):
if ret is None:
ret = {}
x = self.proj_in(x)
if exists(cond):
cond = self.proj_in(cond)
cond = default(cond, x)
batch, seq_len, cond_dim = cond.shape
assert cond_dim == x.shape[-1]
if times.ndim == 0:
times = repeat(times, '-> b', b = cond.shape[0])
if times.ndim == 1 and times.shape[0] == 1:
times = repeat(times, '1 -> b', b = cond.shape[0])
if self.training:
if not exists(cond_mask):
if coin_flip():
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
cond_mask = mask_from_frac_lengths(seq_len, frac_lengths)
else:
p_drop_prob_ = self.frac_lengths_mask[0] + random.random()*(self.frac_lengths_mask[1]-self.frac_lengths_mask[0])
cond_mask = prob_mask_like((batch, seq_len), p_drop_prob_, self.device)
else:
if not exists(cond_mask):
cond_mask = torch.ones((batch, seq_len), device = cond.device, dtype = torch.bool)
cond_mask_with_pad_dim = rearrange(cond_mask, '... -> ... 1')
x = x * cond_mask_with_pad_dim
cond = cond * ~cond_mask_with_pad_dim
if cond_drop_prob > 0.:
cond_drop_mask = prob_mask_like(cond.shape[:1], cond_drop_prob, self.device)
cond = torch.where(
rearrange(cond_drop_mask, '... -> ... 1 1'),
self.null_cond,
cond
)
cond_audio_emb = self.to_cond_emb(cond_audio)
cond_audio_emb_length = cond_audio_emb.shape[-2]
if cond_audio_emb_length != seq_len:
cond_audio_emb = rearrange(cond_audio_emb, 'b n d -> b d n')
cond_audio_emb = interpolate_1d(cond_audio_emb, seq_len)
cond_audio_emb = rearrange(cond_audio_emb, 'b d n -> b n d')
if exists(self_attn_mask):
self_attn_mask = interpolate_1d(self_attn_mask, seq_len)
to_concat = [*filter(exists, (x, cond_audio_emb, cond))]
embed = torch.cat(to_concat, dim = -1)
x = self.to_embed(embed)
x = self.conv_embed(x) + x
time_emb = self.sinu_pos_emb(times)
x = self.transformer(
x,
mask = self_attn_mask,
adaptive_rmsnorm_cond = time_emb
)
x = self.to_pred(x)
ret['pred'] = x
if not exists(target):
return x
else:
loss_mask = reduce_masks_with_and(cond_mask, self_attn_mask)
if not exists(loss_mask):
return F.mse_loss(x, target)
ret['loss_mask'] = loss_mask
loss = F.mse_loss(x, target, reduction = 'none')
loss = reduce(loss, 'b n d -> b n', 'mean')
loss = loss.masked_fill(~loss_mask, 0.)
num = reduce(loss, 'b n -> b', 'sum')
den = loss_mask.sum(dim = -1).clamp(min = 1e-5)
loss = num / den
loss = loss.mean()
ret['mse'] = loss
return loss
if __name__ == '__main__':
model = InContextTransformerAudio2Motion()
input_tensor = torch.randn(2, 125, 64)
time_tensor = torch.rand(2)
audio_tensor = torch.rand(2, 125, 1024)
output = model.forward_with_cond_scale(input_tensor, times=time_tensor, cond_audio=audio_tensor, cond=input_tensor)
print(output.shape)