#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/AclOpsInterface.h"
namespace {
bool parse_rotary_mode_to_neox_style(const std::string &rotary_mode)
{
TORCH_CHECK(
rotary_mode == "half" || rotary_mode == "interleave",
"rotary_mode only support half or interleave",
OPS_ERROR(ErrCode::VALUE));
return rotary_mode == "half";
}
int64_t parse_cache_mode_to_int(const std::string &cache_mode)
{
if (cache_mode == "default") {
return 0;
}
if (cache_mode == "interleave") {
return 1;
}
TORCH_CHECK(false, "cache_mode only support default or interleave", OPS_ERROR(ErrCode::VALUE));
return 0;
}
}
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
std::tuple<at::Tensor, at::Tensor> _mrope_v1(
const at::Tensor &positions,
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &cos_sin_cache,
at::IntArrayRef mrope_section,
int64_t head_size,
bool is_neox_style,
at::Tensor &query_out,
at::Tensor &key_out)
{
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnRopeWithSinCosCache, positions, query, key, cos_sin_cache,
mrope_section, head_size, is_neox_style, query_out, key_out);
return std::tie(query_out, key_out);
}
std::tuple<at::Tensor, at::Tensor> npu_mrope(
const at::Tensor &positions,
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &cos_sin_cache,
int64_t head_size,
c10::OptionalIntArrayRef mrope_section,
c10::optional<c10::string_view> rotary_mode,
c10::optional<c10::string_view> cache_mode)
{
at::IntArrayRef mrope_section_value = mrope_section.value_or(at::IntArrayRef{0, 0, 0});
std::string rotary_mode_str = rotary_mode.has_value() ? std::string(rotary_mode.value()) : "half";
std::string cache_mode_str = cache_mode.has_value() ? std::string(cache_mode.value()) : "default";
bool is_neox_style = parse_rotary_mode_to_neox_style(rotary_mode_str);
int64_t cache_mode_value = parse_cache_mode_to_int(cache_mode_str);
static const bool is_mrope_v2_available = check_aclnn_kernel_available("aclnnRopeWithSinCosCacheV2");
TORCH_CHECK(
is_mrope_v2_available || cache_mode_value == 0,
"npu_mrope: cache_mode='", cache_mode_str,
"' requires aclnnRopeWithSinCosCacheV2, but current environment does not support it. "
"Please upgrade CANN or use cache_mode='default'.",
OPS_ERROR(ErrCode::NOT_SUPPORT));
at::Tensor query_out = at::empty_like(query);
at::Tensor key_out = at::empty_like(key);
DO_COMPATIBILITY(aclnnRopeWithSinCosCacheV2,
_mrope_v1(positions, query, key, cos_sin_cache, mrope_section_value,
head_size, is_neox_style, query_out, key_out));
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnRopeWithSinCosCacheV2, positions, query, key, cos_sin_cache,
mrope_section_value, head_size, is_neox_style, cache_mode_value, query_out, key_out);
return std::tie(query_out, key_out);
}
}