* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file flash_attn.cpp
* \brief FlashAttn Kernel主入口(4字段)
*/
#include "kernel_operator.h"
#include "arch35/flash_attn_entry_regbase.h"
#include "arch35/flash_attn_template_tiling_key.h"
#include "arch35/flash_attn_tiling_data.h"
using namespace AscendC;
template <uint8_t inOutLayoutType, uint8_t KvLayoutType, bool hasAttenMask, uint8_t config>
__global__ __aicore__ void
flash_attn(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value, __gm__ uint8_t *blockTable,
__gm__ uint8_t *cuSeqLensQ, __gm__ uint8_t *cuSeqLensKv, __gm__ uint8_t *sequsedQ, __gm__ uint8_t *sequsedKv,
__gm__ uint8_t *sinks, __gm__ uint8_t *attnMask, __gm__ uint8_t *metadata, __gm__ uint8_t *attnOut,
__gm__ uint8_t *softmaxLse, __gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
{
REGISTER_TILING_DEFAULT(optiling::FlashAttnTilingData);
__gm__ uint8_t *user = GetUserWorkspace(workspace);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
if constexpr (inOutLayoutType == InOutLayoutType_TND && KvLayoutType == KvLayoutType_NO_PA) {
dc_preload(reinterpret_cast<__gm__ uint64_t*>(cuSeqLensQ), 0);
dc_preload(reinterpret_cast<__gm__ uint64_t*>(cuSeqLensKv), 0);
dc_preload(reinterpret_cast<__gm__ uint64_t*>(sequsedQ), 0);
dc_preload(reinterpret_cast<__gm__ uint64_t*>(sequsedKv), 0);
}
#if (ORIG_DTYPE_Q == DT_BF16)
flash_attn_kernel_run<bfloat16_t, bfloat16_t, inOutLayoutType, KvLayoutType, hasAttenMask, config>(
query, key, value, attnMask, cuSeqLensQ, cuSeqLensKv, sequsedQ, sequsedKv, blockTable, sinks, attnOut,
softmaxLse, user, tiling, metadata);
#elif (ORIG_DTYPE_Q == DT_FLOAT16)
flash_attn_kernel_run<half, half, inOutLayoutType, KvLayoutType, hasAttenMask, config>(
query, key, value, attnMask, cuSeqLensQ, cuSeqLensKv, sequsedQ, sequsedKv, blockTable, sinks, attnOut,
softmaxLse, user, tiling, metadata);
#endif
}