* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OP_API_INC_LEVEL0_FLASH_ATTN_H_
#define OP_API_INC_LEVEL0_FLASH_ATTN_H_
#include <array>
#include "opdev/op_executor.h"
namespace l0op {
* @brief FlashAttn level-0 operator。
* 封装FlashAttn算子的底层调度,完成InferShape与Kernel Launch注册。
* 该接口为内部接口,仅供aclnn层调用。
*
* @param q query tensor
* @param k key tensor
* @param v value tensor
* @param blockTableOptional 分页KV缓存块映射表(可选,INT32)
* @param cuSeqlensQOptional query累积序列长度tensor(可选,INT32)
* @param cuSeqlensKvOptional kv累积序列长度tensor(可选,INT32)
* @param sequsedQOptional query各batch实际序列长度tensor(可选,INT32)
* @param sequsedKvOptional kv各batch实际序列长度tensor(可选,INT32)
* @param sinksOptional 可学习sink权重(可选,FLOAT32)
* @param attnMaskOptional attnMask参数(可选,INT8)
* @param metadataOptional 预计算tiling元数据(可选,INT32)
* @param softmaxScale softmax缩放系数(float)
* @param maskMode 掩码模式(int64_t)
* @param winLeft 左窗口大小(int64_t)
* @param winRight 右窗口大小(int64_t)
* @param maxSeqlenQ query最大序列长度(int32_t)
* @param maxSeqlenKV kv最大序列长度(int32_t)
* @param layoutQ query布局字符串
* @param layoutKv kv布局字符串
* @param layoutOut 输出布局字符串
* @param returnSoftmaxLse 是否输出softmax_lse(int64_t)
* @param executor op执行器
* @return std::array<const aclTensor*, 2> [attnOut, softmaxLse]
* 任意元素为nullptr表示对应输出的InferShape或Launch失败。
*/
const std::array<const aclTensor *, 2> FlashAttn(
const aclTensor *q,
const aclTensor *k,
const aclTensor *v,
const aclTensor *blockTableOptional,
const aclTensor *cuSeqlensQOptional,
const aclTensor *cuSeqlensKvOptional,
const aclTensor *sequsedQOptional,
const aclTensor *sequsedKvOptional,
const aclTensor *sinksOptional,
const aclTensor *attnMaskOptional,
const aclTensor *metadataOptional,
double softmaxScale,
int32_t maskMode,
int32_t winLeft,
int32_t winRight,
int32_t maxSeqlenQ,
int32_t maxSeqlenKV,
const char *layoutQ,
const char *layoutKv,
const char *layoutOut,
int32_t returnSoftmaxLse,
aclOpExecutor *executor);
}
#endif