* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file flash_attn_tiling.cpp
* \brief FlashAttn Tiling主入口
*/
#include <cmath>
#include <register/op_impl_registry.h>
#include "log/log.h"
#include "op_host/tiling_templates_registry.h"
#include "flash_attn_tiling.h"
#include "flash_attn_tiling_common.h"
#include "../op_kernel/arch35/flash_attn_template_tiling_key.h"
#include "flash_attn_tiling_info_parser.h"
#include "checkers/fa_checker.h"
#include "../../common/op_host/fia_tiling_templates_registry.h"
using namespace ge;
using namespace AscendC;
using namespace Ops::Transformer::OpTiling;
namespace optiling {
using namespace flash_attn;
static bool IsEmptyInput(gert::TilingContext *context)
{
(void)context;
return false;
}
ASCENDC_EXTERN_C ge::graphStatus TilingFlashAttn(gert::TilingContext *context)
{
OP_LOGW(context, "FlashAttn TilingFlashAttn start.");
auto platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_IF(platformInfoPtr == nullptr, OP_LOGE(context, "platformInfoPtr is null"), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
if (ascendcPlatform.GetCurNpuArch() == NpuArch::DAV_3510) {
if (IsEmptyInput(context)) {
return ge::GRAPH_SUCCESS;
}
}
FaTilingInfo faInfo;
FaInfoParser faInfoParser(context);
if (faInfoParser.Parse(faInfo) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
FAChecker faChecker;
faChecker.Init(faInfo);
if (faChecker.Process(faInfo) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
return FiaTilingRegistry::GetInstance().DoTilingImpl(context, &faInfo);
}
ASCENDC_EXTERN_C ge::graphStatus TilingPrepareForFlashAttn(gert::TilingParseContext *context)
{
auto platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_IF(platformInfoPtr == nullptr, OP_LOGE(context, "platformInfoPtr is null"), return ge::GRAPH_FAILED);
auto compileInfoPtr = context->GetCompiledInfo<FlashAttnCompileInfo>();
OP_CHECK_IF(compileInfoPtr == nullptr, OP_LOGE(context, "compileInfoPtr is null"), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
compileInfoPtr->aivNum = ascendcPlatform.GetCoreNumAiv();
compileInfoPtr->aicNum = ascendcPlatform.GetCoreNumAic();
compileInfoPtr->socVersion = ascendcPlatform.GetSocVersion();
compileInfoPtr->npuArch = ascendcPlatform.GetCurNpuArch();
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L1, compileInfoPtr->l1Size);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, compileInfoPtr->l0cSize);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L2, compileInfoPtr->l2CacheSize);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(FlashAttn)
.Tiling(TilingFlashAttn)
.TilingParse<FlashAttnCompileInfo>(TilingPrepareForFlashAttn);
}