#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/custom_functions/opapi/update_op_api_common.h"
namespace op_api {
#if VERSION_BETWEEN(V2R1, VERSION_NEWEST)
const static int FLASH_THRESHOLD = 512;
const static int64_t PFA_SPARSE_HIGH_PRECISION_NO_MASK = 10;
const static int64_t PFA_SPARSE_HIGH_PRECISION_BAND = 14;
const static int64_t DIM_0 = 0;
const static int64_t DIM_1 = 1;
const static int64_t DIM_2 = 2;
const static int64_t DIM_3 = 3;
const static int64_t DIM_4 = 4;
const static int64_t DIM_NUMS_3 = 3;
const static int64_t DIM_NUMS_4 = 4;
const static int64_t PA_BBH_DIMS = 3;
const static int64_t PA_BNBD_DIMS = 4;
const static int64_t PA_NZ_DIMS = 5;
using namespace at_npu::native;
using npu_preparation = at_npu::native::OpPreparation;
static std::pair<std::string, std::string> get_query_and_attention_out_layout(
const at::Tensor query,
std::string input_layout_str)
{
struct parserLayout {
std::string qLayout;
std::string outLayout;
int32_t qDim;
};
const std::map<std::string, parserLayout> LAYOUT_MAP = {
{"BSH", {"BSH", "BSH", DIM_NUMS_3}},
{"BSND", {"BSND", "BSND", DIM_NUMS_4}},
{"BNSD", {"BNSD", "BNSD", DIM_NUMS_4}},
{"TND", {"TND", "TND", DIM_NUMS_3}},
{"NTD", {"NTD", "NTD", DIM_NUMS_3}},
{"BNSD_BSND", {"BNSD", "BSND", DIM_NUMS_4}},
{"BSH_BNSD", {"BSH", "BNSD", DIM_NUMS_3}},
{"BSND_BNSD", {"BSND", "BNSD", DIM_NUMS_4}},
{"NTD_TND", {"NTD", "TND", DIM_NUMS_3}},
{"BSH_NBSD", {"BSH", "NBSD", DIM_NUMS_3}},
{"BSND_NBSD", {"BSND", "NBSD", DIM_NUMS_4}},
{"BNSD_NBSD", {"BNSD", "NBSD", DIM_NUMS_4}},
{"TND_NTD", {"TND", "NTD", DIM_NUMS_3}},
{"NSD", {"NSD", "NSD", DIM_NUMS_3}}
};
std::string query_layout = "BSH";
std::string attention_out_layout = "BSH";
int32_t query_dim;
auto it = LAYOUT_MAP.find(input_layout_str);
if (it != LAYOUT_MAP.end()) {
query_layout = it->second.qLayout;
attention_out_layout = it->second.outLayout;
query_dim = it->second.qDim;
TORCH_CHECK(query.dim() == query_dim,
"query's dim should be consistent with that of Layout", OPS_ERROR(ErrCode::VALUE));
} else {
TORCH_CHECK(
false,
"layout only support BSH, BSND, TND, NTD, BNSD_BSND, BSH_BNSD, BSND_BNSD, NTD_TND, ",
"BSH_NBSD, BSND_NBSD, BNSD_NBSD, TND_NTD, but got ",
query_layout,
OPS_ERROR(ErrCode::VALUE)
);
}
return {query_layout, attention_out_layout};
}
static std::tuple<int64_t, int64_t, int64_t, int64_t> get_query_b_n_s_d(
const at::Tensor &query,
std::string query_layout,
int64_t num_heads)
{
int64_t b = 0;
int64_t n1 = 0;
int64_t s1 = 0;
int64_t d1 = 0;
if (query_layout == "BSH") {
b = query.size(DIM_0);
s1 = query.size(DIM_1);
n1 = num_heads;
d1 = query.size(DIM_2) / num_heads;
} else if (query_layout == "BSND") {
b = query.size(DIM_0);
s1 = query.size(DIM_1);
n1 = query.size(DIM_2);
d1 = query.size(DIM_3);
} else if (query_layout == "BNSD") {
b = query.size(DIM_0);
s1 = query.size(DIM_2);
n1 = query.size(DIM_1);
d1 = query.size(DIM_3);
} else if (query_layout == "NSD") {
b = 1;
s1 = query.size(DIM_1);
n1 = query.size(DIM_0);
d1 = query.size(DIM_2);
} else {
TORCH_CHECK(
false,
"It is not supported in get_query_b_n_s_d function, layout ",
query_layout,
OPS_ERROR(ErrCode::VALUE)
);
}
return {b, n1, s1, d1};
}
static std::tuple<int64_t, int64_t, int64_t> get_query_t_n_d(
const at::Tensor &query,
std::string query_layout)
{
int64_t t = 0;
int64_t n1 = 0;
int64_t d1 = 0;
if (query_layout == "TND") {
t = query.size(DIM_0);
n1 = query.size(DIM_1);
d1 = query.size(DIM_2);
} else if (query_layout == "NTD") {
t = query.size(DIM_1);
n1 = query.size(DIM_0);
d1 = query.size(DIM_2);
} else {
TORCH_CHECK(
false,
"It is not supported in get_query_t_n_d function, layout ",
query_layout,
OPS_ERROR(ErrCode::VALUE)
);
}
return {t, n1, d1};
}
static int64_t get_value_d(
const c10::optional<at::Tensor> &block_table,
const at::Tensor &query,
const at::Tensor &value,
std::string query_layout,
int64_t kv_num_heads)
{
int64_t valueD = 0;
if (block_table.has_value()) {
if (value.dim() == PA_BBH_DIMS) {
valueD = value.size(DIM_2) / kv_num_heads;
} else if (value.dim() == PA_BNBD_DIMS) {
valueD = value.size(DIM_3);
} else if (value.dim() == PA_NZ_DIMS) {
valueD = value.size(DIM_2) * value.size(DIM_4);
} else {
TORCH_CHECK(
false,
"when Page Attention enabled, value's dim should be 3/4/5, but got ",
value.dim(),
OPS_ERROR(ErrCode::VALUE)
);
}
} else {
TORCH_CHECK(
value.dim() == query.dim(),
"when Page Attention not enabled, value'dim should equal to query's dim!",
OPS_ERROR(ErrCode::VALUE)
);
if (query_layout == "BSH") {
valueD = value.size(DIM_2) / kv_num_heads;
} else if (query_layout == "BSND" || query_layout == "BNSD") {
valueD = value.size(DIM_3);
} else if (query_layout == "TND" || query_layout == "NTD" || query_layout == "NSD") {
valueD = value.size(DIM_2);
}
}
return valueD;
}
static int get_change_d_scale_v2(
const at::Tensor &value,
c10::optional<int64_t> value_dtype)
{
const static int changeDScale = 1;
const static int changeDForInt32 = 8;
const static int changeDForFP4 = 2;
if (value.scalar_type() == at::kInt) {
return changeDForInt32;
}
aclDataType value_acl_type = c10_npu::GetAclDataType(value_dtype.value_or(static_cast<int64_t>(value.scalar_type())));
if (value_acl_type == aclDataType::ACL_FLOAT4_E1M2 || value_acl_type == aclDataType::ACL_FLOAT4_E2M1) {
return changeDForFP4;
}
return changeDScale;
}
static at::Tensor infer_attention_out_shape(
std::string attention_out_layout,
const at::Tensor &query,
std::string query_layout,
int64_t num_heads,
int64_t valueD)
{
int64_t b = 0;
int64_t n1 = 0;
int64_t s1 = 0;
int64_t d1 = 0;
int64_t t = 0;
at::Tensor attention_out = npu_preparation::apply_tensor_without_format(query);
if (attention_out_layout == "BSH") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outH = num_heads * valueD;
outH = (outH == 0 || query.size(DIM_2) == 0) ? query.size(DIM_2) : outH;
attention_out = OpPreparation::apply_tensor_without_format(
{b, s1, outH},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "BSND") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{b, s1, n1, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "BNSD") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{b, n1, s1, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "NBSD") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{n1, b, s1, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "TND") {
auto [t, n1, d1] = get_query_t_n_d(query, query_layout);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{t, n1, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "NTD") {
auto [t, n1, d1] = get_query_t_n_d(query, query_layout);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{n1, t, outD},
query.options().dtype(query.dtype())
);
} else if (attention_out_layout == "NSD") {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
int outD = valueD;
outD = (outD == 0 || d1 == 0) ? d1 : outD;
attention_out = OpPreparation::apply_tensor_without_format(
{n1, s1, outD},
query.options().dtype(query.dtype())
);
}
return attention_out;
}
static at::Tensor infer_lse_out_shape(
std::string input_layout_str,
const at::Tensor &query,
std::string query_layout,
int64_t num_heads)
{
int64_t b = 0;
int64_t n1 = 0;
int64_t s1 = 0;
int64_t d1 = 0;
int64_t t = 0;
at::Tensor lse_out;
if (input_layout_str == "TND" || input_layout_str == "NTD" ||
input_layout_str == "TND_NTD" || input_layout_str == "NTD_TND") {
auto [t, n1, d1] = get_query_t_n_d(query, query_layout);
lse_out = npu_preparation::apply_tensor_without_format({t, n1, 1}, c10::dtype(c10::ScalarType::Float));
} else {
auto [b, n1, s1, d1] = get_query_b_n_s_d(query, query_layout, num_heads);
lse_out = npu_preparation::apply_tensor_without_format({b, n1, s1, 1}, c10::dtype(c10::ScalarType::Float));
}
return lse_out;
}
std::tuple<at::Tensor, at::Tensor> construct_fia_output_tensor_v2(
const at::Tensor &query,
const at::Tensor &value,
c10::optional<int64_t> query_dtype,
c10::optional<int64_t> value_dtype,
std::string input_layout_str,
const c10::optional<at::Tensor> &quant_scale_out,
const c10::optional<at::Tensor> &block_table,
int64_t num_query_heads,
int64_t num_key_value_heads,
bool return_softmax_lse,
const c10::optional<at::Tensor> &query_rope,
c10::optional<int64_t> out_dtype)
{
TORCH_CHECK(
num_query_heads > 0,
"num_heads should be greater than 0, but the actual value is",
num_query_heads,
OPS_ERROR(ErrCode::VALUE)
);
num_key_value_heads = (num_key_value_heads == 0) ? num_query_heads : num_key_value_heads;
auto [query_layout, attention_out_layout] = get_query_and_attention_out_layout(query, input_layout_str);
int64_t valueD = get_value_d(block_table, query, value, query_layout, num_key_value_heads);
int changeDScale = get_change_d_scale_v2(value, value_dtype);
valueD = valueD * changeDScale;
at::Tensor tmp_output = infer_attention_out_shape(attention_out_layout, query, query_layout, num_query_heads, valueD);
bool is_hifloat8_input = query.dtype() == at::kByte && query_dtype.has_value() && c10_npu::GetAclDataType(query_dtype.value()) == aclDataType::ACL_HIFLOAT8;
at::Tensor output;
if (quant_scale_out.has_value()) {
at::ScalarType output_type = at::ScalarType::Char;
if (out_dtype.has_value()) {
output_type = c10_npu::GetATenDType(out_dtype.value());
}
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), output_type);
} else if (query.dtype() == at::kChar || query.dtype() == at::ScalarType::Float8_e4m3fn || is_hifloat8_input) {
if (out_dtype.has_value()) {
at::ScalarType output_type = c10_npu::GetATenDType(out_dtype.value());
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), output_type);
} else if (query_rope.has_value()) {
const at::Tensor &query_rope_tensor = c10::value_or_else(query_rope, [] { return at::Tensor(); });
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), c10::dtype(query_rope_tensor.dtype()));
} else {
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), c10::dtype(c10::ScalarType::Half));
}
} else {
output = npu_preparation::apply_tensor_without_format(tmp_output);
}
at::Tensor softmax_lse;
if (return_softmax_lse) {
softmax_lse = infer_lse_out_shape(input_layout_str, query, query_layout, num_query_heads);
} else {
softmax_lse = npu_preparation::apply_tensor_without_format({0}, c10::dtype(c10::ScalarType::Float));
}
return std::tuple<at::Tensor, at::Tensor>(output, softmax_lse);
}
std::tuple<at::Tensor, at::Tensor> npu_fused_infer_attention_score_v2_symint(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
const c10::optional<at::Tensor> &pse_shift,
const c10::optional<at::Tensor> &atten_mask,
c10::OptionalArrayRef<c10::SymInt> actual_seq_qlen,
c10::OptionalArrayRef<c10::SymInt> actual_seq_kvlen,
const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &dequant_scale_query,
const c10::optional<at::Tensor> &dequant_scale_key,
const c10::optional<at::Tensor> &dequant_offset_key,
const c10::optional<at::Tensor> &dequant_scale_value,
const c10::optional<at::Tensor> &dequant_offset_value,
const c10::optional<at::Tensor> &dequant_scale_key_rope,
const c10::optional<at::Tensor> &quant_scale_out,
const c10::optional<at::Tensor> &quant_offset_out,
const c10::optional<at::Tensor> &quant_scale_p,
const c10::optional<at::Tensor> &learnable_sink,
int64_t num_query_heads, int64_t num_key_value_heads, double softmax_scale,
int64_t pre_tokens, int64_t next_tokens, c10::string_view input_layout,
int64_t sparse_mode, int64_t block_size,
int64_t query_quant_mode, int64_t key_quant_mode, int64_t value_quant_mode,
int64_t inner_precise, bool return_softmax_lse,
c10::optional<int64_t> query_dtype, c10::optional<int64_t> key_dtype, c10::optional<int64_t> value_dtype,
c10::optional<int64_t> query_rope_dtype, c10::optional<int64_t> key_rope_dtype,
c10::optional<int64_t> key_shared_prefix_dtype, c10::optional<int64_t> value_shared_prefix_dtype,
c10::optional<int64_t> dequant_scale_query_dtype, c10::optional<int64_t> dequant_scale_key_dtype,
c10::optional<int64_t> dequant_scale_value_dtype, c10::optional<int64_t> dequant_scale_key_rope_dtype,
c10::optional<int64_t> out_dtype)
{
std::string input_layout_str = std::string(input_layout);
std::tuple<at::Tensor, at::Tensor> fia_output = op_api::construct_fia_output_tensor_v2(query, value, query_dtype, value_dtype, input_layout_str, quant_scale_out, block_table,
num_query_heads, num_key_value_heads, return_softmax_lse, query_rope, out_dtype);
at::Tensor output = std::get<0>(fia_output);
at::Tensor softmax_lse = std::get<1>(fia_output);
char *input_layout_ptr = const_cast<char *>(input_layout_str.c_str());
at::Tensor default_actual_shared_prefix_len {nullptr};
at::Tensor default_q_start_idx {nullptr};
at::Tensor default_kv_start_idx {nullptr};
at::Tensor dequant_scale1;
at::Tensor quant_scale1;
at::Tensor dequant_scale2;
at::Tensor antiquant_scale;
at::Tensor antiquant_offset;
at::Tensor query_padding_size;
at::Tensor kv_padding_size;
at::Tensor key_shared_prefix;
at::Tensor value_shared_prefix;
int64_t antiquant_mode = 0;
int64_t default_pse_type_value = 0;
at::TensorList valueTensors = value;
at::TensorList keyTensors = key;
TensorWrapper query_wrapper = make_wrapper(query, query_dtype);
TensorListWrapper keyTensors_wrapper = make_wrapper(keyTensors, key_dtype);
TensorListWrapper valueTensors_wrapper = make_wrapper(valueTensors, value_dtype);
TensorWrapper outTensor_wrapper = make_wrapper(output, out_dtype);
at::Tensor null_tensor;
auto query_rope_tmp = query_rope.has_value() ? query_rope.value() : null_tensor;
TensorWrapper query_rope_wrapper = make_wrapper(query_rope_tmp, query_rope_dtype);
auto key_rope_tmp = key_rope.has_value() ? key_rope.value() : null_tensor;
TensorWrapper key_rope_wrapper = make_wrapper(key_rope_tmp, key_rope_dtype);
auto dequant_scale_query_tmp = dequant_scale_query.has_value() ? dequant_scale_query.value() : null_tensor;
TensorWrapper dequant_scale_query_wrapper = make_wrapper(dequant_scale_query_tmp, dequant_scale_query_dtype);
auto dequant_scale_key_tmp = dequant_scale_key.has_value() ? dequant_scale_key.value() : null_tensor;
TensorWrapper dequant_scale_key_wrapper = make_wrapper(dequant_scale_key_tmp, dequant_scale_key_dtype);
auto dequant_scale_value_tmp = dequant_scale_value.has_value() ? dequant_scale_value.value() : null_tensor;
TensorWrapper dequant_scale_value_wrapper = make_wrapper(dequant_scale_value_tmp, dequant_scale_value_dtype);
if (c10_npu::GetSocVersion() != c10_npu::SocVersion::Ascend950) {
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnFusedInferAttentionScoreV4, query_wrapper, keyTensors_wrapper, valueTensors_wrapper, pse_shift, atten_mask, actual_seq_qlen, actual_seq_kvlen, dequant_scale1, quant_scale1, dequant_scale2,
quant_scale_out, quant_offset_out, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, dequant_scale_key_wrapper, dequant_offset_key, dequant_scale_value_wrapper,
dequant_offset_value, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix_len, query_rope_wrapper, key_rope_wrapper, dequant_scale_key_rope, dequant_scale_query_wrapper, learnable_sink, num_query_heads, softmax_scale, pre_tokens, next_tokens, input_layout_ptr,
num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, return_softmax_lse, key_quant_mode, value_quant_mode, query_quant_mode, outTensor_wrapper, softmax_lse);
} else {
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnFusedInferAttentionScoreV5, query_wrapper, keyTensors_wrapper, valueTensors_wrapper, pse_shift, atten_mask, actual_seq_qlen, actual_seq_kvlen, dequant_scale1, quant_scale_p, dequant_scale2,
quant_scale_out, quant_offset_out, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, dequant_scale_key_wrapper, dequant_offset_key, dequant_scale_value_wrapper,
dequant_offset_value, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix_len, query_rope_wrapper, key_rope_wrapper, dequant_scale_key_rope, dequant_scale_query_wrapper, learnable_sink, default_q_start_idx, default_kv_start_idx, num_query_heads, softmax_scale, pre_tokens, next_tokens, input_layout_ptr,
num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, return_softmax_lse, key_quant_mode, value_quant_mode, query_quant_mode, default_pse_type_value, outTensor_wrapper, softmax_lse);
}
return std::tuple<at::Tensor, at::Tensor>(output, softmax_lse);
}
std::tuple<at::Tensor &, at::Tensor &> npu_fused_infer_attention_score_v2_out_symint(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
const c10::optional<at::Tensor> &pse_shift,
const c10::optional<at::Tensor> &atten_mask,
c10::OptionalArrayRef<c10::SymInt> actual_seq_qlen,
c10::OptionalArrayRef<c10::SymInt> actual_seq_kvlen,
const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &dequant_scale_query,
const c10::optional<at::Tensor> &dequant_scale_key,
const c10::optional<at::Tensor> &dequant_offset_key,
const c10::optional<at::Tensor> &dequant_scale_value,
const c10::optional<at::Tensor> &dequant_offset_value,
const c10::optional<at::Tensor> &dequant_scale_key_rope,
const c10::optional<at::Tensor> &quant_scale_out,
const c10::optional<at::Tensor> &quant_offset_out,
const c10::optional<at::Tensor> &quant_scale_p,
const c10::optional<at::Tensor> &learnable_sink,
int64_t num_query_heads, int64_t num_key_value_heads, double softmax_scale,
int64_t pre_tokens, int64_t next_tokens, c10::string_view input_layout,
int64_t sparse_mode, int64_t block_size,
int64_t query_quant_mode, int64_t key_quant_mode, int64_t value_quant_mode,
int64_t inner_precise, bool return_softmax_lse,
c10::optional<int64_t> query_dtype, c10::optional<int64_t> key_dtype, c10::optional<int64_t> value_dtype,
c10::optional<int64_t> query_rope_dtype, c10::optional<int64_t> key_rope_dtype,
c10::optional<int64_t> key_shared_prefix_dtype, c10::optional<int64_t> value_shared_prefix_dtype,
c10::optional<int64_t> dequant_scale_query_dtype, c10::optional<int64_t> dequant_scale_key_dtype,
c10::optional<int64_t> dequant_scale_value_dtype, c10::optional<int64_t> dequant_scale_key_rope_dtype,
c10::optional<int64_t> out_dtype,
const c10::optional<at::Tensor> &workspace,
at::Tensor &attention_out,
at::Tensor &softmax_lse)
{
std::string input_layout_str = std::string(input_layout);
char *input_layout_ptr = const_cast<char *>(input_layout_str.c_str());
at::Tensor default_actual_shared_prefix_len {nullptr};
at::Tensor default_q_start_idx {nullptr};
at::Tensor default_kv_start_idx {nullptr};
at::Tensor dequant_scale1;
at::Tensor quant_scale1;
at::Tensor dequant_scale2;
at::Tensor antiquant_scale;
at::Tensor antiquant_offset;
at::Tensor query_padding_size;
at::Tensor kv_padding_size;
at::Tensor key_shared_prefix;
at::Tensor value_shared_prefix;
int64_t antiquant_mode = 0;
int64_t default_pse_type_value = 0;
at::TensorList valueTensors = value;
at::TensorList keyTensors = key;
TensorWrapper query_wrapper = make_wrapper(query, query_dtype);
TensorListWrapper keyTensors_wrapper = make_wrapper(keyTensors, key_dtype);
TensorListWrapper valueTensors_wrapper = make_wrapper(valueTensors, value_dtype);
TensorWrapper outTensor_wrapper = make_wrapper(attention_out, out_dtype);
at::Tensor null_tensor;
auto query_rope_tmp = query_rope.has_value() ? query_rope.value() : null_tensor;
TensorWrapper query_rope_wrapper = make_wrapper(query_rope_tmp, query_rope_dtype);
auto key_rope_tmp = key_rope.has_value() ? key_rope.value() : null_tensor;
TensorWrapper key_rope_wrapper = make_wrapper(key_rope_tmp, key_rope_dtype);
auto dequant_scale_query_tmp = dequant_scale_query.has_value() ? dequant_scale_query.value() : null_tensor;
TensorWrapper dequant_scale_query_wrapper = make_wrapper(dequant_scale_query_tmp, dequant_scale_query_dtype);
auto dequant_scale_key_tmp = dequant_scale_key.has_value() ? dequant_scale_key.value() : null_tensor;
TensorWrapper dequant_scale_key_wrapper = make_wrapper(dequant_scale_key_tmp, dequant_scale_key_dtype);
auto dequant_scale_value_tmp = dequant_scale_value.has_value() ? dequant_scale_value.value() : null_tensor;
TensorWrapper dequant_scale_value_wrapper = make_wrapper(dequant_scale_value_tmp, dequant_scale_value_dtype);
if (c10_npu::GetSocVersion() != c10_npu::SocVersion::Ascend950) {
if (workspace.has_value()) {
void* workspace_addr = const_cast<void *>(workspace.value().storage().data());
uint64_t workspace_size = static_cast<uint64_t>(workspace.value().numel() * workspace.value().element_size());
EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD(aclnnFusedInferAttentionScoreV4, workspace_addr, workspace_size, query_wrapper, keyTensors_wrapper, valueTensors_wrapper, pse_shift, atten_mask, actual_seq_qlen, actual_seq_kvlen, dequant_scale1, quant_scale1, dequant_scale2,
quant_scale_out, quant_offset_out, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, dequant_scale_key_wrapper, dequant_offset_key, dequant_scale_value_wrapper,
dequant_offset_value, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix_len, query_rope_wrapper, key_rope_wrapper, dequant_scale_key_rope, dequant_scale_query_wrapper, learnable_sink, num_query_heads, softmax_scale, pre_tokens, next_tokens, input_layout_ptr,
num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, return_softmax_lse, key_quant_mode, value_quant_mode, query_quant_mode, outTensor_wrapper, softmax_lse);
} else {
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnFusedInferAttentionScoreV4, query_wrapper, keyTensors_wrapper, valueTensors_wrapper, pse_shift, atten_mask, actual_seq_qlen, actual_seq_kvlen, dequant_scale1, quant_scale1, dequant_scale2,
quant_scale_out, quant_offset_out, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, dequant_scale_key_wrapper, dequant_offset_key, dequant_scale_value_wrapper,
dequant_offset_value, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix_len, query_rope_wrapper, key_rope_wrapper, dequant_scale_key_rope, dequant_scale_query_wrapper, learnable_sink, num_query_heads, softmax_scale, pre_tokens, next_tokens, input_layout_ptr,
num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, return_softmax_lse, key_quant_mode, value_quant_mode, query_quant_mode, outTensor_wrapper, softmax_lse);
}
} else {
if (workspace.has_value()) {
void* workspace_addr = const_cast<void *>(workspace.value().storage().data());
uint64_t workspace_size = static_cast<uint64_t>(workspace.value().numel() * workspace.value().element_size());
EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD(aclnnFusedInferAttentionScoreV5, workspace_addr, workspace_size, query_wrapper, keyTensors_wrapper, valueTensors_wrapper, pse_shift, atten_mask, actual_seq_qlen, actual_seq_kvlen, dequant_scale1, quant_scale_p, dequant_scale2,
quant_scale_out, quant_offset_out, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, dequant_scale_key_wrapper, dequant_offset_key, dequant_scale_value_wrapper,
dequant_offset_value, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix_len, query_rope_wrapper, key_rope_wrapper, dequant_scale_key_rope, dequant_scale_query_wrapper, learnable_sink, default_q_start_idx, default_kv_start_idx, num_query_heads, softmax_scale, pre_tokens, next_tokens, input_layout_ptr,
num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, return_softmax_lse, key_quant_mode, value_quant_mode, query_quant_mode, default_pse_type_value, outTensor_wrapper, softmax_lse);
} else {
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnFusedInferAttentionScoreV5, query_wrapper, keyTensors_wrapper, valueTensors_wrapper, pse_shift, atten_mask, actual_seq_qlen, actual_seq_kvlen, dequant_scale1, quant_scale_p, dequant_scale2,
quant_scale_out, quant_offset_out, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, dequant_scale_key_wrapper, dequant_offset_key, dequant_scale_value_wrapper,
dequant_offset_value, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix_len, query_rope_wrapper, key_rope_wrapper, dequant_scale_key_rope, dequant_scale_query_wrapper, learnable_sink, default_q_start_idx, default_kv_start_idx, num_query_heads, softmax_scale, pre_tokens, next_tokens, input_layout_ptr,
num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, return_softmax_lse, key_quant_mode, value_quant_mode, query_quant_mode, default_pse_type_value, outTensor_wrapper, softmax_lse);
}
}
return std::tuple<at::Tensor&, at::Tensor&>(attention_out, softmax_lse);
}
at::Tensor _npu_fused_infer_attention_score_v2_get_max_workspace_symint(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
const c10::optional<at::Tensor> &pse_shift,
const c10::optional<at::Tensor> &atten_mask,
c10::OptionalArrayRef<c10::SymInt> actual_seq_qlen,
c10::OptionalArrayRef<c10::SymInt> actual_seq_kvlen,
const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &dequant_scale_query,
const c10::optional<at::Tensor> &dequant_scale_key,
const c10::optional<at::Tensor> &dequant_offset_key,
const c10::optional<at::Tensor> &dequant_scale_value,
const c10::optional<at::Tensor> &dequant_offset_value,
const c10::optional<at::Tensor> &dequant_scale_key_rope,
const c10::optional<at::Tensor> &quant_scale_out,
const c10::optional<at::Tensor> &quant_offset_out,
const c10::optional<at::Tensor> &quant_scale_p,
const c10::optional<at::Tensor> &learnable_sink,
int64_t num_query_heads, int64_t num_key_value_heads, double softmax_scale,
int64_t pre_tokens, int64_t next_tokens, c10::string_view input_layout,
int64_t sparse_mode, int64_t block_size,
int64_t query_quant_mode, int64_t key_quant_mode, int64_t value_quant_mode,
int64_t inner_precise, bool return_softmax_lse,
c10::optional<int64_t> query_dtype, c10::optional<int64_t> key_dtype, c10::optional<int64_t> value_dtype,
c10::optional<int64_t> query_rope_dtype, c10::optional<int64_t> key_rope_dtype,
c10::optional<int64_t> key_shared_prefix_dtype, c10::optional<int64_t> value_shared_prefix_dtype,
c10::optional<int64_t> dequant_scale_query_dtype, c10::optional<int64_t> dequant_scale_key_dtype,
c10::optional<int64_t> dequant_scale_value_dtype, c10::optional<int64_t> dequant_scale_key_rope_dtype,
c10::optional<int64_t> out_dtype)
{
std::string input_layout_str = std::string(input_layout);
std::tuple<at::Tensor, at::Tensor> fia_output = op_api::construct_fia_output_tensor_v2(query, value, query_dtype, value_dtype, input_layout_str, quant_scale_out, block_table,
num_query_heads, num_key_value_heads, return_softmax_lse, query_rope, out_dtype);
at::Tensor output = std::get<0>(fia_output);
at::Tensor softmax_lse = std::get<1>(fia_output);
char *input_layout_ptr = const_cast<char *>(input_layout_str.c_str());
at::Tensor default_actual_shared_prefix_len {nullptr};
at::Tensor default_q_start_idx {nullptr};
at::Tensor default_kv_start_idx {nullptr};
at::Tensor dequant_scale1;
at::Tensor quant_scale1;
at::Tensor dequant_scale2;
at::Tensor antiquant_scale;
at::Tensor antiquant_offset;
at::Tensor query_padding_size;
at::Tensor kv_padding_size;
at::Tensor key_shared_prefix;
at::Tensor value_shared_prefix;
int64_t antiquant_mode = 0;
int64_t default_pse_type_value = 0;
at::TensorList valueTensors = value;
at::TensorList keyTensors = key;
TensorWrapper query_wrapper = make_wrapper(query, query_dtype);
TensorListWrapper keyTensors_wrapper = make_wrapper(keyTensors, key_dtype);
TensorListWrapper valueTensors_wrapper = make_wrapper(valueTensors, value_dtype);
TensorWrapper outTensor_wrapper = make_wrapper(output, out_dtype);
at::Tensor null_tensor;
auto query_rope_tmp = query_rope.has_value() ? query_rope.value() : null_tensor;
TensorWrapper query_rope_wrapper = make_wrapper(query_rope_tmp, query_rope_dtype);
auto key_rope_tmp = key_rope.has_value() ? key_rope.value() : null_tensor;
TensorWrapper key_rope_wrapper = make_wrapper(key_rope_tmp, key_rope_dtype);
auto dequant_scale_query_tmp = dequant_scale_query.has_value() ? dequant_scale_query.value() : null_tensor;
TensorWrapper dequant_scale_query_wrapper = make_wrapper(dequant_scale_query_tmp, dequant_scale_query_dtype);
auto dequant_scale_key_tmp = dequant_scale_key.has_value() ? dequant_scale_key.value() : null_tensor;
TensorWrapper dequant_scale_key_wrapper = make_wrapper(dequant_scale_key_tmp, dequant_scale_key_dtype);
auto dequant_scale_value_tmp = dequant_scale_value.has_value() ? dequant_scale_value.value() : null_tensor;
TensorWrapper dequant_scale_value_wrapper = make_wrapper(dequant_scale_value_tmp, dequant_scale_value_dtype);
uint64_t workspace_size = 0;
if (c10_npu::GetSocVersion() != c10_npu::SocVersion::Ascend950) {
workspace_size = EXEC_GET_MAX_WORKSPACE_CMD(aclnnFusedInferAttentionScoreV4, query_wrapper, keyTensors_wrapper, valueTensors_wrapper, pse_shift, atten_mask, actual_seq_qlen, actual_seq_kvlen, dequant_scale1, quant_scale1, dequant_scale2,
quant_scale_out, quant_offset_out, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, dequant_scale_key_wrapper, dequant_offset_key, dequant_scale_value_wrapper,
dequant_offset_value, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix_len, query_rope_wrapper, key_rope_wrapper, dequant_scale_key_rope, dequant_scale_query_wrapper, learnable_sink, num_query_heads, softmax_scale, pre_tokens, next_tokens, input_layout_ptr,
num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, return_softmax_lse, key_quant_mode, value_quant_mode, query_quant_mode, outTensor_wrapper, softmax_lse);
} else {
workspace_size = EXEC_GET_MAX_WORKSPACE_CMD(aclnnFusedInferAttentionScoreV5, query_wrapper, keyTensors_wrapper, valueTensors_wrapper, pse_shift, atten_mask, actual_seq_qlen, actual_seq_kvlen, dequant_scale1, quant_scale_p, dequant_scale2,
quant_scale_out, quant_offset_out, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, dequant_scale_key_wrapper, dequant_offset_key, dequant_scale_value_wrapper,
dequant_offset_value, key_shared_prefix, value_shared_prefix, default_actual_shared_prefix_len, query_rope_wrapper, key_rope_wrapper, dequant_scale_key_rope, dequant_scale_query_wrapper, learnable_sink, default_q_start_idx, default_kv_start_idx, num_query_heads, softmax_scale, pre_tokens, next_tokens, input_layout_ptr,
num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, return_softmax_lse, key_quant_mode, value_quant_mode, query_quant_mode, default_pse_type_value, outTensor_wrapper, softmax_lse);
}
at::Tensor workspace_tensor = npu_preparation::apply_tensor_without_format({workspace_size}, query.options().dtype(query.dtype()));
return workspace_tensor;
}
std::tuple<at::Tensor, at::Tensor> _npu_fused_infer_attention_score_v2_infer_output(
const at::Tensor &query,
const at::Tensor &value,
c10::optional<int64_t> query_dtype,
c10::optional<int64_t> value_dtype,
c10::string_view input_layout,
const c10::optional<at::Tensor> &quant_scale_out,
const c10::optional<at::Tensor> &block_table,
int64_t num_query_heads,
int64_t num_key_value_heads,
bool return_softmax_lse,
const c10::optional<at::Tensor> &query_rope,
c10::optional<int64_t> out_dtype)
{
std::string input_layout_str = std::string(input_layout);
std::tuple<at::Tensor, at::Tensor> fia_output = op_api::construct_fia_output_tensor_v2(query, value, query_dtype, value_dtype, input_layout_str,
quant_scale_out, block_table, num_query_heads, num_key_value_heads,
return_softmax_lse, query_rope, out_dtype);
return fia_output;
}
#endif
}