#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/custom_functions/opapi/update_op_api_common.h"
namespace op_api {
#if VERSION_BETWEEN(V2R1, VERSION_NEWEST)
const static int FLASH_THRESHOLD = 512;
const static int64_t PFA_SPARSE_HIGH_PRECISION_NO_MASK = 10;
const static int64_t PFA_SPARSE_HIGH_PRECISION_BAND = 14;
const static int64_t DIM_0 = 0;
const static int64_t DIM_1 = 1;
const static int64_t DIM_2 = 2;
const static int64_t DIM_3 = 3;
const static int64_t DIM_4 = 4;
const static int64_t DIM_NUMS_3 = 3;
const static int64_t DIM_NUMS_4 = 4;
const static int SHAPE_3D = 3;
const static int SHAPE_4D = 4;
const static int SHAPE_5D = 5;
using namespace at_npu::native;
using npu_preparation = at_npu::native::OpPreparation;
static std::pair<std::string, std::string> get_query_and_attention_out_layout(
const at::Tensor query,
std::string input_layout_str)
{
struct parserLayout {
std::string qLayout;
std::string outLayout;
int32_t qDim;
};
const std::map<std::string, parserLayout> LAYOUT_MAP = {
{"BSH", {"BSH", "BSH", DIM_NUMS_3}},
{"BSND", {"BSND", "BSND", DIM_NUMS_4}},
{"BNSD", {"BNSD", "BNSD", DIM_NUMS_4}},
{"TND", {"TND", "TND", DIM_NUMS_3}},
{"NTD", {"NTD", "NTD", DIM_NUMS_3}},
{"BNSD_BSND", {"BNSD", "BSND", DIM_NUMS_4}},
{"BSH_BNSD", {"BSH", "BNSD", DIM_NUMS_3}},
{"BSND_BNSD", {"BSND", "BNSD", DIM_NUMS_4}},
{"NTD_TND", {"NTD", "TND", DIM_NUMS_3}},
{"BSH_NBSD", {"BSH", "NBSD", DIM_NUMS_3}},
{"BSND_NBSD", {"BSND", "NBSD", DIM_NUMS_4}},
{"BNSD_NBSD", {"BNSD", "NBSD", DIM_NUMS_4}},
{"TND_NTD", {"TND", "NTD", DIM_NUMS_3}},
{"NSD", {"NSD", "NSD", DIM_NUMS_3}}
};
std::string query_layout = "BSH";
std::string attention_out_layout = "BSH";
int32_t query_dim;
auto it = LAYOUT_MAP.find(input_layout_str);
if (it != LAYOUT_MAP.end()) {
query_layout = it->second.qLayout;
attention_out_layout = it->second.outLayout;
query_dim = it->second.qDim;
TORCH_CHECK(query.dim() == query_dim,
"query's dim should be consistent with that of Layout", OPS_ERROR(ErrCode::VALUE));
} else {
TORCH_CHECK(
false,
"layout only support BSH, BSND, TND, NTD, BNSD_BSND, BSH_BNSD, BSND_BNSD, NTD_TND, ",
"BSH_NBSD, BSND_NBSD, BNSD_NBSD, TND_NTD, but got ",
query_layout,
OPS_ERROR(ErrCode::VALUE)
);
}
return {query_layout, attention_out_layout};
}
static std::tuple<int64_t, int64_t, int64_t, int64_t> get_query_b_n_s_d(
const at::Tensor &query,
std::string query_layout,
int64_t num_heads)
{
int64_t b = 0;
int64_t n1 = 0;
int64_t s1 = 0;
int64_t d1 = 0;
if (query_layout == "BSH") {
b = query.size(DIM_0);
s1 = query.size(DIM_1);
n1 = num_heads;
d1 = query.size(DIM_2) / num_heads;
} else if (query_layout == "BSND") {
b = query.size(DIM_0);
s1 = query.size(DIM_1);
n1 = query.size(DIM_2);
d1 = query.size(DIM_3);
} else if (query_layout == "BNSD") {
b = query.size(DIM_0);
s1 = query.size(DIM_2);
n1 = query.size(DIM_1);
d1 = query.size(DIM_3);
} else if (query_layout == "NSD") {
b = 1;
s1 = query.size(DIM_1);
n1 = query.size(DIM_0);
d1 = query.size(DIM_2);
} else {
TORCH_CHECK(
false,
"It is not supported in get_query_b_n_s_d function, layout ",
query_layout,
OPS_ERROR(ErrCode::VALUE)
);
}
return {b, n1, s1, d1};
}
static std::tuple<int64_t, int64_t, int64_t> get_query_t_n_d(
const at::Tensor &query,
std::string query_layout)
{
int64_t t = 0;
int64_t n1 = 0;
int64_t d1 = 0;
if (query_layout == "TND") {
t = query.size(DIM_0);
n1 = query.size(DIM_1);
d1 = query.size(DIM_2);
} else if (query_layout == "NTD") {
t = query.size(DIM_1);
n1 = query.size(DIM_0);
d1 = query.size(DIM_2);
} else {
TORCH_CHECK(
false,
"It is not supported in get_query_t_n_d function, layout ",
query_layout,
OPS_ERROR(ErrCode::VALUE)
);
}
return {t, n1, d1};
}
static int64_t get_value_d(
const c10::optional<at::Tensor> &block_table,
const at::Tensor &query,
const at::Tensor &value,
std::string query_layout,
int64_t kv_num_heads)
{
int64_t valueD = 0;
if (block_table.has_value()) {
if (value.dim() == SHAPE_3D) {
valueD = value.size(DIM_2) / kv_num_heads;
} else if (value.dim() == SHAPE_4D) {
valueD = value.size(DIM_3);
} else if (value.dim() == SHAPE_5D) {
valueD = value.size(DIM_2) * value.size(DIM_4);
} else {
TORCH_CHECK(
false,
"when Page Attention enabled, value's dim should be 3/4/5, but got ",
value.dim(),
OPS_ERROR(ErrCode::VALUE)
);
}
} else {
TORCH_CHECK(
value.dim() == query.dim(),
"when Page Attention not enabled, value'dim should equal to query's dim!",
OPS_ERROR(ErrCode::VALUE)
);
if (query_layout == "BSH") {
valueD = value.size(DIM_2) / kv_num_heads;
} else if (query_layout == "BSND" || query_layout == "BNSD") {
valueD = value.size(DIM_3);
} else if (query_layout == "TND" || query_layout == "NTD" || query_layout == "NSD") {
valueD = value.size(DIM_2);
}
}
return valueD;
}
static int get_change_d_scale(
const at::Tensor &value)
{
const static int changeDScale = 1;
const static int changeDForInt32 = 8;
if (value.scalar_type() == at::kInt) {
return changeDForInt32;
}
return changeDScale;
}
static at::Tensor infer_attention_out_shape(
std::string attention_out_layout,
const at::Tensor &query,
std::string query_layout,
int64_t num_heads,
int64_t valueD)
{
int64_t b = 0;
int64_t n1 = 0;
int64_t s1 = 0;
int64_t d1 = 0;
int64_t t = 0;
at::Tensor attention_out = npu_preparation::apply_tensor_without_format(query);
if (attention_out_layout == "BSH") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outH = num_heads * valueD;
outH = (outH == 0 || query.size(DIM_2) == 0) ? query.size(DIM_2) : outH;
attention_out = OpPreparation::apply_tensor_without_format(
{b, s1, outH},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "BSND") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{b, s1, n1, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "BNSD") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{b, n1, s1, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "NBSD") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{n1, b, s1, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "TND") {
auto [t, n1, d1] = get_query_t_n_d(query, query_layout);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{t, n1, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "NTD") {
auto [t, n1, d1] = get_query_t_n_d(query, query_layout);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{n1, t, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "NSD") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{n1, s1, outD},
query.options().dtype(query.dtype())
);
}
return attention_out;
}
static at::Tensor infer_lse_out_shape(
std::string input_layout_str,
const at::Tensor &query,
std::string query_layout,
int64_t num_heads)
{
int64_t b = 0;
int64_t n1 = 0;
int64_t s1 = 0;
int64_t d1 = 0;
int64_t t = 0;
at::Tensor lse_out;
if (input_layout_str == "TND" || input_layout_str == "NTD" ||
input_layout_str == "TND_NTD" || input_layout_str == "NTD_TND") {
auto [t, n1, d1] = get_query_t_n_d(query, query_layout);
lse_out = npu_preparation::apply_tensor_without_format({t, n1, 1}, c10::dtype(c10::ScalarType::Float));
} else {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
lse_out = npu_preparation::apply_tensor_without_format({b, n1, s1, 1}, c10::dtype(c10::ScalarType::Float));
}
return lse_out;
}
std::tuple<at::Tensor, at::Tensor> construct_fia_output_tensor_experimental(
const at::Tensor &query,
const at::Tensor &value,
std::string input_layout_str,
const c10::optional<at::Tensor> &block_table,
int64_t num_heads,
int64_t kv_num_heads,
bool softmax_lse_flag,
const c10::optional<at::Tensor> &query_rope)
{
TORCH_CHECK(
num_heads > 0,
"num_heads should be greater than 0, but the actual value is",
num_heads,
OPS_ERROR(ErrCode::VALUE)
);
kv_num_heads = (kv_num_heads == 0) ? num_heads : kv_num_heads;
auto [query_layout, attention_out_layout] = get_query_and_attention_out_layout(query, input_layout_str);
int64_t valueD = get_value_d(block_table, query, value, query_layout, kv_num_heads);
int changeDScale = get_change_d_scale(value);
valueD = valueD * changeDScale;
at::Tensor tmp_output = infer_attention_out_shape(attention_out_layout, query, query_layout, num_heads, valueD);
at::Tensor output;
if (query.dtype() == at::kChar) {
if (query_rope.has_value()) {
const at::Tensor &query_rope_tensor = c10::value_or_else(query_rope, [] { return at::Tensor(); });
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), c10::dtype(query_rope_tensor.dtype()));
} else {
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), c10::dtype(c10::ScalarType::Half));
}
} else {
output = npu_preparation::apply_tensor_without_format(tmp_output);
}
at::Tensor softmax_lse;
if (softmax_lse_flag) {
softmax_lse = infer_lse_out_shape(input_layout_str, query, query_layout, num_heads);
} else {
softmax_lse = npu_preparation::apply_tensor_without_format({0}, c10::dtype(c10::ScalarType::Float));
}
return std::tuple<at::Tensor, at::Tensor>(output, softmax_lse);
}
std::tuple<at::Tensor, at::Tensor> _npu_attention_pioneer(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &atten_mask,
c10::OptionalArrayRef<c10::SymInt> &actual_seq_lengths,
c10::OptionalArrayRef<c10::SymInt> &actual_seq_lengths_kv,
const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
const c10::optional<at::Tensor> &key_sink,
const c10::optional<at::Tensor> &key_rope_sink,
const c10::optional<at::Tensor> &value_sink,
int64_t num_heads, double scale,
int64_t pre_tokens, int64_t next_tokens,
c10::string_view input_layout, int64_t num_key_value_heads,
int64_t sparse_mode, int64_t block_size,
bool softmax_lse_flag)
{
std::string input_layout_str = std::string(input_layout);
std::tuple<at::Tensor, at::Tensor> fia_output = op_api::construct_fia_output_tensor_experimental(query, value, input_layout_str,
block_table, num_heads, num_key_value_heads,
softmax_lse_flag, query_rope);
at::Tensor output = std::get<0>(fia_output);
at::Tensor softmax_lse = std::get<1>(fia_output);
char *input_layout_ptr = const_cast<char *>(input_layout_str.c_str());
at::TensorList valueTensors = value;
at::TensorList keyTensors = key;
at::Tensor pse_shift;
at::Tensor deq_scale1;
at::Tensor quant_scale1;
at::Tensor deq_scale2;
at::Tensor quant_scale2;
at::Tensor quant_offset2;
at::Tensor antiquant_scale;
at::Tensor antiquant_offset;
at::Tensor q_padding_size;
at::Tensor kv_padding_size;
at::Tensor key_antiquant_scale;
at::Tensor key_antiquant_offset;
at::Tensor value_antiquant_scale;
at::Tensor value_antiquant_offset;
at::Tensor key_shared_prefix;
at::Tensor value_shared_prefix;
at::SymIntArrayRef default_actual_shared_prefix;
at::Tensor key_rope_antiquant_scale;
at::Tensor dequant_scale_query;
at::Tensor learnablesink;
at::SymIntArrayRef default_q_start_idx;
at::SymIntArrayRef default_kv_start_idx;
int64_t inner_precise = 0;
int64_t antiquant_mode = 0;
int64_t key_antiquant_mode = 0;
int64_t value_antiquant_mode = 0;
int64_t query_quant_mode = 0;
int64_t pse_type = 1;
if (!op_plugin::utils::is_gte_cann_version_810rc1() &&
(query_rope.has_value() || key_rope.has_value())) {
TORCH_CHECK(false, "query_rope|key_rope|key_rope_antiquant_scale has value, CANN version should greater than 8.1.RC1!"
+ OPS_ERROR(ErrCode::NOT_SUPPORT));
}
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnAttentionPioneer, query, keyTensors, valueTensors, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv,
deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, q_padding_size, kv_padding_size,
key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix,
query_rope, key_rope, key_rope_antiquant_scale, dequant_scale_query, learnablesink, default_q_start_idx, default_kv_start_idx, key_sink, key_rope_sink, value_sink,
num_heads, scale, pre_tokens, next_tokens, input_layout_ptr, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode,
value_antiquant_mode, query_quant_mode, pse_type, output, softmax_lse);
return std::tuple<at::Tensor, at::Tensor>(output, softmax_lse);
}
std::tuple<at::Tensor &, at::Tensor &> _npu_attention_pioneer_out(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &atten_mask,
c10::OptionalArrayRef<c10::SymInt> &actual_seq_lengths,
c10::OptionalArrayRef<c10::SymInt> &actual_seq_lengths_kv,
const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
const c10::optional<at::Tensor> &key_sink,
const c10::optional<at::Tensor> &key_rope_sink,
const c10::optional<at::Tensor> &value_sink,
int64_t num_heads, double scale,
int64_t pre_tokens, int64_t next_tokens,
c10::string_view input_layout, int64_t num_key_value_heads,
int64_t sparse_mode, int64_t block_size,
bool softmax_lse_flag,
const c10::optional<at::Tensor> &workspace,
at::Tensor &attention_out,
at::Tensor &softmax_lse)
{
std::string input_layout_str = std::string(input_layout);
char *input_layout_ptr = const_cast<char *>(input_layout_str.c_str());
at::TensorList valueTensors = value;
at::TensorList keyTensors = key;
at::Tensor pse_shift;
at::Tensor deq_scale1;
at::Tensor quant_scale1;
at::Tensor deq_scale2;
at::Tensor quant_scale2;
at::Tensor quant_offset2;
at::Tensor antiquant_scale;
at::Tensor antiquant_offset;
at::Tensor q_padding_size;
at::Tensor kv_padding_size;
at::Tensor key_antiquant_scale;
at::Tensor key_antiquant_offset;
at::Tensor value_antiquant_scale;
at::Tensor value_antiquant_offset;
at::Tensor key_shared_prefix;
at::Tensor value_shared_prefix;
at::SymIntArrayRef default_actual_shared_prefix;
at::Tensor key_rope_antiquant_scale;
at::Tensor dequant_scale_query;
at::Tensor learnablesink;
at::SymIntArrayRef default_q_start_idx;
at::SymIntArrayRef default_kv_start_idx;
int64_t inner_precise = 0;
int64_t antiquant_mode = 0;
int64_t key_antiquant_mode = 0;
int64_t value_antiquant_mode = 0;
int64_t query_quant_mode = 0;
int64_t pse_type = 1;
if (!op_plugin::utils::is_gte_cann_version_810rc1() &&
(query_rope.has_value() || key_rope.has_value())) {
TORCH_CHECK(false, "query_rope|key_rope|key_rope_antiquant_scale has value, CANN version should greater than 8.1.RC1!"
+ OPS_ERROR(ErrCode::NOT_SUPPORT));
}
if (workspace.has_value()) {
void* workspace_addr = const_cast<void *>(workspace.value().storage().data());
uint64_t workspace_size = static_cast<uint64_t>(workspace.value().numel() * workspace.value().element_size());
EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD(aclnnAttentionPioneer, workspace_addr, workspace_size, query, keyTensors, valueTensors, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv,
deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, q_padding_size, kv_padding_size,
key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix,
query_rope, key_rope, key_rope_antiquant_scale, dequant_scale_query, learnablesink, default_q_start_idx, default_kv_start_idx, key_sink, key_rope_sink, value_sink,
num_heads, scale, pre_tokens, next_tokens, input_layout_ptr, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode,
value_antiquant_mode, query_quant_mode, pse_type, attention_out, softmax_lse);
} else {
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnAttentionPioneer, query, keyTensors, valueTensors, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv,
deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, q_padding_size, kv_padding_size,
key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix,
query_rope, key_rope, key_rope_antiquant_scale, dequant_scale_query, learnablesink, default_q_start_idx, default_kv_start_idx, key_sink, key_rope_sink, value_sink,
num_heads, scale, pre_tokens, next_tokens, input_layout_ptr, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode,
value_antiquant_mode, query_quant_mode, pse_type, attention_out, softmax_lse);
}
return std::tuple<at::Tensor&, at::Tensor&>(attention_out, softmax_lse);
}
at::Tensor _npu_attention_pioneer_get_max_workspace(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &atten_mask,
c10::OptionalArrayRef<c10::SymInt> &actual_seq_lengths,
c10::OptionalArrayRef<c10::SymInt> &actual_seq_lengths_kv,
const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
const c10::optional<at::Tensor> &key_sink,
const c10::optional<at::Tensor> &key_rope_sink,
const c10::optional<at::Tensor> &value_sink,
int64_t num_heads, double scale,
int64_t pre_tokens, int64_t next_tokens,
c10::string_view input_layout, int64_t num_key_value_heads,
int64_t sparse_mode, int64_t block_size,
bool softmax_lse_flag)
{
std::string input_layout_str = std::string(input_layout);
std::tuple<at::Tensor, at::Tensor> fia_output = op_api::construct_fia_output_tensor_experimental(query, value, input_layout_str,
block_table, num_heads, num_key_value_heads,
softmax_lse_flag, query_rope);
at::Tensor output = std::get<0>(fia_output);
at::Tensor softmax_lse = std::get<1>(fia_output);
char *input_layout_ptr = const_cast<char *>(input_layout_str.c_str());
at::TensorList valueTensors = value;
at::TensorList keyTensors = key;
at::Tensor pse_shift;
at::Tensor deq_scale1;
at::Tensor quant_scale1;
at::Tensor deq_scale2;
at::Tensor quant_scale2;
at::Tensor quant_offset2;
at::Tensor antiquant_scale;
at::Tensor antiquant_offset;
at::Tensor q_padding_size;
at::Tensor kv_padding_size;
at::Tensor key_antiquant_scale;
at::Tensor key_antiquant_offset;
at::Tensor value_antiquant_scale;
at::Tensor value_antiquant_offset;
at::Tensor key_shared_prefix;
at::Tensor value_shared_prefix;
at::SymIntArrayRef default_actual_shared_prefix;
at::Tensor key_rope_antiquant_scale;
at::Tensor dequant_scale_query;
at::Tensor learnablesink;
at::SymIntArrayRef default_q_start_idx;
at::SymIntArrayRef default_kv_start_idx;
int64_t inner_precise = 0;
int64_t antiquant_mode = 0;
int64_t key_antiquant_mode = 0;
int64_t value_antiquant_mode = 0;
int64_t query_quant_mode = 0;
int64_t pse_type = 1;
if (!op_plugin::utils::is_gte_cann_version_810rc1() &&
(query_rope.has_value() || key_rope.has_value())) {
TORCH_CHECK(false, "query_rope|key_rope|key_rope_antiquant_scale has value, CANN version should greater than 8.1.RC1!"
+ OPS_ERROR(ErrCode::NOT_SUPPORT));
}
uint64_t workspace_size = EXEC_GET_MAX_WORKSPACE_CMD(aclnnAttentionPioneer, query, keyTensors, valueTensors, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv,
deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, q_padding_size, kv_padding_size,
key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix,
query_rope, key_rope, key_rope_antiquant_scale, dequant_scale_query, learnablesink, default_q_start_idx, default_kv_start_idx, key_sink, key_rope_sink, value_sink,
num_heads, scale, pre_tokens, next_tokens, input_layout_ptr, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode,
value_antiquant_mode, query_quant_mode, pse_type, output, softmax_lse);
at::Tensor workspace_tensor = npu_preparation::apply_tensor_without_format({workspace_size}, query.options().dtype(query.dtype()));
return workspace_tensor;
}
std::tuple<at::Tensor, at::Tensor> _npu_attention_pioneer_infer_output(
const at::Tensor &query,
const at::Tensor &value,
c10::string_view input_layout,
const c10::optional<at::Tensor> &block_table,
int64_t num_heads,
int64_t num_key_value_heads,
bool softmax_lse_flag,
const c10::optional<at::Tensor> &query_rope)
{
std::string input_layout_str = std::string(input_layout);
std::tuple<at::Tensor, at::Tensor> fia_output = op_api::construct_fia_output_tensor_experimental(query, value, input_layout_str,
block_table, num_heads, num_key_value_heads,
softmax_lse_flag, query_rope);
return fia_output;
}
std::tuple<at::Tensor, at::Tensor> _npu_attention_pioneer_symint(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &atten_mask,
c10::OptionalArrayRef<c10::SymInt> actual_seq_lengths,
c10::OptionalArrayRef<c10::SymInt> actual_seq_lengths_kv,
const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
const c10::optional<at::Tensor> &key_sink,
const c10::optional<at::Tensor> &key_rope_sink,
const c10::optional<at::Tensor> &value_sink,
int64_t num_heads, double scale,
int64_t pre_tokens, int64_t next_tokens,
c10::string_view input_layout, int64_t num_key_value_heads,
int64_t sparse_mode, int64_t block_size,
bool softmax_lse_flag)
{
return op_api::_npu_attention_pioneer(query, key, value, atten_mask,
actual_seq_lengths, actual_seq_lengths_kv, block_table, query_rope, key_rope,
key_sink, key_rope_sink, value_sink, num_heads, scale, pre_tokens, next_tokens,
input_layout, num_key_value_heads, sparse_mode, block_size, softmax_lse_flag);
}
std::tuple<at::Tensor &, at::Tensor &> _npu_attention_pioneer_out_symint(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &atten_mask,
c10::OptionalArrayRef<c10::SymInt> actual_seq_lengths,
c10::OptionalArrayRef<c10::SymInt> actual_seq_lengths_kv,
const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
const c10::optional<at::Tensor> &key_sink,
const c10::optional<at::Tensor> &key_rope_sink,
const c10::optional<at::Tensor> &value_sink,
int64_t num_heads, double scale,
int64_t pre_tokens, int64_t next_tokens,
c10::string_view input_layout, int64_t num_key_value_heads,
int64_t sparse_mode, int64_t block_size,
bool softmax_lse_flag,
const c10::optional<at::Tensor> &workspace,
at::Tensor &attention_out,
at::Tensor &softmax_lse)
{
return op_api::_npu_attention_pioneer_out(query, key, value, atten_mask,
actual_seq_lengths, actual_seq_lengths_kv, block_table, query_rope, key_rope,
key_sink, key_rope_sink, value_sink, num_heads, scale, pre_tokens, next_tokens,
input_layout, num_key_value_heads, sparse_mode, block_size, softmax_lse_flag,
workspace, attention_out, softmax_lse);
}
at::Tensor _npu_attention_pioneer_get_max_workspace_symint(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &atten_mask,
c10::OptionalArrayRef<c10::SymInt> actual_seq_lengths,
c10::OptionalArrayRef<c10::SymInt> actual_seq_lengths_kv,
const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
const c10::optional<at::Tensor> &key_sink,
const c10::optional<at::Tensor> &key_rope_sink,
const c10::optional<at::Tensor> &value_sink,
int64_t num_heads, double scale,
int64_t pre_tokens, int64_t next_tokens,
c10::string_view input_layout, int64_t num_key_value_heads,
int64_t sparse_mode, int64_t block_size,
bool softmax_lse_flag)
{
return op_api::_npu_attention_pioneer_get_max_workspace(query, key, value, atten_mask,
actual_seq_lengths, actual_seq_lengths_kv, block_table, query_rope, key_rope,
key_sink, key_rope_sink, value_sink, num_heads, scale, pre_tokens, next_tokens,
input_layout, num_key_value_heads, sparse_mode, block_size, softmax_lse_flag);
}
#endif
}