* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_utils.h
* \brief
*/
#ifndef CATLASS_EXAMPLES_FAI_KERNEL_UTILS_H
#define CATLASS_EXAMPLES_FAI_KERNEL_UTILS_H
#include "catlass/catlass.hpp"
using namespace Catlass;
using namespace AscendC;
constexpr uint32_t CV_RATIO = 2;
constexpr uint32_t NUM2 =2;
constexpr uint32_t KERNEL_TASK_NUM = 3;
template <typename T>
CATLASS_DEVICE T Min(T a, T b) {
return (a > b) ? b : a;
}
struct FAIKernelParams {
GM_ADDR q;
GM_ADDR k;
GM_ADDR v;
GM_ADDR mask;
GM_ADDR blockTables;
GM_ADDR actualQSeqlen;
GM_ADDR actualKvSeqlen;
GM_ADDR o;
GM_ADDR tiling;
CATLASS_DEVICE
FAIKernelParams() {
}
CATLASS_DEVICE
FAIKernelParams(GM_ADDR q_,
GM_ADDR k_,
GM_ADDR v_,
GM_ADDR mask_,
GM_ADDR blockTables_,
GM_ADDR actualQSeqlen_,
GM_ADDR actualKvSeqlen_,
GM_ADDR o_,
GM_ADDR tiling_)
: q(q_)
, k(k_)
, v(v_)
, mask(mask_)
, blockTables(blockTables_)
, actualQSeqlen(actualQSeqlen_)
, actualKvSeqlen(actualKvSeqlen_)
, o(o_)
, tiling(tiling_) {
}
};
constexpr uint64_t SYNC_MODE = 4;
constexpr uint64_t SYNC_C1_V1_FLAG[2] = {0, 1};
constexpr uint64_t SYNC_V1_C2_FLAG[3] = {2, 3, 4};
constexpr uint64_t SYNC_C2_V2_FLAG[2] = {5, 6};
constexpr uint64_t MM2_RES_INTRA_EVENT[2] = {7, 8};
constexpr uint64_t MM1_RES_INTRA_EVENT[2] = {9, 10};
struct CubeCoordInfo {
uint32_t curBIdx;
uint32_t qSeqCoord;
uint32_t kvSeqCoord;
};
struct RunParamStr {
int64_t batchOuterIdx;
int64_t qSeqOuterAxisIdx;
int64_t kvHeadsOuterIdx;
int64_t groupIdx;
int32_t kvSeqLoopStartIdx;
int32_t kvSeqLoopEndIdx;
int64_t kvSeqAxisLineStartIdx = 0;
int64_t kvSeqAxisLineEndIdx;
uint32_t qSeqRealSize;
uint32_t halfQSeqRealSize;
uint32_t firstHalfQSeqRealSize;
int64_t actualQSeqSize;
int64_t actualKvSeqSize;
int64_t qSeqLoopTimes;
};
struct RunInfo {
int64_t kvSeqAxisStartIdx;
int64_t kvSeqAxisEndIdx;
int64_t kvSeqLoopCount;
int64_t kvSeqLoopStartIdx;
int64_t kvSeqLoopLimit;
int64_t qSeqOuterAxisIdx = 0;
int64_t batchOuterIdx = 0;
int64_t kvHeadsOuterIdx = 0;
int64_t groupIdx = 0;
int32_t qSeqRealSize;
int32_t halfQSeqRealSize;
int32_t firstHalfQSeqRealSize;
int32_t kvSeqRealSize;
int64_t taskId;
int64_t multiCoreInnerIdx = 0;
int64_t actualQSeqSize;
int64_t actualKvSeqSize;
uint8_t taskIdMod2;
uint8_t taskIdMod3;
uint8_t multiCoreIdxMod2 = 0;
uint8_t multiCoreIdxMod3 = 0;
int64_t blockTableOffset;
};
struct ConstInfo {
uint32_t qSeqlenBase;
uint32_t kvSeqlenBase;
int64_t embed;
int64_t groupSize;
int64_t qHeads;
int64_t kvHeads;
int64_t qSeqlen;
int64_t kvSeqlen;
int64_t qSeqlenOuterSize;
uint8_t subBlockIdx;
float scaleValue;
bool isActualLenDimsNull;
bool isActualLenDimsKVNull;
uint32_t actualSeqLenSize;
uint32_t actualSeqLenKVSize;
uint32_t blockTableDim2;
uint32_t blockSize;
uint32_t paBlockNumSum;
uint32_t headNumRatio;
uint32_t bnAxisStartIdx;
uint32_t bnAxisEndIdx;
uint32_t actualSeqLengthsSize;
uint32_t actualSeqLengthsKVSize;
bool isActualSeqLengthsNull;
bool isActualSeqLengthsKVNull;
uint32_t batch;
uint32_t attenMaskQSeqlen;
uint32_t attenMaskKvSeqlen;
volatile int64_t multiCoreInnerOffset;
volatile int64_t multiCoreInnerLimit;
uint32_t coreNum;
};
struct AttenMaskInfo {
int64_t attenMaskShapeType;
int64_t attenMaskQSeqlen;
int64_t attenMaskKvSeqlen;
int64_t attenMaskOffsetPre;
};
constexpr uint16_t SHIFT_NUM_6 = 6;
constexpr uint16_t ADD_NUM_63 = 63;
CATLASS_DEVICE constexpr uint16_t Align64Func(uint16_t data) {
return (data + ADD_NUM_63) >> SHIFT_NUM_6 << SHIFT_NUM_6;
}
CATLASS_DEVICE constexpr uint16_t Align(uint16_t data, uint16_t baseSize) {
return (data - 1) / baseSize * baseSize + baseSize;
}
CATLASS_DEVICE inline void ComputeParamBatch(RunParamStr& runParam, const ConstInfo &constInfo,
const AttenMaskInfo &attenMaskInfo)
{
runParam.actualQSeqSize = constInfo.qSeqlen;;
runParam.actualKvSeqSize = constInfo.kvSeqlen;;
}
template <uint32_t qSeqlenTemplateType>
CATLASS_DEVICE inline void ComputeQseqLoopInfo(RunParamStr& runParam, const ConstInfo &constInfo, bool lastBN,
int64_t nextQSeqAxisIdx)
{
constexpr int32_t qSeqlenBase = static_cast<int32_t>(qSeqlenTemplateType);
int32_t qSeqLoopTimes = CeilDiv(runParam.actualQSeqSize, qSeqlenBase);
if (!lastBN) {
runParam.qSeqLoopTimes = qSeqLoopTimes;
} else {
runParam.qSeqLoopTimes = nextQSeqAxisIdx == 0 ? qSeqLoopTimes : nextQSeqAxisIdx;
}
}
template <uint32_t qSeqlenTemplateType>
CATLASS_DEVICE inline void ComputeParamQSeq(RunParamStr& runParam, const ConstInfo &constInfo,
uint32_t sOuterLoopIdx)
{
int64_t cubeSOuterOffset = sOuterLoopIdx * (uint32_t)qSeqlenTemplateType;
if (runParam.actualQSeqSize == 0) {
runParam.qSeqRealSize = 0;
} else {
runParam.qSeqRealSize = Min((uint32_t)qSeqlenTemplateType, (uint32_t)(runParam.actualQSeqSize - cubeSOuterOffset));
}
runParam.halfQSeqRealSize = (runParam.qSeqRealSize + 1) >> 1;
runParam.firstHalfQSeqRealSize = runParam.halfQSeqRealSize;
if (constInfo.subBlockIdx == 1) {
runParam.halfQSeqRealSize = runParam.qSeqRealSize - runParam.halfQSeqRealSize;
}
}
template <uint32_t kvSeqlenTemplateType>
CATLASS_DEVICE inline void ComputeKvSeqLoopInfo(RunParamStr& runParam, const ConstInfo &constInfo)
{
constexpr int32_t kvSeqlenBase = static_cast<int32_t>(kvSeqlenTemplateType);
runParam.kvSeqAxisLineStartIdx = 0;
runParam.kvSeqAxisLineEndIdx = runParam.actualKvSeqSize;
runParam.kvSeqLoopStartIdx = 0;
runParam.kvSeqLoopEndIdx = (runParam.kvSeqAxisLineEndIdx + kvSeqlenBase - 1) / kvSeqlenBase;
}
#endif