import torch
from mindspeed.op_builder import RotaryPositionEmbeddingOpBuilder

__all__ = ["npu_rotary_position_embedding"]


rope_op_builder = RotaryPositionEmbeddingOpBuilder()


def npu_rotary_position_embedding(x, cos, sin, mode=0):
    rope_ops = rope_op_builder.load()
    return rope_ops.npu_rotary_position_embedding(x, cos, sin, mode)