import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
def get_3d_sincos_pos_embed(
embed_dim: int,
spatial_size: Union[int, Tuple[int, int]],
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
) -> np.ndarray:
r"""
Args:
embed_dim (`int`):
spatial_size (`int` or `Tuple[int, int]`):
temporal_size (`int`):
spatial_interpolation_scale (`float`, defaults to 1.0):
temporal_interpolation_scale (`float`, defaults to 1.0):
"""
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
embed_dim_spatial = 3 * embed_dim // 4
embed_dim_temporal = embed_dim // 4
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0)
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1)
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
return pos_embed
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega
pos = pos.reshape(-1)
out = np.einsum("m,d->md", pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
class PatchEmbed2D_3DsincosPE(nn.Module):
"""3D Image to Patch Embedding with support."""
def __init__(
self,
height=64,
width=64,
frame=1,
t_patch_size=1,
patch_size=2,
in_channels=4,
embed_dim=1152,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
time_interpolation_scale=1,
pos_embed_type="sincos",
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size) * (frame // t_patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
self.height, self.width = height // patch_size, width // patch_size
self.frame = frame // t_patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
self.time_interpolation_scale = time_interpolation_scale
grid_size = (self.height, self.width)
if pos_embed_type is None:
self.pos_embed = None
elif pos_embed_type == "sincos":
pos_embed = get_3d_sincos_pos_embed(
embed_dim,
spatial_size=grid_size,
temporal_size=self.frame,
spatial_interpolation_scale=self.interpolation_scale,
temporal_interpolation_scale=self.time_interpolation_scale,
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0))
else:
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
def forward(self, latent):
d_dtype = latent.dtype
batch_size, frame, dim, height, width = latent.shape
latent = self.proj(latent.reshape(-1, dim, height, width))
height, width = height // self.patch_size, width // self.patch_size
if self.flatten:
latent = latent.flatten(2).transpose(1, 2)
if self.layer_norm:
latent = self.norm(latent)
if self.pos_embed is None:
return latent.to(latent.dtype)
if self.height != height or self.width != width or self.frame != frame:
pos_embed = get_3d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
spatial_size=(height, width),
temporal_size=frame,
spatial_interpolation_scale=self.interpolation_scale,
temporal_interpolation_scale=self.time_interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
else:
pos_embed = self.pos_embed
latent = latent.reshape(batch_size, frame, -1, self.pos_embed.shape[-1]).float()
pos_embed = pos_embed.to(latent.device)
latent = (latent + pos_embed).to(d_dtype)
return latent.reshape(batch_size * frame, -1, self.pos_embed.shape[-1])