* Copyright (c) Huawei Technologies Co., Ltd. 2026-2026. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <torch/library.h>
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
#include "torch_npu/csrc/core/npu/NPUFormat.h"
#include "pytorch_npu_helper.h"
#include "quant_flash_attn.h"
using namespace at;
constexpr std::string_view QUANT_FLASH_ATTN_NAME = "aclnnQuantFlashAttn";
inline bool IsByteTensor(const at::Tensor &tensor) { return tensor.scalar_type() == at::ScalarType::Byte; }
inline void CheckQkvDtype(const char *name, const at::Tensor &tensor, const c10::optional<int64_t> &dtype) {
if (!IsByteTensor(tensor)) {
return;
}
TORCH_CHECK(dtype.has_value(), name, "_dtype can not be None when ", name, " is torch.uint8.");
TORCH_CHECK(dtype.value() == CANN_DTYPE_FLOAT4_E2M1 || dtype.value() == CANN_DTYPE_HIFLOAT8, name,
"_dtype must be torch_npu.float4_e2m1fn_x2 or torch_npu.hifloat8.");
}
inline void CheckDescaleDtype(const char *name, const at::Tensor &tensor, const c10::optional<int64_t> &dtype) {
if (!IsByteTensor(tensor)) {
return;
}
TORCH_CHECK(dtype.has_value(), name, "_dtype can not be None when ", name, " is torch.uint8.");
TORCH_CHECK(dtype.value() == CANN_DTYPE_FLOAT8_E8M0, name, "_dtype must be torch_npu.float8_e8m0fnu.");
}
inline at::SmallVector<int64_t, 8> GetQfaAttentionOutSize(const at::Tensor &query, const at::Tensor &value,
const c10::optional<int64_t> &q_dtype, const std::string &layout_q, const std::string &layout_out) {
int64_t tSize = 0;
int64_t nSize = 0;
int64_t dSize = 0;
int64_t sSize = 0;
int64_t bSize = 0;
if (layout_q == "TND") {
tSize = query.size(0);
nSize = query.size(1);
dSize = value.size(2);
} else if (layout_q == "BSND") {
bSize = query.size(0);
sSize = query.size(1);
nSize = query.size(2);
dSize = value.size(3);
} else {
bSize = query.size(0);
nSize = query.size(1);
sSize = query.size(2);
dSize = value.size(3);
}
int64_t qDtypeRatio = q_dtype.has_value() && q_dtype.value() == CANN_DTYPE_FLOAT4_E2M1 ? 2 : 1;
if (layout_out == "TND") {
return {tSize, nSize, qDtypeRatio * dSize};
}
if (layout_out == "BNSD") {
return {bSize, nSize, sSize, qDtypeRatio * dSize};
}
return {bSize, sSize, nSize, qDtypeRatio * dSize};
}
inline at::SmallVector<int64_t, 4> GetQfaSoftmaxLseSize(
const at::Tensor &query, const std::string &layout_q, int64_t return_softmax_lse) {
if (return_softmax_lse == 0) {
return {0};
}
if (query.dim() == 3) {
return {query.size(1), query.size(0)};
}
if (layout_q == "BSND") {
return {query.size(0), query.size(2), query.size(1)};
}
return {query.size(0), query.size(1), query.size(2)};
}
std::tuple<at::Tensor, at::Tensor> quant_flash_attn_impl_npu(const at::Tensor &query, const at::Tensor &key,
const at::Tensor &value, const at::Tensor &q_descale, const at::Tensor &k_descale, const at::Tensor &v_descale,
int64_t q_quant_mode, int64_t k_quant_mode, int64_t v_quant_mode, const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &cu_seqlens_q, const c10::optional<at::Tensor> &cu_seqlens_kv,
const c10::optional<at::Tensor> &seqused_q, const c10::optional<at::Tensor> &seqused_kv,
const c10::optional<at::Tensor> &sinks, const c10::optional<at::Tensor> &attn_mask,
const c10::optional<at::Tensor> &metadata, const c10::optional<int64_t> &q_dtype,
const c10::optional<int64_t> &k_dtype, const c10::optional<int64_t> &v_dtype,
const c10::optional<int64_t> &q_descale_dtype, const c10::optional<int64_t> &k_descale_dtype,
const c10::optional<int64_t> &v_descale_dtype, int64_t quant_block_size_qs, int64_t quant_block_size_ks,
int64_t quant_block_size_vs, double softmax_scale, int64_t mask_mode, int64_t win_left, int64_t win_right,
int64_t max_seqlen_q, int64_t max_seqlen_kv, std::string layout_q, std::string layout_kv, std::string layout_out,
int64_t softmax_precision, int64_t return_softmax_lse) {
CheckQkvDtype("q", query, q_dtype);
CheckQkvDtype("k", key, k_dtype);
CheckQkvDtype("v", value, v_dtype);
CheckDescaleDtype("q_descale", q_descale, q_descale_dtype);
CheckDescaleDtype("k_descale", k_descale, k_descale_dtype);
CheckDescaleDtype("v_descale", v_descale, v_descale_dtype);
const c10::string_view device = "npu";
at::Device outputDevice = at::Device(std::string(device));
auto outputOptions = query.options().device(outputDevice);
auto attentionOutSize = GetQfaAttentionOutSize(query, value, q_dtype, layout_q, layout_out);
auto softmaxLseSize = GetQfaSoftmaxLseSize(query, layout_q, return_softmax_lse);
at::Tensor attn_out = at_npu::native::empty_with_format(
attentionOutSize, outputOptions.dtype(at::kBFloat16), at_npu::native::get_npu_format(query));
at::Tensor softmax_lse = at_npu::native::empty_with_format(
softmaxLseSize, outputOptions.dtype(at::kFloat), at_npu::native::get_npu_format(query));
const char *layoutQPtr = layout_q.c_str();
const char *layoutKvPtr = layout_kv.c_str();
const char *layoutOutPtr = layout_out.c_str();
auto blockTableTensor = block_table.value_or(at::Tensor());
auto cuSeqlensQTensor = cu_seqlens_q.value_or(at::Tensor());
auto cuSeqlensKvTensor = cu_seqlens_kv.value_or(at::Tensor());
auto sequsedQTensor = seqused_q.value_or(at::Tensor());
auto sequsedKvTensor = seqused_kv.value_or(at::Tensor());
auto sinksTensor = sinks.value_or(at::Tensor());
auto attnMaskTensor = attn_mask.value_or(at::Tensor());
auto metadataTensor = metadata.value_or(at::Tensor());
auto queryWrapper = MakeTensorWrapper(query, q_dtype);
auto keyWrapper = MakeTensorWrapper(key, k_dtype);
auto valueWrapper = MakeTensorWrapper(value, v_dtype);
auto qDescaleWrapper = MakeTensorWrapper(q_descale, q_descale_dtype);
auto kDescaleWrapper = MakeTensorWrapper(k_descale, k_descale_dtype);
auto vDescaleWrapper = MakeTensorWrapper(v_descale, v_descale_dtype);
EXEC_NPU_CMD<QUANT_FLASH_ATTN_NAME>(queryWrapper, keyWrapper, valueWrapper, qDescaleWrapper, kDescaleWrapper,
vDescaleWrapper, blockTableTensor, cuSeqlensQTensor, cuSeqlensKvTensor, sequsedQTensor, sequsedKvTensor,
sinksTensor, attnMaskTensor, metadataTensor, q_quant_mode, k_quant_mode, v_quant_mode, quant_block_size_qs,
quant_block_size_ks, quant_block_size_vs, softmax_scale, mask_mode, win_left, win_right, max_seqlen_q,
max_seqlen_kv, layoutQPtr, layoutKvPtr, layoutOutPtr, softmax_precision, return_softmax_lse, attn_out,
softmax_lse);
return std::make_tuple(attn_out, softmax_lse);
}