import torch
import torch.nn as nn
import torch_npu
class Model(nn.Module):
"""
Rotary Position Embedding (RoPE).
Formula:
x1, x2 = chunk(x, 2, dim=-1)
x_new = concat(-x2, x1)
output = cos * x + sin * x_new
On NPU, use torch_npu.npu_rotary_mul(x, cos, sin).
Constraints:
- x, cos, sin must be 4D tensors
- Last dimension must be multiple of 64
- cos and sin last dim must match x last dim
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, x, cos, sin):
return torch_npu.npu_rotary_mul(x, cos, sin)
def get_inputs():
batch = 16
num_heads = 48
seq_len = 1000
head_dim = 128
x = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16)
cos = torch.randn(1, 1, seq_len, head_dim, dtype=torch.float16).clamp(-1, 1)
sin = torch.randn(1, 1, seq_len, head_dim, dtype=torch.float16).clamp(-1, 1)
return [x, cos, sin]
def get_init_inputs():
return []