import torch
import torch_npu
import npu_ops_transformer_ext
import logging
logging.basicConfig(level=logging.INFO)
batch_size = 2
sequence_len = 10
head_num = 2
hidden_dim = 8
epsilon = 1e-5
supported_dtypes = {torch.bfloat16, torch.float16}
for dtype in supported_dtypes:
logging.info(f"start testing dtype : {'BF16' if dtype == torch.bfloat16 else 'HALF'}")
input = torch.randn(batch_size, sequence_len, head_num, hidden_dim).to(dtype)
sin = torch.arange(1, sequence_len + 1).unsqueeze(1).expand(sequence_len, hidden_dim).to(dtype)
cos = torch.arange(0, sequence_len).unsqueeze(1).expand(sequence_len, hidden_dim).to(dtype)
output = torch.empty_like(input)
sin_expanded = sin.unsqueeze(0).unsqueeze(2)[..., ::2]
cos_expanded = cos.unsqueeze(0).unsqueeze(2)[..., ::2]
input1, input2 = torch.chunk(input, chunks = 2, dim = -1)
rotated_input1 = input1 * cos_expanded - input2 * sin_expanded
rotated_input2 = input2 * cos_expanded + input1 * sin_expanded
output_cpu = torch.cat([rotated_input1, rotated_input2], dim = -1)
input_npu = input.npu()
sin_npu = sin.npu()
cos_npu = cos.npu()
output_npu = output.npu()
torch.ops.npu_ops_transformer_ext.rotary_stride(40, input_npu, sin_npu, cos_npu, output_npu, hidden_dim)
abs_error = torch.abs(output_npu.cpu() - output_cpu)
rel_error = abs_error / (torch.abs(output_npu.cpu()) + epsilon)
logging.info(f"input tensor: \n{input}")
logging.info(f"cpu result tensor: \n{output_cpu}")
logging.info(f"npu result tensor: \n{output_npu}")
logging.info(f"max absolute error: {abs_error.max().item():.8f}")
logging.info(f"average absolute error: {abs_error.mean().item():.8f}")
logging.info(f"max relative error: {rel_error.max().item():.8f}")
logging.info(f"average relative error: {rel_error.mean().item():.8f}")