#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
inline bool _transform_bias_rescale_qkv_fallback_condition()
{
static const bool is_aclnn_kernel_available = check_aclnn_kernel_available("aclnnTransformBiasRescaleQkv");
if (!is_aclnn_kernel_available) {
TORCH_NPU_WARN_ONCE("CAUTION: The operator aten::_transform_bias_rescale_qkv is currently "
"not supported on the NPU backend. Now this operator will fallback to run on the CPU "
"and may have performance implications. Please try to update your CANN version.");
return true;
}
return false;
}
c10::SmallVector<int64_t, SIZE> _transform_bias_rescale_qkv_out_size(
const at::Tensor& qkv,
const at::Tensor& qkv_bias,
const int64_t num_head)
{
TORCH_CHECK(num_head != 0, "num_head of _transform_bias_rescale_qkv must not be 0.", OPS_ERROR(ErrCode::VALUE))
auto B = qkv.size(0);
auto T = qkv.size(1);
auto _3D = qkv_bias.size(0);
auto D = _3D / 3;
TORCH_CHECK(D % num_head == 0, "D of _transform_bias_rescale_qkv must divide evenly by num_head", OPS_ERROR(ErrCode::VALUE));
const auto dim_per_head = D / num_head;
c10::SmallVector<int64_t, SIZE> output_shape = {B, num_head, T, dim_per_head};
return output_shape;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> _transform_bias_rescale_qkv(
const at::Tensor& qkv,
const at::Tensor& qkv_bias,
const int64_t num_heads)
{
if (_transform_bias_rescale_qkv_fallback_condition()) {
at::Tensor qkv_cpu = qkv.cpu();
at::Tensor qkv_bias_cpu = qkv_bias.cpu();
std::tuple<at::Tensor, at::Tensor, at::Tensor> result_cpu = at::_transform_bias_rescale_qkv(qkv_cpu, qkv_bias_cpu, num_heads);
at::Tensor q = std::get<0>(result_cpu).to(qkv.device());
at::Tensor k = std::get<1>(result_cpu).to(qkv.device());
at::Tensor v = std::get<2>(result_cpu).to(qkv.device());
return std::make_tuple(std::move(q), std::move(k), std::move(v));
}
auto output_size = _transform_bias_rescale_qkv_out_size(qkv, qkv_bias, num_heads);
auto output_dtype = qkv.scalar_type();
at::Tensor q = npu_preparation::apply_tensor_without_format(output_size, qkv.options().dtype(output_dtype));
at::Tensor k = npu_preparation::apply_tensor_without_format(output_size, qkv.options().dtype(output_dtype));
at::Tensor v = npu_preparation::apply_tensor_without_format(output_size, qkv.options().dtype(output_dtype));
EXEC_NPU_CMD(aclnnTransformBiasRescaleQkv, qkv, qkv_bias, num_heads, q, k, v);
return std::make_tuple(std::move(q), std::move(k), std::move(v));
}
}