import math
from math import pi
import functools
from typing import Optional, List
from beartype import beartype
from beartype.typing import Literal, Union, Optional
from einops import rearrange, repeat
import numpy as np
import torch
from torch import nn, einsum, broadcast_tensors, Tensor
from torch.cuda.amp import autocast
import torch.nn.functional as F
from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding
def get_3d_sincos_pos_embed(
embed_dim,
grid_size,
cls_token=False,
extra_tokens=0,
interpolation_scale=(1.0, 1.0, 1.0),
base_size=None,
) -> np.array:
"""
embed_dim: output dimension for each position
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_t = np.arange(grid_size[0], dtype=np.float32) / interpolation_scale[0]
grid_h = np.arange(grid_size[1], dtype=np.float32) / interpolation_scale[1]
grid_w = np.arange(grid_size[2], dtype=np.float32) / interpolation_scale[2]
if base_size is not None:
grid_t *= base_size[0] / grid_size[0]
grid_h *= base_size[1] / grid_size[1]
grid_w *= base_size[2] / grid_size[2]
grid = np.meshgrid(grid_w, grid_h, grid_t)
grid = np.stack(grid, axis=0)
grid = grid.reshape([3, 1, grid_size[2], grid_size[1], grid_size[0]])
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid) -> np.array:
"""
embed_dim: output dimension for each position
grid: list of grid size
"""
if embed_dim % 3 != 0:
raise ValueError("embed_dim must be divisible by 3")
emb_t = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0])
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1])
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2])
emb = np.concatenate([emb_t, emb_h, emb_w], axis=1)
return emb
def get_2d_sincos_pos_embed(
embed_dim,
grid_size,
cls_token=False,
extra_tokens=0,
interpolation_scale=(1.0, 1.0),
base_size=None,
) -> np.array:
"""
embed_dim: output dimension for each position
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / interpolation_scale[0]
grid_w = np.arange(grid_size[1], dtype=np.float32) / interpolation_scale[1]
if base_size is not None:
grid_h *= base_size[0] / grid_size[0]
grid_w *= base_size[1] / grid_size[1]
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid) -> np.array:
"""
embed_dim: output dimension for each position
grid: list of grid size
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos) -> np.array:
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega
pos = pos.reshape(-1)
out = np.einsum("m,d->md", pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
def get_1d_sincos_pos_embed(
embed_dim,
grid_size,
cls_token=False,
extra_tokens=0,
interpolation_scale=1.0,
base_size=None,
) -> np.array:
"""
embed_dim: output dimension for each position
grid_size: int of the grid
return:
pos_embed: [grid_size, embed_dim] or
[1+grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid = np.arange(grid_size, dtype=np.float32) / interpolation_scale
if base_size is not None:
grid *= base_size / grid_size
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_meshgrid_nd(rope_sizes, dim=2, dtype=torch.float32):
"""
Get n-D meshgrid
"""
axis_grid = [torch.linspace(0, rope_sizes[i], rope_sizes[i] + 1, dtype=dtype)[:rope_sizes[i]] for i in range(dim)]
grid = torch.meshgrid(*axis_grid, indexing="ij")
grid = torch.stack(grid, dim=0)
return grid
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
):
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
if not math.isclose(theta_rescale_factor, 1.0, rel_tol=1e-9):
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
freqs = torch.outer(pos * interpolation_factor, freqs)
freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
return freqs_cos, freqs_sin
def get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=10000.0,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
rope_sizes (int | tuple of int | list of int): rotary embed sizes for each dim
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(
rope_sizes, dim=len(rope_dim_list)
)
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
if len(theta_rescale_factor) != len(rope_dim_list):
raise ValueError(f"len(theta_rescale_factor): {len(theta_rescale_factor)} should equal to len(rope_dim_list): {len(rope_dim_list)}")
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
if len(interpolation_factor) != len(rope_dim_list):
raise ValueError(f"len(interpolation_factor): {len(interpolation_factor)} should equal to len(rope_dim_list): {len(rope_dim_list)}")
embs = []
for i, rope_dim in enumerate(rope_dim_list):
emb = get_1d_rotary_pos_embed(
rope_dim,
grid[i].reshape(-1),
theta,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
)
embs.append(emb)
cos = torch.cat([emb[0] for emb in embs], dim=1)
sin = torch.cat([emb[1] for emb in embs], dim=1)
return cos, sin
class PositionEmbedding(nn.Module):
def __init__(self, max_num_patch_per_side, hidden_size):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size
self.pos_embed = nn.Parameter(
torch.zeros(max_num_patch_per_side ** 2, hidden_size),
requires_grad=False
)
self._init_weights()
def _init_weights(self):
pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
def forward(self, position_ids):
return self.pos_embed[position_ids]
class PositionEmbedding2D(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
if dim % 4 != 0:
raise Exception("dim must be divisible by 4")
half_dim = dim // 2
inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _get_sin_cos_emb(self, t: torch.Tensor):
out = torch.einsum("i,d->id", t, self.inv_freq)
emb_cos = torch.cos(out)
emb_sin = torch.sin(out)
return torch.cat((emb_sin, emb_cos), dim=-1)
@functools.lru_cache(maxsize=512)
def _get_cached_emb(
self,
device: torch.device,
dtype: torch.dtype,
h: int,
w: int,
scale=1.0,
base_size=None,
):
grid_h = torch.arange(h, device=device) / scale
grid_w = torch.arange(w, device=device) / scale
if base_size is not None:
grid_h *= base_size / h
grid_w *= base_size / w
grid_h, grid_w = torch.meshgrid(
grid_w,
grid_h,
indexing="ij",
)
grid_h = grid_h.t().reshape(-1)
grid_w = grid_w.t().reshape(-1)
emb_h = self._get_sin_cos_emb(grid_h)
emb_w = self._get_sin_cos_emb(grid_w)
return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype)
def forward(
self,
x: torch.Tensor,
h: int,
w: int,
scale: Optional[float] = 1.0,
base_size: Optional[int] = None,
) -> torch.Tensor:
return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r=2).contiguous()
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d r -> ... (d r)')
@autocast(enabled=False)
def apply_rotary_emb(freqs, t, start_index=0, scale=1., seq_dim=-2):
dtype = t.dtype
if t.ndim == 3:
seq_len = t.shape[seq_dim]
freqs = freqs[-seq_len:]
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
if rot_dim > t.shape[-1]:
raise Exception(f"feature dimension {t.shape[-1]} is not \
of sufficient size to rotate in all the positions {rot_dim}")
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
out = torch.cat((t_left, t, t_right), dim=-1)
return out.type(dtype)
class NpuRotaryEmbedding(nn.Module):
@beartype
def __init__(
self,
dim,
custom_freqs: Optional[Tensor] = None,
freqs_for: Union[
Literal['lang'],
Literal['pixel'],
Literal['constant']
] = 'lang',
theta=10000,
max_freq=10,
num_freqs=1,
learned_freq=False,
use_xpos=False,
xpos_scale_base=512,
interpolate_factor=1.,
theta_rescale_factor=1.,
seq_before_head_dim=False,
cache_if_possible=True
):
super().__init__()
theta *= theta_rescale_factor ** (dim / (dim - 2))
self.freqs_for = freqs_for
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == 'lang':
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
elif freqs_for == 'constant':
freqs = torch.ones(num_freqs).float()
self.cache_if_possible = cache_if_possible
self.tmp_store('cached_freqs', None)
self.tmp_store('cached_scales', None)
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
self.learned_freq = learned_freq
self.tmp_store('dummy', torch.tensor(0))
self.seq_before_head_dim = seq_before_head_dim
self.default_seq_dim = -3 if seq_before_head_dim else -2
if interpolate_factor < 1.:
raise Exception("interpolate_factor must less than 1.")
self.interpolate_factor = interpolate_factor
self.use_xpos = use_xpos
if not use_xpos:
self.tmp_store('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = xpos_scale_base
self.tmp_store('scale', scale)
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
@property
def device(self):
return self.dummy.device
def tmp_store(self, key, value):
self.register_buffer(key, value, persistent=False)
def get_seq_pos(self, seq_len, device, dtype, offset=0):
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
seq_dim = default(seq_dim, self.default_seq_dim)
if self.use_xpos:
raise Exception("you must use `.rotate_queries_and_keys` method \
instead and pass in both queries and keys, \
for length extrapolatable rotary embeddings")
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
freqs = self.forward(self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset), seq_len=seq_len, offset=offset)
if seq_dim == -3:
freqs = rearrange(freqs, 'n d -> n 1 d')
return apply_rotary_emb(freqs, t, seq_dim=seq_dim)
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
seq_dim = default(seq_dim, self.default_seq_dim)
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
if q_len > k_len:
raise Exception("q_len must ")
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset)
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
rotated_q = rotated_q.type(q.dtype)
rotated_k = rotated_k.type(k.dtype)
return rotated_q, rotated_k
def rotate_queries_and_keys(self, q, k, seq_dim=None):
seq_dim = default(seq_dim, self.default_seq_dim)
if not self.use_xpos:
raise Exception("use_xpos must be true when we use rotate_queries_and_keys")
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
freqs = self.forward(seq, seq_len=seq_len)
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
if seq_dim == -3:
freqs = rearrange(freqs, 'n d -> n 1 d')
scale = rearrange(scale, 'n d -> n 1 d')
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
rotated_k = apply_rotary_emb(freqs, k, scale=scale ** -1, seq_dim=seq_dim)
rotated_q = rotated_q.type(q.dtype)
rotated_k = rotated_k.type(k.dtype)
return rotated_q, rotated_k
@beartype
def get_scale(
self,
t: Tensor,
seq_len: Optional[int] = None,
offset=0
):
if not self.use_xpos:
raise Exception("use_xpos must be true when we use get_scale method")
should_cache = (
self.cache_if_possible and
exists(seq_len)
)
if (
should_cache and \
exists(self.cached_scales) and \
(seq_len + offset) <= self.cached_scales.shape[0]
):
return self.cached_scales[offset:(offset + seq_len)]
scale = 1.
if self.use_xpos:
power = (t - len(t) // 2) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim=-1)
if should_cache:
self.tmp_store('cached_scales', scale)
return scale
def get_axial_freqs(self, *dims):
Colon = slice(None)
all_freqs = []
for ind, dim in enumerate(dims):
if self.freqs_for == 'pixel':
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
else:
pos = torch.arange(dim, device=self.device)
freqs = self.forward(pos, seq_len=dim)
all_axis = [None] * len(dims)
all_axis[ind] = Colon
new_axis_slice = (Ellipsis, *all_axis, Colon)
all_freqs.append(freqs[new_axis_slice])
all_freqs = broadcast_tensors(*all_freqs)
return torch.cat(all_freqs, dim=-1)
@autocast(enabled=False)
def forward(
self,
t: Tensor,
seq_len=None,
offset=0
):
should_cache = (
self.cache_if_possible and \
not self.learned_freq and \
exists(seq_len) and \
self.freqs_for != 'pixel'
)
if (
should_cache and \
exists(self.cached_freqs) and \
(offset + seq_len) <= self.cached_freqs.shape[0]
):
return self.cached_freqs[offset:(offset + seq_len)].detach()
freqs = self.freqs
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = repeat(freqs, '... n -> ... (n r)', r=2)
if should_cache:
self.tmp_store('cached_freqs', freqs.detach())
return freqs
def broad_cat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
class Rotary3DPositionEmbedding(nn.Module):
def __init__(
self,
height,
width,
compressed_num_frames,
hidden_size,
hidden_size_head,
text_length,
theta=10000,
rot_v=False,
learnable_pos_embed=False,
):
super().__init__()
self.rot_v = rot_v
dim_t = hidden_size_head // 4
dim_h = hidden_size_head // 8 * 3
dim_w = hidden_size_head // 8 * 3
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
grid_h = torch.arange(height, dtype=torch.float32)
grid_w = torch.arange(width, dtype=torch.float32)
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broad_cat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
self.freqs = freqs.contiguous().npu()
self.text_length = text_length
if learnable_pos_embed:
num_patches = int(height * width * compressed_num_frames + text_length)
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
else:
self.pos_embedding = None
def rotary(self, t, **kwargs):
def reshape_freq(freqs):
freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous()
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.unsqueeze(0).unsqueeze(0)
return freqs
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
return npu_rotary_position_embedding(t, freqs_cos, freqs_sin, mode=1)
def position_embedding_forward(self, position_ids, **kwargs):
if self.pos_embedding is not None:
return self.pos_embedding[:, :self.text_length + kwargs.get("seq_length", 0)]
else:
return None
def apply_rotary_pos_emb(self, x, freqs):
freqs_cos = freqs.cos().to(x.dtype)
freqs_sin = freqs.sin().to(x.dtype)
x = npu_rotary_position_embedding(x, freqs_cos, freqs_sin, mode=1)
return x
def forward(self, rope_T, rope_H, rope_W):
freqs = self.freqs[: rope_T, : rope_H, : rope_W].contiguous()
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs[:, None, None, :]
freqs_text_padding = torch.zeros([self.text_length, 1, 1, freqs.shape[-1]], device=freqs.device,
dtype=freqs.dtype)
freqs = torch.cat((freqs_text_padding, freqs), dim=0)
return freqs
class RoPE3DSORA(nn.Module):
def __init__(self, head_dim, freq=10000.0, interpolation_scale=(1, 1, 1)):
super().__init__()
if head_dim % 3 != 0:
raise ValueError("number of head dimensions should be a multiple of three")
self.dim = head_dim // 3
self.base = freq
self.interpolation_scale_t, self.interpolation_scale_h, self.interpolation_scale_w = interpolation_scale
self.cache = {}
self.cache_positions = {}
def check_type(self, param):
if isinstance(param, torch.Tensor):
param = param.item()
return param
def get_position(self, b, t, h, w, device):
b = self.check_type(b)
t = self.check_type(t)
h = self.check_type(h)
w = self.check_type(w)
if not (b, t, h, w) in self.cache_positions:
x = torch.arange(w, device=device)
y = torch.arange(h, device=device)
z = torch.arange(t, device=device)
pos = torch.cartesian_prod(z, y, x)
pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, -1, 1).contiguous().expand(3, -1, b).clone()
poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous())
max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max()))
self.cache_positions[b, t, h, w] = (poses, max_poses)
pos = self.cache_positions[b, t, h, w]
return pos
def get_freq(self, seq_len, pos1d, device, interpolation_scale=1):
freqs = None
if (self.dim, seq_len, device) not in self.cache:
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale
freqs = torch.einsum("i,j->ij", t, inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1)
self.cache[self.dim, seq_len, device] = freqs
freqs = self.cache[self.dim, seq_len, device]
return F.embedding(pos1d, freqs)[:, :, None, :]
@staticmethod
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rope1d(self, tokens, freq):
cos = freq.cos()
sin = freq.sin()
return (tokens * cos) + (self.rotate_half(tokens) * sin)
def apply_rotary_pos_emb(self, tokens, freq):
if tokens.size(3) % 3 != 0:
raise AssertionError("number of dimensions should be a multiple of three")
freq = freq.to(tokens.dtype)
t, y, x = tokens.chunk(3, dim=-1)
freq_t, freq_y, freq_x = freq.chunk(3, dim=-1)
t = self.apply_rope1d(t, freq_t)
y = self.apply_rope1d(y, freq_y)
x = self.apply_rope1d(x, freq_x)
tokens = torch.cat((t, y, x), dim=-1)
return tokens
def forward(self, b, t, h, w, device):
poses, max_poses = self.get_position(b, t, h, w, device)
freq_t = self.get_freq(max_poses[0] + 1, poses[0], device, self.interpolation_scale_t)
freq_y = self.get_freq(max_poses[1] + 1, poses[1], device, self.interpolation_scale_h)
freq_x = self.get_freq(max_poses[2] + 1, poses[2], device, self.interpolation_scale_w)
freq = torch.cat((freq_t, freq_y, freq_x), dim=-1)
return freq
class RoPE3DStepVideo(RoPE3DSORA):
def __init__(self, ch_split, freq=10000.0):
super().__init__(head_dim=3)
self.base = freq
self.ch_split = ch_split
def apply_rotary_pos_emb(self, tokens, freqs):
freqs = freqs.to(tokens.dtype)
out = []
for _, (x, freq) in enumerate(zip(torch.split(tokens, self.ch_split, dim=-1), torch.split(freqs, self.ch_split, dim=-1))):
x_i = self.apply_rope1d(x, freq)
out.append(x_i)
tokens = torch.cat(out, dim=-1)
return tokens
def get_freq(self, seq_len, pos1d, dim, device):
freqs = None
if (dim, seq_len, device) not in self.cache:
inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float().to(device) / dim))
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1)
self.cache[dim, seq_len, device] = freqs
freqs = self.cache[dim, seq_len, device]
return F.embedding(pos1d, freqs)[:, :, None, :]
def forward(self, b, t, h, w, device):
poses, max_poses = self.get_position(b, t, h, w, device)
out = []
for i, dim in enumerate(self.ch_split):
freq_i = self.get_freq(max_poses[i] + 1, poses[i], dim, device)
out.append(freq_i)
freqs = torch.cat(out, dim=-1)
return freqs