* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file aclnn_flash_attn.h
* \brief
*/
#ifndef OP_API_INC_LEVEL2_ACLNN_FLASH_ATTN_H_
#define OP_API_INC_LEVEL2_ACLNN_FLASH_ATTN_H_
#include "aclnn/aclnn_base.h"
#ifdef __cplusplus
extern "C" {
#endif
* @brief aclnnFlashAttn的第一段接口,根据具体的计算流程,计算workspace大小。
* @domain aclnn_ops_train_infer
*
* @param q [IN] query tensor。数据类型FLOAT16/BFLOAT16,ND格式。
* layout由layoutQ决定,支持BSND/BNSD/TND。
* BSND: shape=(B, S_q, N_q, D)
* BNSD: shape=(B, N_q, S_q, D)
* TND: shape=(T_q, N_q, D),T_q为总token数
* @param k [IN] key tensor。数据类型与q一致,layout由layoutKv决定。
* BSND: shape=(B, S_kv, N_kv, D)
* TND: shape=(T_kv, N_kv, D)
* PA_ND: shape=(NumBlocks, BlockSize, N_kv, D),分页注意力格式
* PA_Nz: shape=(NumBlocks, N_kv, BlockSize/16, D, 16),分页注意力NZ格式
* @param v [IN] value tensor。数据类型与q一致,layout由layoutKv决定,shape同k。
* @param blockTableOptional [IN] 可选。分页KV缓存的块映射表,数据类型INT32。
* shape=(B, MaxBlocksPerSeq),与layoutKv=PA_ND/PA_Nz配合使用。
* @param cuSeqlensQOptional [IN] 可选。query的累积序列长度(前缀和),数据类型INT32。
* shape=(B+1,),cu_seqlens_q[i+1]-cu_seqlens_q[i]为第i个sample的q序列长度。
* TND layout变长场景时有效。与seqused_q互斥。
* @param cuSeqlensKvOptional [IN] 可选。kv的累积序列长度(前缀和),数据类型INT32。
* shape=(B+1,),与cuSeqlensQOptional配合使用。
* @param sequsedQOptional [IN] 可选。各batch中query的实际序列长度,数据类型INT32。
* shape=(B,),padded batch模式下有效。与cuSeqlensQOptional互斥。
* @param sequsedKvOptional [IN] 可选。各batch中kv的实际序列长度,数据类型INT32。
* shape=(B,),与sequsedQOptional配合使用。
* @param sinksOptional [IN] 可选。可学习的sink注意力权重,数据类型FLOAT32。
* @param attnMaskOptional [IN] 可选。attnmask的参数,数据类型INT8。
* @param metadataOptional [IN] 可选。预计算的tiling切分方案,数据类型INT32,由上游算子传入。
* 当该输入不为nullptr时,tiling侧跳过切分方案计算,直接使用该元数据。
* @param softmaxScale [IN] ATTR可选。softmax缩放系数,对应老接口的scaleValue。
* 数据类型DOUBLE。默认值0.0表示使用1/sqrt(D)。
* @param maskMode [IN] ATTR可选。掩码模式,对应老接口的sparseMode。数据类型INT。
* 0: 不使用掩码;1: 因果掩码(下三角);2: 非因果掩码(上三角);
* 3: prefix/band掩码;4: 滑动窗口掩码(使用winLeft/winRight)。
* @param winLeft [IN] ATTR可选。左侧注意力窗口大小(maskMode=4时的winLefts)。数据类型INT。
* @param winRight [IN] ATTR可选。右侧注意力窗口大小(maskMode=4时的winRight)。数据类型INT。
* @param layoutQ [IN] ATTR可选。query的数据布局,支持"BSND"/"BNSD"/"TND"(大小写不敏感)。
* @param layoutKv [IN] ATTR可选。key/value的数据布局,支持"BSND"/"TND"/"PA_ND"/"PA_Nz"。
* @param layoutOut [IN] ATTR可选。输出的数据布局,支持"BSND"/"BNSD"/"TND"(大小写不敏感)。
* @param returnSoftmaxLse [IN] ATTR可选。是否输出softmax_lse。INT类型,1表示输出,0表示不输出。
* 训练正向传播时置1,推理时置0。
* @param attnOut [OUT] 必选输出。attention计算结果,数据类型与q一致,layout由layoutOut决定。
* @param softmaxLseOptional [OUT] 可选输出。softmax的log-sum-exp值,FLOAT32类型。
* returnSoftmaxLse=1时有效,shape取决于layout和序列长度:
* TND: (T_q, N_q) 或 (N_q, T_q);其他layout: (B, N_q, S_q)。
* @param workspaceSize [OUT] workspace大小(字节数)。
* @param executor [OUT] op执行器句柄,供第二段接口使用。
* @return aclnnStatus 执行状态。ACLNN_SUCCESS表示成功。
*/
aclnnStatus aclnnFlashAttnGetWorkspaceSize(
const aclTensor *q, const aclTensor *k, const aclTensor *v, const aclTensor *blockTableOptional,
const aclTensor *cuSeqlensQOptional, const aclTensor *cuSeqlensKvOptional, const aclTensor *sequsedQOptional,
const aclTensor *sequsedKvOptional, const aclTensor *sinksOptional, const aclTensor *attnMaskOptional,
const aclTensor *metadataOptional, double softmaxScale, int64_t maskMode, int64_t winLeft, int64_t winRight,
int64_t maxSeqlenQ, int64_t maxSeqlenKV, const char *layoutQ, const char *layoutKv, const char *layoutOut,
int64_t returnSoftmaxLse, const aclTensor *attnOut, const aclTensor *softmaxLseOptional,
uint64_t *workspaceSize, aclOpExecutor **executor);
* @brief aclnnFlashAttn的第二段接口,用于执行计算。
* @param workspace [IN] 由第一段接口计算得到的workspace设备内存指针。
* @param workspaceSize [IN] workspace大小(字节数)。
* @param executor [IN] 第一段接口输出的op执行器句柄。
* @param stream [IN] 用于执行计算的acl stream。
* @return aclnnStatus 执行状态。
*/
aclnnStatus aclnnFlashAttn(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, const aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif