import os
from pathlib import Path
import torch
import torch_npu
from ..utils import ParametersInvalid, file_utils
def check_input_params(x, cos, sin, rotated_mode, head_first, fused):
if not isinstance(x, torch.Tensor):
raise ParametersInvalid(f"The data type of input x must be torch.Tensor, but got {type(x)}.")
if not isinstance(cos, torch.Tensor):
raise ParametersInvalid(f"The data type of input cos must be torch.Tensor, but got {type(cos)}.")
if not isinstance(sin, torch.Tensor):
raise ParametersInvalid(f"The data type of input sin must be torch.Tensor, but got {type(sin)}.")
if not isinstance(rotated_mode, str):
raise ParametersInvalid(f"The data type of input rotated_mode must be str, but got {type(rotated_mode)}.")
if not isinstance(head_first, bool):
raise ParametersInvalid(f"The data type of input head_first must be bool, but got {type(head_first)}.")
if not isinstance(fused, bool):
raise ParametersInvalid(f"The data type of input fused must be bool, but got {type(fused)}.")
if x.dim() != 4:
raise ParametersInvalid(f"The dimensional of input x must be a 4, but got {x.dim()}.")
if cos.dim() not in [2, 4]:
raise ParametersInvalid(f"The dimensional of input cos must be a 2 or 4, but got {cos.dim()}.")
if sin.dim() not in [2, 4]:
raise ParametersInvalid(f"The dimensional of input sin must be a 2 or 4, but got {sin.dim()}.")
if cos.dim() != sin.dim():
raise ParametersInvalid(f"The dimensional of input cos must be equal to the dimensional of input sin, "
f"but {cos.dim()} != {sin.dim()}.")
def reshape_for_broadcast(x, cos, sin, head_first=False):
ndim = x.ndim
if head_first:
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return cos.view(*shape), sin.view(*shape)
def rotary_position_embedding(x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rotated_mode: str = "rotated_half",
head_first: bool = False,
fused: bool = True) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
Args:
x (torch.Tensor):
Query or key tensor to apply rotary embeddings. x must be 4-dimensional.
The supported layout: [B,N,S,D], [B,S,N,D], [S,B,N,D].
Input x could be represented as [x_0, x_1, ... , x_d/2-1, x_d/2, x_d/2+1, ... , x_d-1].
cos (torch.Tensor):
Precomputed cos frequency tensor for complex exponentials. cos must be 2 or 4-dimensional.
Correspongding to the input x, the supported layout: [S, D], [1,1,S,D], [1,S,1,D], [S,1,1,D].
sin (torch.Tensor):
Precomputed sin frequency tensor for complex exponentials. sin must be 2 or 4-dimensional.
Correspongding to the input x, the supported layout: [S, D], [1,1,S,D], [1,S,1,D], [S,1,1,D].
rotated_mode (str):
If `rotated_half`: rotate x to [-x_d/2, -x_d/2+1, ... , -x_d-1, x_0, x_1, ... , x_d/2-1].
If `rotated_interleaved`: rotate x to [-x_1, x_0, -x_3, x_2, ... , -x_d-1, x_d-2].
head_first (bool):
In the layout of x, if N is before S, set to True; otherwise, set to False.
fused (bool):
If fused is True, using high-performance RoPE operator.
Returns:
(torch.Tensor): modified query or key tensor with rotary embeddings.
"""
check_input_params(x, cos, sin, rotated_mode, head_first, fused)
if cos.dim() == 2 and sin.dim() == 2:
cos, sin = reshape_for_broadcast(x, cos, sin, head_first=head_first)
mode = None
if rotated_mode == "rotated_half":
mode = "half"
elif rotated_mode == "rotated_interleaved":
mode = "interleave"
else:
raise ParametersInvalid(f"Unsupported rotated_mode: {rotated_mode}. The supported "
"rotated_mode must be 'rotated_half' or 'rotated_interleaved'")
x_in = x.to(cos.dtype)
if fused:
x_out = torch_npu.npu_rotary_mul(x_in, cos, sin, mode)
elif mode == "interleave":
x_real, x_imag = x_in.reshape(*x_in.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
x_out = x_in * cos + x_rotated * sin
else:
x_real, x_imag = x_in.reshape(*x_in.shape[:-1], 2, -1).unbind(-2)
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
x_out = x_in * cos + x_rotated * sin
return x_out.type_as(x)