"""
Rotary Position Embedding implementation of different types along with helper functions
"""
from typing import Optional, Tuple, Union, List
import torch
import torch_npu
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""
def __init__(
self,
dim: int,
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0,
interleaved: bool = False,
):
"""
Parameters
----------
dim: int
Rotary embedding dimension.
rotary_percent: float, default = 1.0
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int, default = None
If not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int, default = None
Pre-trained max_position_embeddings before position interpolation.
rotary_base: float, default = 10000.0
Base of the rotary position embedding.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
"""
super().__init__()
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.rotary_base = rotary_base
inv_freq = 1.0 / (
self.rotary_base
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.npu.current_device())
/ dim
)
)
self.register_buffer("inv_freq", inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
self.interleaved = interleaved
def forward(self, max_seq_len: int, offset: int = 0):
"""
Create rotary position embedding frequencies.
This function is particularly sensitive to the use of mixed precision, so we disable the
autocast context if it is enabled.
Parameters
----------
max_seq_len: int
Sequence length of a sample.
offset: int, default = 0
Fixed offset for frequencies.
"""
with torch.autocast(enabled=False, device_type="npu"):
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if (
self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None
):
if (
max_seq_len
> self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
):
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
else:
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
if not self.interleaved:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
freqs.shape[0], -1
)
return emb.reshape(emb.size(0), 1, 1, emb.size(1))
def _rotate_half_thd(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
if not interleaved:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
x1 = x[:, :, ::2]
x2 = x[:, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], -1)
def _npu_rotary_mul(
t: torch.Tensor,
cos_: torch.Tensor,
sin_: torch.Tensor,
interleaved: bool = False,
) -> torch.Tensor:
rotary_mode = "interleave" if interleaved else "half"
return torch_npu.npu_rotary_mul(t.contiguous(), cos_.contiguous(), sin_.contiguous(), rotary_mode)
def _apply_fused_rope_4d(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str,
interleaved: bool,
) -> torch.Tensor:
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1)
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
t_rot = t[..., :rot_dim]
t_pass = t[..., rot_dim:]
use_npu_kernel = True
if interleaved:
if tensor_format == "sbhd" and freqs.shape[1] != 1:
use_npu_kernel = False
elif tensor_format == "bshd" and freqs.shape[0] != 1:
use_npu_kernel = False
if use_npu_kernel:
t_rot = _npu_rotary_mul(t_rot.contiguous(), cos_.contiguous(), sin_.contiguous(), interleaved)
else:
t_rot = (t_rot * cos_) + (_rotate_half(t_rot, interleaved) * sin_)
return torch.cat([t_rot, t_pass], dim=-1)
def _apply_fused_rope_thd_single(
x: torch.Tensor,
freqs: torch.Tensor,
interleaved: bool = False,
) -> torch.Tensor:
x_4d = x.unsqueeze(1)
cos_ = torch.cos(freqs).to(x_4d.dtype)
sin_ = torch.sin(freqs).to(x_4d.dtype)
rot_dim = freqs.shape[-1]
t_rot = x_4d[..., :rot_dim].contiguous()
t_pass = x_4d[..., rot_dim:]
t_rot = _npu_rotary_mul(t_rot, cos_, sin_, interleaved)
result = torch.cat([t_rot, t_pass], dim=-1)
return result.squeeze(1)
class FusedRoPEFunc(torch.autograd.Function):
"""Function for FusedRoPE on NPU.
Uses torch_npu.npu_rotary_mul as the fused kernel. The Python-level preprocessing
(start_positions, context parallel, cu_seqlens) is handled before calling the kernel,
and the kernel handles the core rotation computation with autograd support.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
start_positions: Union[torch.Tensor, None] = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
if freqs.dtype != torch.float32:
freqs = freqs.float()
assert tensor_format in (
"sbhd",
"bshd",
"thd",
), f"Unsupported tensor_format: {tensor_format}."
if tensor_format == "thd":
assert cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'."
cu_seqlens_cp = cu_seqlens // cp_size
seqlens = (cu_seqlens_cp[1:] - cu_seqlens_cp[:-1]).tolist()
output = torch.cat(
[
_apply_fused_rope_thd_single(
x,
_get_freqs_on_this_cp_rank(
freqs[start_positions[idx] :] if start_positions is not None else freqs,
x.size(0),
cp_size,
cp_rank,
),
interleaved=interleaved,
)
for idx, x in enumerate(torch.split(t, seqlens))
]
)
ctx.save_for_backward(freqs, start_positions, cu_seqlens)
ctx.seqlens = seqlens
else:
if tensor_format == "sbhd":
seqlen = t.size(0)
else:
seqlen = t.size(1)
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + seqlen * cp_size <= freqs.shape[0]
), f"Rotary Embeddings only supported up to {freqs.shape[0]} sequence length!"
freqs = torch.concatenate(
[freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1
)
freqs = _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank)
output = _apply_fused_rope_4d(t, freqs, tensor_format, interleaved)
ctx.save_for_backward(freqs)
ctx.tensor_format = tensor_format
ctx.interleaved = interleaved
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.t_shape = t.shape
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
tensor_format = ctx.tensor_format
interleaved = ctx.interleaved
if tensor_format == "thd":
freqs, start_positions, cu_seqlens = ctx.saved_tensors
seqlens = ctx.seqlens
cp_size = ctx.cp_size
cp_rank = ctx.cp_rank
grad_inputs = []
for idx, x_grad in enumerate(torch.split(grad_output, seqlens)):
seq_freqs = _get_freqs_on_this_cp_rank(
freqs[start_positions[idx] :] if start_positions is not None else freqs,
x_grad.size(0),
cp_size,
cp_rank,
)
cos_ = torch.cos(seq_freqs).reshape(seq_freqs.shape[0], 1, seq_freqs.shape[-1]).to(x_grad.dtype)
sin_ = torch.sin(seq_freqs).reshape(seq_freqs.shape[0], 1, seq_freqs.shape[-1]).to(x_grad.dtype)
rot_dim = seq_freqs.shape[-1]
x_grad_rot = x_grad[..., :rot_dim]
x_grad_rot = (x_grad_rot * cos_) - (_rotate_half_thd(x_grad_rot, interleaved) * sin_)
if rot_dim < x_grad.shape[-1]:
grad_inputs.append(torch.cat([x_grad_rot, x_grad[..., rot_dim:]], dim=-1))
else:
grad_inputs.append(x_grad_rot)
grad_input = torch.cat(grad_inputs)
return grad_input, None, None, None, None, None, None, None, None
freqs, = ctx.saved_tensors
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1)
cos_ = torch.cos(freqs).to(grad_output.dtype)
sin_ = torch.sin(freqs).to(grad_output.dtype)
rot_dim = freqs.shape[-1]
grad_output_rot = grad_output[..., :rot_dim].contiguous()
use_npu_kernel = True
if interleaved:
if tensor_format == "sbhd" and freqs.shape[1] != 1:
use_npu_kernel = False
elif tensor_format == "bshd" and freqs.shape[0] != 1:
use_npu_kernel = False
if use_npu_kernel:
grad_input_rot = _npu_rotary_mul(grad_output_rot, cos_, -sin_, interleaved)
else:
grad_input_rot = (grad_output_rot * cos_) - (_rotate_half(grad_output_rot, interleaved) * sin_)
if rot_dim < grad_output.shape[-1]:
grad_input = torch.cat([grad_input_rot, grad_output[..., rot_dim:]], dim=-1)
else:
grad_input = grad_input_rot
return grad_input, None, None, None, None, None, None, None, None
class FusedQKVRoPEFunc(torch.autograd.Function):
"""Function for FusedQKVRoPE on NPU.
Applies RoPE to Q and K separately using torch_npu.npu_rotary_mul,
while V passes through unchanged.
"""
@staticmethod
def forward(
ctx,
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
start_positions: Union[torch.Tensor, None] = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
if q_freqs.dtype != torch.float32:
q_freqs = q_freqs.float()
if k_freqs.dtype != torch.float32:
k_freqs = k_freqs.float()
assert tensor_format in (
"sbhd",
"bshd",
), f"Unsupported tensor_format: {tensor_format}."
assert qkv.is_contiguous(), "QKV Tensor should be contiguous."
assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous."
assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous."
q_dim, k_dim, v_dim = qkv_split_arg_list
q, k, v = torch.split(qkv, [q_dim, k_dim, v_dim], dim=-1)
if tensor_format == "sbhd":
seqlen = qkv.size(0)
else:
seqlen = qkv.size(1)
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + seqlen * cp_size <= q_freqs.shape[0]
), f"Rotary Embeddings only supported up to {q_freqs.shape[0]} sequence length!"
q_freqs = torch.concatenate(
[q_freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1
)
k_freqs = torch.concatenate(
[k_freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1
)
q_freqs = _get_freqs_on_this_cp_rank(q_freqs, seqlen, cp_size, cp_rank)
k_freqs = _get_freqs_on_this_cp_rank(k_freqs, seqlen, cp_size, cp_rank)
q = _apply_fused_rope_4d(q, q_freqs, tensor_format, interleaved)
k = _apply_fused_rope_4d(k, k_freqs, tensor_format, interleaved)
ctx.save_for_backward(q_freqs, k_freqs)
ctx.tensor_format = tensor_format
ctx.qkv_split_arg_list = qkv_split_arg_list
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return q, k, v
@staticmethod
def backward(
ctx, grad_q: torch.Tensor, grad_k: torch.Tensor, grad_v: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
q_freqs, k_freqs = ctx.saved_tensors
tensor_format = ctx.tensor_format
interleaved = ctx.interleaved
q_dim, k_dim, v_dim = ctx.qkv_split_arg_list
if tensor_format == "bshd":
q_freqs = q_freqs.transpose(0, 1)
k_freqs = k_freqs.transpose(0, 1)
q_cos = torch.cos(q_freqs).to(grad_q.dtype)
q_sin = torch.sin(q_freqs).to(grad_q.dtype)
k_cos = torch.cos(k_freqs).to(grad_k.dtype)
k_sin = torch.sin(k_freqs).to(grad_k.dtype)
q_rot_dim = q_freqs.shape[-1]
k_rot_dim = k_freqs.shape[-1]
use_npu_kernel = True
if interleaved:
if tensor_format == "sbhd" and q_freqs.shape[1] != 1:
use_npu_kernel = False
elif tensor_format == "bshd" and q_freqs.shape[0] != 1:
use_npu_kernel = False
grad_q_rot = grad_q[..., :q_rot_dim].contiguous()
grad_q_pass = grad_q[..., q_rot_dim:]
if use_npu_kernel:
grad_q_rot = _npu_rotary_mul(grad_q_rot, q_cos, -q_sin, interleaved)
else:
grad_q_rot = (grad_q_rot * q_cos) - (_rotate_half(grad_q_rot, interleaved) * q_sin)
grad_q_out = torch.cat([grad_q_rot, grad_q_pass], dim=-1) if q_rot_dim < grad_q.shape[-1] else grad_q_rot
grad_k_rot = grad_k[..., :k_rot_dim].contiguous()
grad_k_pass = grad_k[..., k_rot_dim:]
if use_npu_kernel:
grad_k_rot = _npu_rotary_mul(grad_k_rot, k_cos, -k_sin, interleaved)
else:
grad_k_rot = (grad_k_rot * k_cos) - (_rotate_half(grad_k_rot, interleaved) * k_sin)
grad_k_out = torch.cat([grad_k_rot, grad_k_pass], dim=-1) if k_rot_dim < grad_k.shape[-1] else grad_k_rot
grad_input = torch.cat([grad_q_out, grad_k_out, grad_v], dim=-1)
return grad_input, None, None, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
"""Change sign so the last dimension becomes [-odd, +even]
Args:
x: torch.Tensor. Input tensor.
interleaved: bool. Whether to use interleaved rotary position embedding.
Returns:
Tensor: Tensor rotated half.
"""
if not interleaved:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
def _apply_rotary_pos_emb_base(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
interleaved: bool = False,
) -> torch.Tensor:
"""
Base implementation of applying rotary positional embedding tensor to the input tensor.
Parameters
----------
t : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied.
freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]`
and dtype 'float', with `s2 >= s` and `d2 <= d`.
tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`.
interleaved : bool, default = False
Whether to use interleaved rotary position embedding.
"""
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1)
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
t = (t * cos_) + (_rotate_half(t, interleaved) * sin_)
return torch.cat((t, t_pass), dim=-1)
def _get_freqs_on_this_cp_rank(
freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int
) -> torch.Tensor:
"""Get the position embedding on the current context parallel rank.
Args:
freqs: torch.Tensor. Positional embedding tensor of shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank.
"""
if cp_size > 1:
cp_seg = seqlen // 2
full_seqlen = cp_size * seqlen
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
return freqs[:seqlen]
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False,
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input tensor.
Support matrix:
Fused (NPU):
qkv_formats: "thd", "bshd", "sbhd"
context parallel: yes
start_positions: yes
interleaving: yes
Unfused (NPU):
qkv_formats: "thd", "bshd", "sbhd"
context parallel: yes
start_positions: yes
interleaving: yes
Parameters
----------
t : torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format : {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved : bool, default = False
Whether to use interleaved rotary position embedding.
fused : bool, default = False
Whether to use a fused applying RoPE implementation.
cu_seqlens : torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
Should be `cu_seqlens_padded` when cp_size > 1.
cp_size : int, default = 1.
Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
cp_rank : int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
if fused:
return FusedRoPEFunc.apply(
t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
)
if tensor_format == "thd":
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
_apply_rotary_pos_emb_base(
x.unsqueeze(1),
_get_freqs_on_this_cp_rank(
(
freqs[start_positions[idx] :] if start_positions is not None else freqs
),
x.size(0),
cp_size,
cp_rank,
),
interleaved=interleaved,
)
for idx, x in enumerate(torch.split(t, seqlens))
]
).squeeze(1)
if tensor_format == "sbhd":
seqlen = t.size(0)
elif tensor_format == "bshd":
seqlen = t.size(1)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + seqlen * cp_size <= freqs.shape[0]
), f"Rotary Embeddings only suppported up to {freqs.shape[0]} sequence length!"
freqs = torch.concatenate([freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1)
return _apply_rotary_pos_emb_base(
t,
_get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
tensor_format,
interleaved=interleaved,
)
def apply_fused_qkv_rotary_pos_emb(
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Apply rotary positional embedding tensor to the input qkv tensor.
Support matrix:
Fused (NPU):
Training:
qkv_formats: "bshd", "sbhd"
context parallel: yes
start_positions: no
interleaving: yes
Inference:
qkv_formats: "bshd", "sbhd"
context parallelism: no
start_positions: yes
interleaving: yes
Parameters
----------
qkv : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension.
q_freqs : torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
k_freqs : torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list : List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is of shape
`[seq, bs, ...]`.
interleaved : bool, default = False
Whether to use interleaved rotary position embedding.
cp_size : int, default = 1.
Context parallel world size.
cp_rank : int, default = 0.
Context parallel rank.
"""
assert not (
cp_size > 1 and start_positions is not None
), "start_positions != None with CP SIZE > 1 is not supported!"
assert tensor_format != "thd", "'thd' tensor_format not supported currently."
return FusedQKVRoPEFunc.apply(
qkv,
q_freqs,
k_freqs,
qkv_split_arg_list,
start_positions,
tensor_format,
interleaved,
cp_size,
cp_rank,
)