import importlib
from math import sqrt
from functools import partial
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch import nn, einsum, Tensor
from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange
from typing import Optional
from torchtyping import TensorType
from rotary_embedding_torch import RotaryEmbedding
import torch_npu
from torch_npu.contrib.function import matmul_transpose
from torch._jit_internal import BroadcastingList2
from torch.overrides import has_torch_function_unary, handle_torch_function
from torch.nn.modules.utils import _list_with_default
Cache = namedtuple('Cache', [
'seq_len',
'last_token',
'kv_cumsum',
'k_cumsum'
])
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def shift(t):
t, t_shift = t.chunk(2, dim = -1)
t_shift = F.pad(t_shift, (0, 0, 1, -1), value = 0.)
return torch.cat((t, t_shift), dim = -1)
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return self.gamma * F.normalize(x, dim = -1) * self.scale
def second_taylor_expansion(
x: Tensor,
remove_even_power_dups = False
):
dtype, device, dim = x.dtype, x.device, x.shape[-1]
x, ps = pack([x], '* d')
lead_dims = x.shape[0]
x0 = x.new_ones((lead_dims,))
x1 = x
x2 = einsum('... i, ... j -> ... i j', x, x) * (0.5 ** 0.5)
if remove_even_power_dups:
x2_diagonal = torch.diagonal(x2, dim1 = -2, dim2 = -1)
mask = torch.ones(x2.shape[-2:], dtype = torch.bool).triu(1)
x2_upper_triangle = x2[:, mask] * sqrt(2)
x2 = torch.cat((x2_diagonal, x2_upper_triangle), dim = -1)
out, _ = pack((x0, x1, x2), 'b *')
out, = unpack(out, ps, '* d')
return out
class TaylorSeriesLinearAttn(Module):
def __init__(
self,
dim,
*,
dim_head = 16,
heads = 8,
causal = False,
one_headed_kv = False,
rotary_emb = False,
combine_heads = True,
gate_value_heads = False,
prenorm = False,
shift_tokens = False,
dropout = 0.,
remove_even_power_dups = False
):
super().__init__()
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads
self.shift_tokens = shift_tokens
self.norm = RMSNorm(dim) if prenorm else nn.Identity()
self.heads = heads
self.dim_hidden = dim_inner
self.taylor_expand_fn = partial(second_taylor_expansion, remove_even_power_dups = remove_even_power_dups)
self.causal = causal
self.causal_linear_attn_fn = None
if causal:
if not exists(importlib.util.find_spec('fast_transformers')):
print('pytorch-fast-transformers must be installed. `pip install pytorch-fast-transformers` first')
exit()
from fast_transformers.causal_product import CausalDotProduct
self.causal_linear_attn_fn = CausalDotProduct.apply
kv_heads = heads if not one_headed_kv else 1
dim_kv_inner = dim_head * (heads if not one_headed_kv else 1)
self.rotary_emb = RotaryEmbedding(dim_head) if rotary_emb else None
self.one_headed_kv = one_headed_kv
self.to_q = nn.Sequential(
nn.Linear(dim, dim_inner, bias = False),
)
self.to_kv = nn.Sequential(
nn.Linear(dim, dim_kv_inner * 2, bias = False),
Rearrange('b n (kv h d) -> kv b h n d', kv = 2, h = kv_heads)
)
self.to_v_gates = nn.Sequential(
nn.Linear(dim, heads, bias = False),
nn.Sigmoid(),
Rearrange('b n h -> b h n 1')
) if gate_value_heads else None
self.merge_heads = Rearrange('b h n d -> b n (h d)')
self.to_out = nn.Identity()
if combine_heads:
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
)
def forward(
self,
x: TensorType['batch', 'seq', 'dim', float],
mask: Optional[TensorType['batch', 'seq', bool]] = None,
context: Optional[TensorType['batch', 'target_seq', 'dim', float]] = None,
eps: float = 1e-5,
cache: Optional[Cache] = None,
return_cache = False
):
"""
einops
b - batch
h - heads
d - query / key head dimension
e - value head dimension
n - source query sequence length
m - target key / value sequence length
"""
orig_input, seq_len, is_cross_attn = x, x.shape[-2], exists(context)
assert not (exists(self.rotary_emb) and is_cross_attn), 'rotary embedding does not work with cross attention'
if self.shift_tokens:
if exists(cache):
x, ps = pack([cache.last_token, x], 'b * d')
x = shift(x)
if exists(cache):
_, x = unpack(x, ps, 'b * d')
normed = self.norm(x)
q = self.to_q(normed)
b, n, hd = q.shape
h = self.heads
d = hd // h
q = torch_npu.npu_confusion_transpose(q, (0, 2, 1, 3), (b, n, h, d), transpose_first=False)
k, v = self.to_kv(default(context, normed))
if exists(self.rotary_emb):
rotate_fn = self.rotary_emb.rotate_queries_or_keys
if exists(cache):
rotate_fn = partial(rotate_fn, offset = cache.seq_len)
q, k = map(rotate_fn, (q, k))
q = q * self.scale
q, k = map(self.taylor_expand_fn, (q, k))
if self.causal:
assert not exists(mask), 'masking does not make sense for autoregressive linear attention'
assert not is_cross_attn, 'causal does not make sense with cross attention'
if self.one_headed_kv:
k, v = map(lambda t: repeat(t, 'b 1 n d -> b h n d', h = self.heads), (k, v))
if exists(cache):
assert seq_len == 1
old_seq_len, _, kv_cumsum_cache, k_cumsum_cache = cache
kv = torch.matmul(k.permute(0, 1, 3, 2), v)
kv_cumsum = kv + kv_cumsum_cache
k_cumsum = k + k_cumsum_cache
num = torch.matmul(q, kv_cumsum)
den = einsum('... n d, ... n d -> ... n', q, k_cumsum)
den = rearrange(den, '... -> ... 1')
out = num / den.clamp(min = eps)
if return_cache:
new_cache = Cache(old_seq_len + 1, orig_input, kv_cumsum, k_cumsum)
else:
num = self.causal_linear_attn_fn(q, k, v)
k_cumsum = k.cumsum(dim = -2)
den = einsum('... n d, ... n d -> ... n', q, k_cumsum)
den = rearrange(den, '... -> ... 1')
out = num / den.clamp(min = eps)
if return_cache:
new_kv_cache = einsum('b h n d, b h n e -> b h d e', k, v)
new_k_cumsum_cache = k_cumsum[..., -1:, :]
new_cache = Cache(seq_len, orig_input[:, -1:], new_kv_cache, new_k_cumsum_cache)
else:
assert not return_cache, 'cache is only needed for autoregressive'
if exists(mask):
mask = rearrange(mask, 'b n -> b 1 n 1')
k = k.masked_fill(~mask, 0.)
v = v.masked_fill(~mask, 0.)
if self.one_headed_kv:
k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
kv = matmul_transpose(k, v, transpose_first=True)
qk_inv = 1. / einsum('b h n d, b m d -> b h n', q, k).clamp(min = eps)
out = einsum('b h n d, b d e, b h n -> b h n e', q, kv, qk_inv)
else:
kv = torch.matmul(v.permute(0, 1, 3, 2), k).permute(0, 1, 3, 2)
qk_inv = 1. / einsum('b h n d, b h m d -> b h n', q, k).clamp(min = eps)
out = einsum('b h n d, b h d e, b h n -> b h n e', q, kv, qk_inv)
if exists(self.to_v_gates):
out = out * self.to_v_gates(x)
out = self.merge_heads(out)
out = self.to_out(out)
if not return_cache:
return out
return out, new_cache
def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
r"""Apply a 2D adaptive average pooling over an input signal composed of several input planes.
See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
Args:
output_size: the target output size (single integer or
double-integer tuple)
"""
if has_torch_function_unary(input):
return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
_output_size = _list_with_default(output_size, input.size())
input = input.to(torch.float32)
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)