* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "kernel_operator.h"
#include "mixkernels/utils/common/kernel/kernel_utils.h"
#include "mixkernels/gating/tiling/tiling_data.h"
using namespace AscendC;
constexpr int32_t INT32_SIZE = sizeof(int32_t);
constexpr int32_t INT64_SIZE = sizeof(int64_t);
constexpr int32_t BUFFER_NUM = 1;
constexpr int64_t DOUBLE = 2;
constexpr int64_t TRIPLE = 3;
constexpr int64_t MAX_SORT_QUEUE_NUM = 4;
constexpr int64_t SORT_STRUCT_MULTIPLE = 8;
constexpr int32_t SYNC_UB_BYTES = 32 * 32;
constexpr int64_t TILE_NUM = 512;
constexpr int64_t STRUCT_TILE_NUM = 8 * TILE_NUM;
constexpr int32_t TOPK_PROPOSAL_IDX = 4;
constexpr int32_t IDX_PROPOSAL_IDX = 5;
template <typename CumSumNumType>
class Gating {
public:
__aicore__ inline Gating() {}
__aicore__ inline void Init(GM_ADDR topk, GM_ADDR idxArr,
GM_ADDR tokenIndex, GM_ADDR CumSum,
GM_ADDR originalIndex, GM_ADDR globalSortWorkspace,
GM_ADDR cumSumWorkspace, GM_ADDR syncWorkspace,
AtbOps::GatingTilingData *tiling_data)
{
InitParams(tiling_data);
topkGm.SetGlobalBuffer((__gm__ int32_t *)topk, topKNum);
idxArrGm.SetGlobalBuffer((__gm__ int32_t *)idxArr, topKNum);
tokenIndexGm.SetGlobalBuffer((__gm__ int32_t *)tokenIndex, topKNum);
cumSumGm.SetGlobalBuffer((__gm__ CumSumNumType *)CumSum, cumSumNum);
globalSortBlock.SetGlobalBuffer((__gm__ float *)globalSortWorkspace, SORT_STRUCT_MULTIPLE * topKNumPadded);
globalSortBlock2.SetGlobalBuffer((__gm__ float *)globalSortWorkspace + SORT_STRUCT_MULTIPLE * topKNumPadded,
SORT_STRUCT_MULTIPLE * topKNumPadded);
cumSumBlock.SetGlobalBuffer((__gm__ int32_t *)cumSumWorkspace, actualCoreNum * cumSumNum32BytesPadded);
syncGm.SetGlobalBuffer((__gm__ int32_t *)syncWorkspace, syncSize);
originalGm.SetGlobalBuffer((__gm__ int32_t *)originalIndex, topKNum);
InitPipe();
}
__aicore__ inline void Process()
{
if (blockIdx == actualCoreNum) {
if (cumSumNum > 0) {
for (int32_t i = 0; i < actualCoreNum; ++i) {
auto sync_buf = syncTQue.AllocTensor<int32_t>();
IBWait(syncGm, sync_buf, i, actualCoreNum);
syncTQue.FreeTensor(sync_buf);
}
calculateAndCopy2CumSumGm();
}
} else if (blockIdx > 0) {
PartSort();
auto sync_buf = syncTQue.AllocTensor<int32_t>();
IBSet(syncGm, sync_buf, blockIdx, 0);
syncTQue.FreeTensor(sync_buf);
if (cumSumNum > 0) {
PartCumSum();
auto sync_buf2 = syncTQue.AllocTensor<int32_t>();
IBSet(syncGm, sync_buf2, blockIdx, actualCoreNum);
syncTQue.FreeTensor(sync_buf2);
}
} else {
PartSort();
if (cumSumNum > 0) {
PartCumSum();
auto sync_buf2 = syncTQue.AllocTensor<int32_t>();
IBSet(syncGm, sync_buf2, blockIdx, actualCoreNum);
syncTQue.FreeTensor(sync_buf2);
}
if (actualCoreNum > 1) {
for (int32_t i = 1; i < actualCoreNum; ++i) {
auto sync_buf = syncTQue.AllocTensor<int32_t>();
IBWait(syncGm, sync_buf, i, 0);
syncTQue.FreeTensor(sync_buf);
}
}
GlobalTensor<float> resultGlobalSortBlock = GlobalSort();
CopyGm2Gm(resultGlobalSortBlock, originalGm);
}
}
private:
struct GmsParams {
int (&gmsLengths)[MAX_SORT_QUEUE_NUM];
int (&gmsCurrentHead)[MAX_SORT_QUEUE_NUM];
int &queueNum;
LocalTensor<float> &srcLocalTensor;
LocalTensor<float> &dstLocalTensor;
bool &gmTensorIndex;
GlobalTensor<float> (&buffLocal)[DOUBLE];
};
__aicore__ inline void InitPipe()
{
pipe.InitBuffer(syncTQue, BUFFER_NUM, SYNC_UB_BYTES);
pipe.InitBuffer(inQueueTopK, BUFFER_NUM, TILE_NUM * INT32_SIZE);
pipe.InitBuffer(inQueueIdxArr, BUFFER_NUM, TILE_NUM * INT32_SIZE);
pipe.InitBuffer(inQueueWorkspace, BUFFER_NUM, MAX_SORT_QUEUE_NUM * STRUCT_TILE_NUM * INT32_SIZE);
pipe.InitBuffer(outQueueWorkspace, BUFFER_NUM, DOUBLE * STRUCT_TILE_NUM * INT32_SIZE);
pipe.InitBuffer(outQueueTopK, BUFFER_NUM, STRUCT_TILE_NUM * INT32_SIZE);
pipe.InitBuffer(inQueueCumsumPart, BUFFER_NUM, cumSumNum32BytesPadded * INT32_SIZE);
pipe.InitBuffer(outQueueCumsumPartAddSrc0, BUFFER_NUM, cumSumNum32BytesPadded * INT32_SIZE);
pipe.InitBuffer(outQueueCumsumPartAddSrc1, BUFFER_NUM, cumSumNum32BytesPadded * INT32_SIZE);
pipe.InitBuffer(tileNumTempBuf, TILE_NUM * INT32_SIZE);
pipe.InitBuffer(structTileNumTempBuf, STRUCT_TILE_NUM * INT32_SIZE);
pipe.InitBuffer(expertNumTempBuf, cumSumNum32BytesPadded * sizeof(CumSumNumType));
pipe.InitBuffer(CopyUb2GmPadtemp, SORT_STRUCT_MULTIPLE * INT32_SIZE);
}
__aicore__ inline void InitParams(AtbOps::GatingTilingData *tiling_data)
{
blockIdx = GetBlockIdx();
topkExpertNum = tiling_data->topkExpertNum;
topKNum = tiling_data->topKNum;
cumSumNum = tiling_data->cumSumNum;
cumSumNum32BytesPadded = tiling_data->cumSumNum32BytesPadded;
actualCoreNum = tiling_data->actualCoreNum;
tailBlockDataSize = tiling_data->tailBlockDataSize;
syncSize = tiling_data->syncSize;
blockNumPerCore = tiling_data->blockNumPerCore[blockIdx];
offSet = tiling_data->offSetPerCore[blockIdx];
topKNumPadded = tiling_data->topKNumPadded;
}
__aicore__ inline void PartSort()
{
int32_t executeTimes = blockNumPerCore;
int32_t tailNum = blockIdx == (actualCoreNum - 1) ? tailBlockDataSize : TILE_NUM;
for (uint32_t i = 0; i < executeTimes; i++) {
uint32_t processNum = i == executeTimes - 1 ? tailNum : TILE_NUM;
CopyIn(i, processNum);
Compute(i, processNum);
CopyOut(i, processNum);
}
}
__aicore__ inline void PartCumSum()
{
if (cumSumNum > 0) {
ComputeCumSumPart();
}
}
__aicore__ inline GlobalTensor<float> GlobalSort()
{
bool switchFlag = false;
LocalTensor<float> srcLocalTensor = inQueueWorkspace.AllocTensor<float>();
LocalTensor<float> dstLocalTensor = outQueueWorkspace.AllocTensor<float>();
GlobalTensor<float> sortedGlobal[2] = {globalSortBlock, globalSortBlock2};
int32_t orderBlock = topKNumPadded / TILE_NUM;
int32_t globalSortBlockCount = MAX_SORT_QUEUE_NUM;
int32_t length[MAX_SORT_QUEUE_NUM];
int32_t currentHead[MAX_SORT_QUEUE_NUM];
for (int32_t blockSize = 1; blockSize < orderBlock; blockSize *= globalSortBlockCount) {
for (int32_t tileIndex = 0; tileIndex < orderBlock; tileIndex += blockSize * globalSortBlockCount) {
int32_t mrgTileNum = orderBlock - tileIndex < blockSize * globalSortBlockCount ?
(orderBlock - tileIndex) : (blockSize * globalSortBlockCount);
int32_t queueNum = (mrgTileNum + blockSize - 1) / blockSize;
uint16_t lastQueTileNum = mrgTileNum % blockSize == 0 ? blockSize : mrgTileNum % blockSize;
for (int i = 0; i < queueNum; i++) {
currentHead[i] = TILE_NUM * (tileIndex + i * blockSize);
}
for (int i = 0; i < queueNum - 1; i++) {
length[i] = TILE_NUM * blockSize;
}
length[queueNum-1] = TILE_NUM * lastQueTileNum;
GmsParams params{length, currentHead, queueNum,
srcLocalTensor, dstLocalTensor, switchFlag, sortedGlobal};
GlobalMrgSort(params);
PipeBarrier<PIPE_V>();
}
switchFlag = !switchFlag;
}
inQueueWorkspace.FreeTensor(srcLocalTensor);
outQueueWorkspace.FreeTensor(dstLocalTensor);
return switchFlag ? globalSortBlock2 : globalSortBlock;
}
struct MrgQueue {
int32_t queueNum[MAX_SORT_QUEUE_NUM];
int32_t queueLength;
int32_t currentHead[MAX_SORT_QUEUE_NUM];
int32_t totalMrgLen;
int32_t originalPosition;
};
__aicore__ inline void GlobalMrgSort(GmsParams ¶ms)
{
LocalTensor<float> srcLocalTensor {params.srcLocalTensor};
LocalTensor<float> dstLocalTensor {params.dstLocalTensor};
GlobalTensor<float> srcGmTensor = params.buffLocal[params.gmTensorIndex];
GlobalTensor<float> dstGmTensor = params.buffLocal[!params.gmTensorIndex];
MrgQueue queueToMrg{
{params.gmsLengths[0], params.gmsLengths[1], params.gmsLengths[2], params.gmsLengths[3]},
params.queueNum,
{params.gmsCurrentHead[0], params.gmsCurrentHead[1], params.gmsCurrentHead[2], params.gmsCurrentHead[3]},
0,
params.gmsCurrentHead[0]
};
if (queueToMrg.queueLength == 1) {
HandleSingleQueue(queueToMrg, dstLocalTensor, srcGmTensor, dstGmTensor);
return;
}
while (queueToMrg.queueLength > 1) {
PerformMergeSortStep(queueToMrg, srcLocalTensor, dstLocalTensor, srcGmTensor, dstGmTensor);
PipeBarrier<PIPE_ALL>();
}
if (queueToMrg.queueNum[0] > 0) {
HandleSingleQueue(queueToMrg, dstLocalTensor, srcGmTensor, dstGmTensor);
}
}
__aicore__ inline void HandleSingleQueue(MrgQueue &queueToMrg, LocalTensor<float> &dstLocalTensor,
GlobalTensor<float> &srcGmTensor, GlobalTensor<float> &dstGmTensor)
{
int32_t repeatTimes = (queueToMrg.queueNum[0] + TILE_NUM - 1) / TILE_NUM;
int32_t tailNum = queueToMrg.queueNum[0] % TILE_NUM == 0 ? TILE_NUM : queueToMrg.queueNum[0] % TILE_NUM;
for (int i = 0; i < repeatTimes; i++) {
int32_t executeNum = i == repeatTimes - 1 ? tailNum : TILE_NUM;
DataCopy(dstLocalTensor,
srcGmTensor[SORT_STRUCT_MULTIPLE * queueToMrg.currentHead[0] +
SORT_STRUCT_MULTIPLE * i * TILE_NUM],
static_cast<uint32_t>(SORT_STRUCT_MULTIPLE * executeNum));
PipeBarrier<PIPE_ALL>();
DataCopy(dstGmTensor[SORT_STRUCT_MULTIPLE * queueToMrg.originalPosition +
SORT_STRUCT_MULTIPLE * queueToMrg.totalMrgLen +
SORT_STRUCT_MULTIPLE * i * TILE_NUM],
dstLocalTensor,
static_cast<uint32_t>(SORT_STRUCT_MULTIPLE * executeNum));
SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
}
}
__aicore__ inline void PerformMergeSortStep(MrgQueue &queueToMrg, LocalTensor<float> &srcLocalTensor,
LocalTensor<float> &dstLocalTensor, GlobalTensor<float> &srcGmTensor,
GlobalTensor<float> &dstGmTensor)
{
uint16_t tmpSortLen[MAX_SORT_QUEUE_NUM];
for (int i = 0; i < queueToMrg.queueLength; i++) {
uint16_t sortLength = queueToMrg.queueNum[i] < maxSortLengthArr[i] ? queueToMrg.queueNum[i] :
maxSortLengthArr[i];
tmpSortLen[i] = sortLength;
int32_t gmStartPosition = queueToMrg.currentHead[i];
DataCopy(srcLocalTensor[i * STRUCT_TILE_NUM], srcGmTensor[SORT_STRUCT_MULTIPLE * gmStartPosition],
static_cast<uint32_t>(SORT_STRUCT_MULTIPLE * sortLength));
}
MrgSort4Info mrgParams(tmpSortLen, true, validQueueArr[queueToMrg.queueLength], 1);
MrgSortSrcList<float> srcList(srcLocalTensor[0], srcLocalTensor[STRUCT_TILE_NUM],
srcLocalTensor[DOUBLE * STRUCT_TILE_NUM],
srcLocalTensor[TRIPLE * STRUCT_TILE_NUM]);
SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);
MrgSort4<float>(dstLocalTensor, srcList, mrgParams);
PipeBarrier<PIPE_ALL>();
uint16_t sortedLen[MAX_SORT_QUEUE_NUM];
GetMrgSortResult(sortedLen[0], sortedLen[1], sortedLen[2], sortedLen[3]);
const uint16_t localMrgLen = sortedLen[0] + sortedLen[1] + sortedLen[2] + sortedLen[3];
DataCopy(dstGmTensor[SORT_STRUCT_MULTIPLE * queueToMrg.originalPosition +
SORT_STRUCT_MULTIPLE * queueToMrg.totalMrgLen],
dstLocalTensor, static_cast<uint32_t>(SORT_STRUCT_MULTIPLE * localMrgLen));
queueToMrg.totalMrgLen += localMrgLen;
for (int i = 0; i < queueToMrg.queueLength; i++) {
queueToMrg.queueNum[i] -= sortedLen[i];
queueToMrg.currentHead[i] += sortedLen[i];
}
for (int i = 0; i < queueToMrg.queueLength; i++) {
if (queueToMrg.queueNum[i] == 0) {
for (int j = i; j < 3; j++) {
queueToMrg.queueNum[j] = queueToMrg.queueNum[j + 1];
queueToMrg.currentHead[j] = queueToMrg.currentHead[j + 1];
}
queueToMrg.queueNum[queueToMrg.queueLength - 1] = 0;
queueToMrg.queueLength -= 1;
break;
}
}
}
__aicore__ inline void CopyIn(uint32_t processIndex, uint32_t processNum)
{
uint32_t paddingNum = (processNum * INT32_SIZE) % 32 == 0 ?
0 : (32 - (processNum * INT32_SIZE) % 32) / INT32_SIZE;
LocalTensor<int32_t> topkLocal = inQueueTopK.AllocTensor<int32_t>();
LocalTensor<int32_t> idxArrLocal = inQueueIdxArr.AllocTensor<int32_t>();
DataCopy(topkLocal, topkGm[offSet + processIndex * TILE_NUM], processNum + paddingNum);
DataCopy(idxArrLocal, idxArrGm[offSet + processIndex * TILE_NUM], processNum + paddingNum);
SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_S>(EVENT_ID0);
if (paddingNum != 0) {
uint64_t bit_mask = 1;
const uint32_t countPreBlock = 8;
const uint32_t dstRepeatStride = (processNum + paddingNum) / countPreBlock - 1;
bit_mask <<= paddingNum;
bit_mask -= 1;
bit_mask <<= countPreBlock - paddingNum;
uint64_t mask[2]{ bit_mask, 0 };
Duplicate(topkLocal[dstRepeatStride * countPreBlock], paddingValueInt, mask, 1, 1, 1);
Duplicate(idxArrLocal[dstRepeatStride * countPreBlock], paddingValueInt, mask, 1, 1, 1);
}
SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_S>(EVENT_ID0);
inQueueTopK.EnQue(topkLocal);
inQueueIdxArr.EnQue(idxArrLocal);
}
__aicore__ inline void Compute(uint32_t processIndex, uint32_t processNum)
{
const uint32_t paddingNum = (processNum * INT32_SIZE) % 32 == 0 ?
0 : (32 - (processNum * INT32_SIZE) % 32) / INT32_SIZE;
LocalTensor<int32_t> topkLocalInt = inQueueTopK.DeQue<int32_t>();
LocalTensor<int32_t> idxArrLocalInt = inQueueIdxArr.DeQue<int32_t>();
LocalTensor<float> idxArrLocal = idxArrLocalInt.ReinterpretCast<float>();
if (cumSumNum > 0) {
SetFlag<HardEvent::V_S>(EVENT_ID0);
WaitFlag<HardEvent::V_S>(EVENT_ID0);
for (uint32_t i = 0; i < processNum; i++) {
uint32_t expertIndex = topkLocalInt.GetValue(i);
selectedExpertCount[expertIndex]++;
}
}
LocalTensor<float> groupSortLocal = structTileNumTempBuf.Get<float>();
LocalTensor<float> topkLocal_float = tileNumTempBuf.Get<float>();
Cast(topkLocal_float, topkLocalInt, RoundMode::CAST_NONE, processNum + paddingNum);
Duplicate<float>(topkLocal_float[processNum + paddingNum], paddingValueFloat,
TILE_NUM - (processNum + paddingNum));
Duplicate<float>(idxArrLocal[processNum + paddingNum], paddingValueFloat,
TILE_NUM - (processNum + paddingNum));
float factor = -1.0;
Muls(topkLocal_float, topkLocal_float, factor, TILE_NUM);
PipeBarrier<PIPE_V>();
LocalTensor<float> sortLocal = outQueueWorkspace.AllocTensor<float>();
uint32_t repeatTimes = (TILE_NUM) / 16;
ProposalConcat<float>(sortLocal, topkLocal_float, repeatTimes, TOPK_PROPOSAL_IDX);
ProposalConcat<float>(sortLocal, idxArrLocal, repeatTimes, IDX_PROPOSAL_IDX);
PipeBarrier<PIPE_V>();
RpSort16(groupSortLocal, sortLocal, repeatTimes);
PipeBarrier<PIPE_V>();
MergeSort4Queue(sortLocal, groupSortLocal);
outQueueWorkspace.EnQue<float>(sortLocal);
inQueueTopK.FreeTensor(topkLocalInt);
inQueueIdxArr.FreeTensor(idxArrLocal);
}
__aicore__ inline void CopyOut(uint32_t processIndex, uint32_t processNum)
{
LocalTensor<float> sortLocal = outQueueWorkspace.DeQue<float>();
DataCopy(globalSortBlock[offSet * SORT_STRUCT_MULTIPLE + processIndex * STRUCT_TILE_NUM], sortLocal,
STRUCT_TILE_NUM);
outQueueWorkspace.FreeTensor(sortLocal);
}
__aicore__ inline void CopyGm2Gm(const GlobalTensor<float> &srcGlobal, const GlobalTensor<int32_t> &dstGlobal)
{
const uint32_t copyTimes = (topKNum / TILE_NUM) + (topKNum % TILE_NUM == 0 ? 0 : 1);
LocalTensor<float> tmpLocal = structTileNumTempBuf.Get<float>();
LocalTensor<float> originIndex = outQueueTopK.AllocTensor<float>();
LocalTensor<float> divFloatLocal = tileNumTempBuf.Get<float>();
float topkExpertNumFloat = topkExpertNum;
Duplicate(divFloatLocal, topkExpertNumFloat, TILE_NUM);
SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
for (int i = 0; i < copyTimes; i++) {
DataCopy(tmpLocal, srcGlobal[i * STRUCT_TILE_NUM], STRUCT_TILE_NUM);
PipeBarrier<PIPE_ALL>();
ProposalExtract(originIndex, tmpLocal, TILE_NUM / (DOUBLE * SORT_STRUCT_MULTIPLE), IDX_PROPOSAL_IDX);
LocalTensor<int32_t> originIndexLocalInt = originIndex.ReinterpretCast<int32_t>();
SetFlag<HardEvent::V_MTE3>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE3>(EVENT_ID0);
if (i == copyTimes - 1) {
uint32_t tailNum = topKNum % TILE_NUM == 0 ? TILE_NUM : topKNum % TILE_NUM;
CopyUb2GmPad(dstGlobal[i * TILE_NUM], originIndexLocalInt, tailNum);
} else {
DataCopy(dstGlobal[i * TILE_NUM], originIndexLocalInt, TILE_NUM);
}
PipeBarrier<PIPE_ALL>();
if (cumSumNum > 0) {
LocalTensor<float> originIndexLocalFloat = originIndexLocalInt.ReinterpretCast<float>();
Cast(originIndexLocalFloat, originIndexLocalInt, RoundMode::CAST_NONE, TILE_NUM);
PipeBarrier<PIPE_V>();
Div(originIndexLocalFloat, originIndexLocalFloat, divFloatLocal, TILE_NUM);
PipeBarrier<PIPE_V>();
Cast(originIndexLocalInt, originIndexLocalFloat, RoundMode::CAST_FLOOR, TILE_NUM);
SetFlag<HardEvent::V_MTE3>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE3>(EVENT_ID0);
if (i == copyTimes - 1) {
uint32_t tailNum = topKNum % TILE_NUM == 0 ? TILE_NUM : topKNum % TILE_NUM;
CopyUb2GmPad(tokenIndexGm[i * TILE_NUM], originIndexLocalInt, tailNum);
} else {
DataCopy(tokenIndexGm[i * TILE_NUM], originIndexLocalInt, TILE_NUM);
}
}
}
outQueueTopK.FreeTensor(originIndex);
}
__aicore__ inline void MergeSort4Queue(LocalTensor<float> &sortBuf, LocalTensor<float> &tmpBuf)
{
const uint16_t mergeCount = MAX_SORT_QUEUE_NUM;
LocalTensor<float> sortedQue[2] = {tmpBuf, sortBuf};
int switchFlag = 0;
const uint16_t proposalSize = SORT_STRUCT_MULTIPLE;
uint16_t singleQueSize = 16;
while (singleQueSize < TILE_NUM) {
uint16_t QueNum = (TILE_NUM / singleQueSize) % MAX_SORT_QUEUE_NUM;
uint16_t repeatTimes = TILE_NUM / singleQueSize / MAX_SORT_QUEUE_NUM;
if (QueNum != 0) {
repeatTimes++;
}
uint16_t validBit = (1 << (MAX_SORT_QUEUE_NUM - QueNum)) - 1;
struct MrgSortSrcList<float> srcList{
sortedQue[switchFlag][0],
sortedQue[switchFlag][singleQueSize * proposalSize],
sortedQue[switchFlag][singleQueSize * proposalSize * DOUBLE],
sortedQue[switchFlag][singleQueSize * proposalSize * TRIPLE]
};
uint16_t elementLengths[MAX_SORT_QUEUE_NUM]{
singleQueSize,
singleQueSize,
singleQueSize,
singleQueSize
};
struct MrgSort4Info srcInfo(elementLengths, false, validBit, repeatTimes);
MrgSort4(sortedQue[switchFlag ^ 1], srcList, srcInfo);
switchFlag ^= 1;
singleQueSize *= mergeCount;
}
SetFlag<HardEvent::V_S>(EVENT_ID0);
WaitFlag<HardEvent::V_S>(EVENT_ID0);
if (!static_cast<bool>(switchFlag)) {
DataCopy(sortBuf, tmpBuf, STRUCT_TILE_NUM);
}
}
__aicore__ inline void ComputeCumSumPart()
{
LocalTensor<int32_t> cumSumPartLocalTensor = inQueueCumsumPart.AllocTensor<int32_t>();
for (int i = 0; i < cumSumNum; i++) {
cumSumPartLocalTensor.SetValue(i, selectedExpertCount[i]);
}
SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
WaitFlag<HardEvent::S_MTE3>(EVENT_ID0);
CopyUb2GmPad(cumSumBlock[GetBlockIdx() * cumSumNum32BytesPadded], cumSumPartLocalTensor, cumSumNum);
SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
inQueueCumsumPart.FreeTensor(cumSumPartLocalTensor);
}
__aicore__ inline void calculateAndCopy2CumSumGm()
{
LocalTensor<int32_t> accumulator = inQueueCumsumPart.AllocTensor<int32_t>();
LocalTensor<int32_t> src0 = outQueueCumsumPartAddSrc0.AllocTensor<int32_t>();
LocalTensor<int32_t> src1 = outQueueCumsumPartAddSrc1.AllocTensor<int32_t>();
DataCopy(accumulator, cumSumBlock, cumSumNum32BytesPadded);
for (int i = 1; i < actualCoreNum; i++) {
DataCopy(src0, cumSumBlock[i * cumSumNum32BytesPadded], cumSumNum32BytesPadded);
SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);
DataCopy(src1, accumulator, cumSumNum32BytesPadded);
PipeBarrier<PIPE_ALL>();
Add(accumulator, src0, src1, cumSumNum);
SetFlag<HardEvent::V_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE2>(EVENT_ID0);
}
PipeBarrier<PIPE_ALL>();
LocalTensor<CumSumNumType> cumSumLocalTensor = expertNumTempBuf.Get<CumSumNumType>();
CumSumNumType acc = 0;
for (int i = 0; i < cumSumNum; i++) {
acc = acc + static_cast<CumSumNumType>(accumulator.GetValue(i));
cumSumLocalTensor.SetValue(i, acc);
}
PipeBarrier<PIPE_ALL>();
CopyUb2GmPad(cumSumGm, cumSumLocalTensor, cumSumNum);
inQueueCumsumPart.FreeTensor(accumulator);
outQueueCumsumPartAddSrc0.FreeTensor(src0);
outQueueCumsumPartAddSrc1.FreeTensor(src1);
}
template <typename T>
__aicore__ inline void CopyUb2GmPad(const GlobalTensor<T> &dstGlobal, const LocalTensor<T> &srcLocal,
uint32_t length)
{
LocalTensor<T> tempLocal = CopyUb2GmPadtemp.Get<T>();
uint32_t countPreBlock = 8;
if (length <= countPreBlock) {
DataCopy(dstGlobal, srcLocal, countPreBlock);
} else {
DataCopy(dstGlobal, srcLocal, length);
SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
if (length % countPreBlock != 0) {
for (T i = 0; i < countPreBlock; ++i) {
T t = srcLocal.GetValue(length - countPreBlock + i);
tempLocal.SetValue(i, t);
}
SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
WaitFlag<HardEvent::S_MTE3>(EVENT_ID0);
DataCopy(dstGlobal[length - countPreBlock], tempLocal, countPreBlock);
}
}
}
private:
TPipe pipe;
GlobalTensor<int32_t> topkGm;
GlobalTensor<int32_t> idxArrGm;
GlobalTensor<int32_t> tokenIndexGm;
GlobalTensor<int32_t> originalGm;
GlobalTensor<CumSumNumType> cumSumGm;
GlobalTensor<float> globalSortBlock;
GlobalTensor<float> globalSortBlock2;
GlobalTensor<int32_t> cumSumBlock;
GlobalTensor<int32_t> syncGm;
TQue<QuePosition::VECIN, BUFFER_NUM> syncTQue;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueTopK, inQueueIdxArr;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueWorkspace;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueCumsumPart;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueWorkspace;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueTopK;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueCumsumPartAddSrc0;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueCumsumPartAddSrc1;
TBuf<AscendC::TPosition::VECCALC> tileNumTempBuf;
TBuf<AscendC::TPosition::VECCALC> structTileNumTempBuf;
TBuf<AscendC::TPosition::VECCALC> expertNumTempBuf;
TBuf<AscendC::TPosition::VECCALC> CopyUb2GmPadtemp;
float paddingValueFloat = static_cast<float>(0x0FFFFFFF);
int32_t paddingValueInt = 0x0FFFFFFF;
uint32_t paddingValueUint = 0x0FFFFFFF;
int32_t topkExpertNum = 0;
int64_t topKNum = -1;
int32_t cumSumNum = -1;
int32_t cumSumNum32BytesPadded = -1;
int32_t actualCoreNum = 1;
int32_t tailBlockDataSize = 0;
int32_t syncSize = 0;
int32_t blockNumPerCore = 0;
uint32_t offSet = 0;
int64_t topKNumPadded = 0;
int32_t blockIdx = 0;
int32_t selectedExpertCount[1025] = {0};
int maxSortLengthArr[MAX_SORT_QUEUE_NUM] = {512, 512, 512, 512};
uint16_t validQueueArr[5] = {0, 0, 0b11, 0b111, 0b1111};
};
__aicore__ inline void InitGatingTilingData(const __gm__ uint8_t *tiling,
AtbOps::GatingTilingData *tilingData)
{
TPipe pipe;
__ubuf__ uint8_t *tilingdata_in_ub = nullptr;
CopyGmTilingToUb(tilingdata_in_ub, tiling, sizeof(AtbOps::GatingTilingData), &pipe);
__ubuf__ AtbOps::GatingTilingData *tilingDataPointer =
reinterpret_cast<__ubuf__ AtbOps::GatingTilingData *>(tilingdata_in_ub);
SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_S>(EVENT_ID0);
tilingData->topkExpertNum = tilingDataPointer->topkExpertNum;
tilingData->topKNum = tilingDataPointer->topKNum;
tilingData->topKNumPadded = tilingDataPointer->topKNumPadded;
tilingData->cumSumNum = tilingDataPointer->cumSumNum;
tilingData->cumSumNum32BytesPadded = tilingDataPointer->cumSumNum32BytesPadded;
tilingData->actualCoreNum = tilingDataPointer->actualCoreNum;
tilingData->blockNum = tilingDataPointer->blockNum;
tilingData->tailBlockDataSize = tilingDataPointer->tailBlockDataSize;
tilingData->syncSize = tilingDataPointer->syncSize;
for (int i = 0; i < MAX_CORE_NUM; ++i) {
tilingData->blockNumPerCore[i] = tilingDataPointer->blockNumPerCore[i];
tilingData->beginBlockIndexPerCore[i] = tilingDataPointer->beginBlockIndexPerCore[i];
tilingData->offSetPerCore[i] = tilingDataPointer->offSetPerCore[i];
}
tilingData->cumSumInt64 = tilingDataPointer->cumSumInt64;
}
extern "C" __global__ __aicore__ void gating(GM_ADDR topk, GM_ADDR idxArr,
GM_ADDR tokenIndex, GM_ADDR cumSum,
GM_ADDR originalIndex, GM_ADDR validIndex,
GM_ADDR globalSortWorkspace, GM_ADDR cumSumWorkspace,
GM_ADDR syncWorkspace, GM_ADDR tiling)
{
AtbOps::GatingTilingData tilingData;
InitGatingTilingData(tiling, &tilingData);
if (TILING_KEY_IS(2000000000)) {
Gating<int32_t> op;
op.Init(topk, idxArr, tokenIndex, cumSum, originalIndex, globalSortWorkspace,
cumSumWorkspace, syncWorkspace, &tilingData);
op.Process();
}
if (TILING_KEY_IS(2000000001)) {
Gating<int64_t> op;
op.Init(topk, idxArr, tokenIndex, cumSum, originalIndex, globalSortWorkspace,
cumSumWorkspace, syncWorkspace, &tilingData);
op.Process();
}
}