#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
constexpr int64_t LAYOUT_BSND_BSH = 1;
constexpr int64_t LAYOUT_SBND = 2;
constexpr int64_t LAYOUT_BNSD = 3;
constexpr int64_t LAYOUT_TND = 4;
std::tuple<at::Tensor, at::Tensor> _apply_rotary_pos_emb_v1(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &cos,
const at::Tensor &sin,
int64_t lay_out)
{
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnApplyRotaryPosEmb, query, key, cos, sin, lay_out);
return std::tie(query, key);
}
std::tuple<at::Tensor, at::Tensor> npu_apply_rotary_pos_emb(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &cos,
const at::Tensor &sin,
c10::string_view layout,
c10::string_view rotary_mode)
{
std::string layout_str = std::string(layout);
std::string rotary_mode_str = std::string(rotary_mode);
TORCH_CHECK(rotary_mode_str == "half" || rotary_mode_str == "quarter" || rotary_mode_str == "interleave",
"The layout should be half/quarter/interleave, but got ", rotary_mode_str, OPS_ERROR(ErrCode::PARAM));
int64_t lay_out = LAYOUT_BSND_BSH;
if (layout_str == "BNSD") {
lay_out = LAYOUT_BNSD;
} else if (layout_str == "SBND") {
lay_out = LAYOUT_SBND;
} else if (layout_str == "TND") {
lay_out = LAYOUT_TND;
}
DO_COMPATIBILITY(aclnnApplyRotaryPosEmbV2, _apply_rotary_pos_emb_v1(query, key, cos, sin, lay_out));
char* rotary_mode_ptr = const_cast<char *>(rotary_mode_str.c_str());
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnApplyRotaryPosEmbV2, query, key, cos, sin, lay_out, rotary_mode_ptr);
return std::tie(query, key);
}
}