* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "kernel_operator.h"
#include "mixkernels/gating/tiling/tiling_data.h"
using namespace AscendC;
constexpr int32_t TOKEN_NUM = 1024;
constexpr int32_t BUFFER_NUM = 1;
constexpr int32_t FLOAT_DATA_SIZE = 4;
constexpr int32_t TAIL_COPY_FLAG = 2;
constexpr int64_t TILE_NUM = 2048;
constexpr int64_t DOUBLE_TILE_NUM = 2 * TILE_NUM;
constexpr int64_t DOUBLE = 2;
constexpr int64_t THREE = 3;
constexpr int64_t GATHER_MASK_BEFORE_TILE_COUNT = 4;
constexpr int64_t GATHER_MASK_AFTER_TILE_COUNT = 2;
constexpr int32_t INT32_SIZE = sizeof(int32_t);
constexpr int32_t BYTES32 = 32;
constexpr int32_t SYNC_UB_BYTES = 1 * 1024;
constexpr uint16_t GLOBAL_SORT_NUM = 2560;
constexpr uint16_t GLOBAL_SORT_UB_HALF_LENGTH = 20 * 1024;
constexpr int32_t GLOBAL_SORT_SIZE = TILE_NUM * 4 * 2;
constexpr uint32_t DEVIC_EEXPERT_NUM_OFFSET = 9;
constexpr uint32_t CUM_SUM_NUM_OFFSET = 3;
struct CopyInfo {
uint32_t copyTimes;
uint32_t tailNum;
int32_t startCopyLocation;
int32_t completedCopy;
};
__aicore__ inline void InitGatingTilingData(const __gm__ uint8_t *tiling, AtbOps::GatingTilingData *tilingData)
{
const __gm__ uint32_t *src = (const __gm__ uint32_t *)tiling;
uint32_t *dst = (uint32_t *)tilingData;
int32_t deviceExpertNum = *(src + DEVIC_EEXPERT_NUM_OFFSET);
int32_t cumSumGm = *(src + CUM_SUM_NUM_OFFSET);
if (deviceExpertNum == cumSumGm) {
deviceExpertNum = 0;
}
for (auto i = 0; i < sizeof(AtbOps::GatingTilingData) / sizeof(int32_t) - 1025 + deviceExpertNum; i++) {
*(dst + i) = *(src + i);
}
}
template <typename CumSumNumType>
class GatingEP {
public:
__aicore__ inline GatingEP() {}
__aicore__ inline void Init(GM_ADDR topk, GM_ADDR idxArr,
GM_ADDR tokenIndex, GM_ADDR CumSum,
GM_ADDR originalIndex, GM_ADDR validIndex,
GM_ADDR globalSortWorkspace, GM_ADDR cumSumWorkspace,
GM_ADDR syncWorkspace, AtbOps::GatingTilingData *tiling_data)
{
InitParams(tiling_data);
int32_t selectNum = deviceExpertNum > 0 ? deviceExpertNum : expertNum;
topkGm.SetGlobalBuffer((__gm__ int32_t *)topk, topKNum);
idxArrGm.SetGlobalBuffer((__gm__ int32_t *)idxArr, topKNum);
tokenIndexGm.SetGlobalBuffer((__gm__ int32_t *)tokenIndex, topKNum);
cumSumGm.SetGlobalBuffer((__gm__ CumSumNumType *)CumSum, selectNum);
originalGm.SetGlobalBuffer((__gm__ int32_t *)originalIndex, topKNum);
if (deviceExpertNum > 0) {
validIndexGm.SetGlobalBuffer((__gm__ int32_t *)validIndex, 1);
} else {
validIndexGm.SetGlobalBuffer((__gm__ int32_t *)validIndex, 0);
}
globalSortBlock.SetGlobalBuffer((__gm__ float *)globalSortWorkspace, 2 * topKNumPadded);
globalSortBlock2.SetGlobalBuffer((__gm__ float *)globalSortWorkspace + 2 * topKNumPadded, 2 * topKNumPadded);
cumSumBlock.SetGlobalBuffer((__gm__ int32_t *)cumSumWorkspace, actualCoreNum * expertNum);
syncGm.SetGlobalBuffer((__gm__ int32_t *)syncWorkspace, syncSize);
pipe.InitBuffer(inQueueTopK, BUFFER_NUM, TILE_NUM * INT32_SIZE);
pipe.InitBuffer(inQueueIdxArr, BUFFER_NUM, TILE_NUM * INT32_SIZE);
pipe.InitBuffer(outQueueGroupSortToken, BUFFER_NUM, GLOBAL_SORT_UB_HALF_LENGTH * INT32_SIZE);
pipe.InitBuffer(outQueueSortToken, BUFFER_NUM, GLOBAL_SORT_UB_HALF_LENGTH * INT32_SIZE);
pipe.InitBuffer(syncTQue, BUFFER_NUM, SYNC_UB_BYTES);
}
__aicore__ inline void Process()
{
if (blockIdx == actualCoreNum) {
if (expertNum > 0) {
for (int32_t i = 0; i < actualCoreNum; ++i) {
auto sync_buf = syncTQue.AllocTensor<int32_t>();
IBWait(syncGm, sync_buf, i, actualCoreNum);
syncTQue.FreeTensor(sync_buf);
}
calculateAndCopy2CumSumGm();
}
} else if (blockIdx > 0) {
PartSort();
auto sync_buf = syncTQue.AllocTensor<int32_t>();
IBSet(syncGm, sync_buf, blockIdx, 0);
syncTQue.FreeTensor(sync_buf);
if (expertNum > 0) {
PartCumSum();
auto sync_buf2 = syncTQue.AllocTensor<int32_t>();
IBSet(syncGm, sync_buf2, blockIdx, actualCoreNum);
syncTQue.FreeTensor(sync_buf2);
}
} else {
PartSort();
if (expertNum > 0) {
PartCumSum();
auto sync_buf2 = syncTQue.AllocTensor<int32_t>();
IBSet(syncGm, sync_buf2, blockIdx, actualCoreNum);
syncTQue.FreeTensor(sync_buf2);
}
if (actualCoreNum > 1) {
for (int32_t i = 1; i < actualCoreNum; ++i) {
auto sync_buf = syncTQue.AllocTensor<int32_t>();
IBWait(syncGm, sync_buf, i, 0);
syncTQue.FreeTensor(sync_buf);
}
}
PipeBarrier<PIPE_ALL>();
bool switchFlag = false;
GlobalSort(switchFlag);
PipeBarrier<PIPE_ALL>();
if (expertNum > 0) {
auto sync_buf_read = syncTQue.AllocTensor<int32_t>();
IBWait(syncGm, sync_buf_read, actualCoreNum, 1);
syncTQue.FreeTensor(sync_buf_read);
}
CopyGm2Gm(switchFlag ? globalSortBlock2 : globalSortBlock, originalGm);
}
}
private:
struct GmsParams {
int (&gmsLengths)[4];
int (&gmsCurrentHead)[4];
int &queueNum;
LocalTensor<float> &srcLocalTensor;
LocalTensor<float> &dstLocalTensor;
bool &gmTensorIndex;
GlobalTensor<float> (&buffLocal)[2];
};
__aicore__ inline void InitParams(AtbOps::GatingTilingData *tiling_data)
{
blockIdx = GetBlockIdx();
topkExpertNum = tiling_data->topkExpertNum;
topKNum = tiling_data->topKNum;
expertNum = tiling_data->cumSumNum;
cumSumNum32BytesPadded = tiling_data->cumSumNum32BytesPadded;
actualCoreNum = tiling_data->actualCoreNum;
tailBlockDataSize = tiling_data->tailBlockDataSize;
syncSize = tiling_data->syncSize;
blockNumPerCore = tiling_data->blockNumPerCore[blockIdx];
offSet = tiling_data->offSetPerCore[blockIdx];
topKNumPadded = tiling_data->topKNumPadded;
deviceExpertNum = tiling_data->deviceExpertNum;
for (int i = 0; i < deviceExpertNum; i++) {
deviceExpert[i] = tiling_data->deviceExpert[i];
}
}
__aicore__ inline void PartSort()
{
int32_t executeTimes = blockNumPerCore;
int32_t tailNum = blockIdx == (actualCoreNum - 1) ? tailBlockDataSize : TILE_NUM;
for (uint32_t i = 0; i < executeTimes; i++) {
uint32_t processNum = i == executeTimes - 1 ? tailNum : TILE_NUM;
CopyIn(i, processNum);
Compute(i, processNum);
CopyOut(i, processNum);
PipeBarrier<PIPE_ALL>();
}
PipeBarrier<PIPE_ALL>();
}
__aicore__ inline void PartCumSum()
{
if (expertNum > 0) {
ComputeCumSumPart();
PipeBarrier<PIPE_ALL>();
}
}
__aicore__ inline void ComputeCumSumPart()
{
LocalTensor<int32_t> cumSumPartLocalTensor = outQueueGroupSortToken.AllocTensor<int32_t>();
for (int i = 0; i < expertNum; i++) {
cumSumPartLocalTensor.SetValue(i, selectedExpertCount[i]);
}
DataCopyExtParams copyParams1{1, static_cast<uint32_t>(expertNum * INT32_SIZE), 0, 0, 0};
PipeBarrier<PIPE_ALL>();
DataCopyPad(cumSumBlock[GetBlockIdx() * expertNum], cumSumPartLocalTensor, copyParams1);
PipeBarrier<PIPE_ALL>();
outQueueGroupSortToken.FreeTensor(cumSumPartLocalTensor);
}
__aicore__ inline void calculateAndCopy2CumSumGm()
{
LocalTensor<int32_t> all = outQueueGroupSortToken.AllocTensor<int32_t>();
LocalTensor<int32_t> accumulator = all[0];
LocalTensor<int32_t> src0 = all[cumSumNum32BytesPadded];
LocalTensor<int32_t> src1 = all[cumSumNum32BytesPadded + cumSumNum32BytesPadded];
DataCopyPadExtParams<int32_t> padParams{false, 0, static_cast<uint8_t>(0), 0};
DataCopyExtParams copyParams{1, static_cast<uint32_t>(expertNum * INT32_SIZE), 0, 0, 0};
PipeBarrier<PIPE_ALL>();
DataCopyPad(accumulator, cumSumBlock, copyParams, padParams);
PipeBarrier<PIPE_ALL>();
for (int i = 1; i < actualCoreNum; i++) {
DataCopyExtParams copyParams0{1, static_cast<uint32_t>(expertNum * INT32_SIZE), 0, 0, 0};
DataCopyPad(src0, cumSumBlock[i * expertNum], copyParams0, padParams);
PipeBarrier<PIPE_ALL>();
DataCopy(src1, accumulator, cumSumNum32BytesPadded);
PipeBarrier<PIPE_ALL>();
Add(accumulator, src0, src1, expertNum);
}
PipeBarrier<PIPE_ALL>();
LocalTensor<CumSumNumType> cumSumLocalTensor = outQueueSortToken.AllocTensor<CumSumNumType>();
int32_t acc = 0;
int32_t selectNum = deviceExpertNum > 0 ? deviceExpertNum : expertNum;
for (int i = 0; i < selectNum; i++) {
int chosenIndex = deviceExpertNum > 0 && deviceExpertNum < expertNum
? deviceExpert[i] : i;
acc = acc + accumulator.GetValue(chosenIndex);
cumSumLocalTensor.SetValue(i, static_cast<CumSumNumType>(acc));
}
PipeBarrier<PIPE_ALL>();
DataCopyExtParams copyParams1{1, static_cast<uint32_t>(expertNum * INT32_SIZE), 0, 0, 0};
DataCopyPad(cumSumBlock, accumulator, copyParams1);
if (sizeof(CumSumNumType) == INT32_SIZE) {
DataCopyExtParams copyParams2{1, static_cast<uint32_t>(selectNum * INT32_SIZE), 0, 0, 0};
DataCopyPad(cumSumGm, cumSumLocalTensor, copyParams2);
} else {
DataCopyExtParams copyParams2{1, static_cast<uint32_t>(selectNum * sizeof(CumSumNumType)), 0, 0, 0};
DataCopyPadUB2GMImpl((__gm__ uint32_t*)cumSumGm.GetPhyAddr(),
(__ubuf__ uint32_t*)cumSumLocalTensor.GetPhyAddr(),
copyParams2);
}
PipeBarrier<PIPE_ALL>();
outQueueGroupSortToken.FreeTensor(all);
outQueueSortToken.FreeTensor(cumSumLocalTensor);
auto sync_buf_write = syncTQue.AllocTensor<int32_t>();
IBSet(syncGm, sync_buf_write, actualCoreNum, 1);
syncTQue.FreeTensor(sync_buf_write);
}
__aicore__ inline void GlobalSort(bool &switchFlag)
{
LocalTensor<float> srcLocalTensor = outQueueGroupSortToken.AllocTensor<float>();
LocalTensor<float> dstLocalTensor = outQueueSortToken.AllocTensor<float>();
GlobalTensor<float> sortedGlobal[2] = {globalSortBlock, globalSortBlock2};
PipeBarrier<PIPE_ALL>();
int orderBlock = topKNumPadded / TILE_NUM;
int globalSortBlockCount = 4;
int length[4];
int currentHead[4];
int blockSize = 1;
for (blockSize = 1; blockSize < orderBlock; blockSize *= globalSortBlockCount) {
for (int tileIndex = 0; tileIndex < orderBlock; tileIndex += blockSize * globalSortBlockCount) {
int mrgTileNum = orderBlock - tileIndex < blockSize * globalSortBlockCount
? (orderBlock - tileIndex) : (blockSize * globalSortBlockCount);
int queueNum = (mrgTileNum + blockSize - 1) / blockSize;
uint16_t lastQueTileNum = mrgTileNum % blockSize == 0 ? blockSize : mrgTileNum % blockSize;
for (int i = 0; i < queueNum; i++) {
currentHead[i] = TILE_NUM * (tileIndex + i * blockSize);
}
for (int i = 0; i < queueNum - 1; i++) {
length[i] = TILE_NUM * blockSize;
}
length[queueNum-1] = TILE_NUM * lastQueTileNum;
GmsParams params{length, currentHead, queueNum,
srcLocalTensor, dstLocalTensor, switchFlag, sortedGlobal};
PipeBarrier<PIPE_ALL>();
GlobalMrgSort(params);
PipeBarrier<PIPE_ALL>();
}
PipeBarrier<PIPE_ALL>();
switchFlag = !switchFlag;
}
PipeBarrier<PIPE_ALL>();
outQueueGroupSortToken.FreeTensor(srcLocalTensor);
outQueueSortToken.FreeTensor(dstLocalTensor);
}
__aicore__ inline void GlobalMrgSort(GmsParams & params)
{
int queueNum[4] {params.gmsLengths[0], params.gmsLengths[1],
params.gmsLengths[2], params.gmsLengths[3]};
int queueLength = params.queueNum;
int currentHead[4] {params.gmsCurrentHead[0], params.gmsCurrentHead[1],
params.gmsCurrentHead[2], params.gmsCurrentHead[3]};
int totalMrgLen = 0;
int originalPosition = currentHead[0];
LocalTensor<float> srcLocalTensor {params.srcLocalTensor};
LocalTensor<float> dstLocalTensor {params.dstLocalTensor};
GlobalTensor<float> srcGmTensor = params.buffLocal[params.gmTensorIndex];
GlobalTensor<float> dstGmTensor = params.buffLocal[!params.gmTensorIndex];
PipeBarrier<PIPE_ALL>();
if (queueLength == 1) {
int repeatTimes = (queueNum[0] + TILE_NUM - 1) / TILE_NUM;
int tailNum = queueNum[0] % TILE_NUM == 0 ? TILE_NUM : queueNum[0] % TILE_NUM;
for (int i = 0; i < repeatTimes; i++) {
int executeNum = i == repeatTimes - 1
? tailNum
: TILE_NUM;
DataCopyPad(dstLocalTensor, srcGmTensor[DOUBLE * currentHead[0] + DOUBLE * i * TILE_NUM],
{1, static_cast<uint32_t>(DOUBLE * executeNum * INT32_SIZE), 0, 0, 0},
{false, 0, 0, 0});
PipeBarrier<PIPE_ALL>();
DataCopyPad(dstGmTensor[DOUBLE * currentHead[0] + DOUBLE * i * TILE_NUM], dstLocalTensor,
{1, static_cast<uint32_t>(DOUBLE * executeNum * INT32_SIZE), 0, 0, 0});
PipeBarrier<PIPE_ALL>();
}
PipeBarrier<PIPE_ALL>();
return;
}
while (queueLength > 1) {
uint16_t tmpSortLen[4];
for (int i = 0; i < queueLength; i++) {
int sortLength = queueNum[i] < maxSortLengthArr[i] ? queueNum[i] : maxSortLengthArr[i];
tmpSortLen[i] = sortLength;
int gmStartPosition = currentHead[i];
DataCopyPad(srcLocalTensor[i * GLOBAL_SORT_NUM * DOUBLE], srcGmTensor[DOUBLE * gmStartPosition],
{1, static_cast<uint32_t>(DOUBLE * sortLength * INT32_SIZE), 0, 0, 0},
{false, 0, 0, 0});
PipeBarrier<PIPE_ALL>();
}
PipeBarrier<PIPE_ALL>();
MrgSort4Info params{tmpSortLen, true, validQueueArr[queueLength], 1};
PipeBarrier<PIPE_ALL>();
MrgSort<float>(dstLocalTensor,
{srcLocalTensor[0], srcLocalTensor[GLOBAL_SORT_NUM * DOUBLE],
srcLocalTensor[GLOBAL_SORT_NUM * DOUBLE * DOUBLE],
srcLocalTensor[GLOBAL_SORT_NUM * DOUBLE * 3]},
params);
PipeBarrier<PIPE_ALL>();
uint16_t sortedLen[4];
GetMrgSortResult(sortedLen[0], sortedLen[1], sortedLen[2], sortedLen[3]);
const uint16_t localMrgLen = sortedLen[0] + sortedLen[1] + sortedLen[2] + sortedLen[3];
PipeBarrier<PIPE_ALL>();
DataCopyPad(dstGmTensor[DOUBLE * originalPosition + DOUBLE * totalMrgLen], dstLocalTensor,
{1, static_cast<uint32_t>(DOUBLE * localMrgLen * INT32_SIZE), 0, 0, 0});
PipeBarrier<PIPE_ALL>();
totalMrgLen += localMrgLen;
PipeBarrier<PIPE_ALL>();
for (int i = 0; i < queueLength; i++) {
queueNum[i] -= sortedLen[i];
currentHead[i] += sortedLen[i];
}
PipeBarrier<PIPE_ALL>();
for (int i = 0; i < queueLength; i++) {
if (queueNum[i] == 0) {
for (int j = i; j < 3; j++) {
queueNum[j] = queueNum[j + 1];
currentHead[j] = currentHead[j + 1];
}
queueNum[queueLength - 1] = 0;
queueLength -= 1;
break;
}
}
PipeBarrier<PIPE_ALL>();
}
PipeBarrier<PIPE_ALL>();
if (queueNum[0] > 0) {
int repeatTimes = (queueNum[0] + TILE_NUM - 1) / TILE_NUM;
int tailNum = queueNum[0] % TILE_NUM == 0 ? TILE_NUM : queueNum[0] % TILE_NUM;
for (int i = 0; i < repeatTimes; i++) {
int executeNum = i == repeatTimes - 1
? tailNum
: TILE_NUM;
DataCopyPad(dstLocalTensor, srcGmTensor[DOUBLE * currentHead[0] + DOUBLE * i * TILE_NUM],
{1, static_cast<uint32_t>(DOUBLE * executeNum * INT32_SIZE), 0, 0, 0},
{false, 0, 0, 0});
PipeBarrier<PIPE_ALL>();
DataCopyPad(dstGmTensor[DOUBLE * originalPosition + DOUBLE * totalMrgLen + DOUBLE * i * TILE_NUM],
dstLocalTensor, {1, static_cast<uint32_t>(DOUBLE * executeNum * INT32_SIZE), 0, 0, 0});
PipeBarrier<PIPE_ALL>();
}
PipeBarrier<PIPE_ALL>();
}
}
__aicore__ inline void CopyTile2Gm(const CopyInfo& info, LocalTensor<float> &tmp, GlobalTensor<float> &src)
{
uint32_t copyTimes = info.copyTimes;
uint32_t tailNum = info.tailNum;
int32_t startCopyLocation = info.startCopyLocation;
int32_t completedCopy = info.completedCopy;
uint64_t rsvdCnt = 0;
LocalTensor<float> gatherMaskTmpLocal = tmp[DOUBLE_TILE_NUM];
LocalTensor<float> divFloatLocal = tmp[DOUBLE * DOUBLE_TILE_NUM];
float topkExpertNumFloat = topkExpertNum;
Duplicate(divFloatLocal, topkExpertNumFloat, TILE_NUM);
for (int i = 0; i < copyTimes; i++) {
PipeBarrier<PIPE_ALL>();
DataCopy(tmp, src[startCopyLocation * 2 + i * DOUBLE_TILE_NUM], DOUBLE_TILE_NUM);
PipeBarrier<PIPE_ALL>();
const uint8_t PATTERN_ODD = 2;
const bool REDUCE_MODE_NORMAL = false;
const uint32_t MASK_NORMAL = 0;
const uint8_t src0BlockStride = 1;
const uint32_t repeatTimes = static_cast<uint32_t>((DOUBLE_TILE_NUM + 63) / 64);
const uint32_t src0RepeatStride = 8;
const uint8_t src1RepeatStride = 0;
GatherMask(gatherMaskTmpLocal, tmp, PATTERN_ODD, REDUCE_MODE_NORMAL, MASK_NORMAL,
{src0BlockStride, repeatTimes, src0RepeatStride, src1RepeatStride}, rsvdCnt);
LocalTensor<int32_t> gatherMaskTmpLocalInt = gatherMaskTmpLocal.ReinterpretCast<int32_t>();
PipeBarrier<PIPE_ALL>();
if (i == copyTimes - 1) {
DataCopyExtParams copyParams{1, static_cast<uint32_t>(tailNum * INT32_SIZE), 0, 0, 0};
DataCopyPad(originalGm[completedCopy + i * TILE_NUM], gatherMaskTmpLocalInt, copyParams);
} else {
DataCopy(originalGm[completedCopy + i * TILE_NUM], gatherMaskTmpLocalInt, TILE_NUM);
}
PipeBarrier<PIPE_ALL>();
if (expertNum > 0) {
LocalTensor<float> gatherMaskTmpLocalFloat = gatherMaskTmpLocalInt.ReinterpretCast<float>();
Cast(gatherMaskTmpLocalFloat, gatherMaskTmpLocalInt, RoundMode::CAST_FLOOR, TILE_NUM);
PipeBarrier<PIPE_V>();
Div(gatherMaskTmpLocalFloat, gatherMaskTmpLocalFloat, divFloatLocal, TILE_NUM);
PipeBarrier<PIPE_V>();
Cast(gatherMaskTmpLocalInt, gatherMaskTmpLocalFloat, RoundMode::CAST_FLOOR, TILE_NUM);
PipeBarrier<PIPE_ALL>();
if (i == copyTimes - 1) {
DataCopyExtParams copyParams{1, static_cast<uint32_t>(tailNum * INT32_SIZE), 0, 0, 0};
DataCopyPad(tokenIndexGm[completedCopy + i * TILE_NUM], gatherMaskTmpLocalInt, copyParams);
PipeBarrier<PIPE_ALL>();
} else {
DataCopyExtParams copyParams{1, static_cast<uint32_t>(TILE_NUM * INT32_SIZE), 0, 0, 0};
DataCopyPad(tokenIndexGm[completedCopy + i * TILE_NUM], gatherMaskTmpLocalInt, copyParams);
PipeBarrier<PIPE_ALL>();
}
}
}
}
__aicore__ inline void CopyGm2Gm(GlobalTensor<float> &srcLocal, GlobalTensor<int32_t> &dstLocal)
{
uint32_t copyTimes = (topKNum / TILE_NUM) + (topKNum % TILE_NUM == 0 ? 0 : 1);
LocalTensor<float> tmpLocal = outQueueSortToken.AllocTensor<float>();
int32_t validIndexNum = topKNum;
if (deviceExpertNum > 0 && deviceExpertNum < expertNum) {
validIndexNum = cumSumBlock.GetValue(deviceExpert[0]);
int32_t previousExpert = deviceExpert[0];
int32_t currentExpert = previousExpert;
int32_t cumSumFull = 0;
for (int i = 0; i <= previousExpert; i++) {
cumSumFull += cumSumBlock.GetValue(i);
}
int32_t startCopyLocation = cumSumFull - cumSumBlock.GetValue(previousExpert);
int32_t endCopyLocation = cumSumFull;
int32_t completedCopy = 0;
for (int i = 1; i < deviceExpertNum + 1; i++) {
if (i < deviceExpertNum) {
validIndexNum += cumSumBlock.GetValue(deviceExpert[i]);
currentExpert = deviceExpert[i];
} else {
currentExpert = previousExpert + TAIL_COPY_FLAG;
}
for (int j = previousExpert + 1; j <= currentExpert; j++) {
cumSumFull += cumSumBlock.GetValue(j);
}
if (currentExpert == previousExpert + 1) {
endCopyLocation = cumSumFull;
} else {
uint32_t copyTimes = (endCopyLocation - startCopyLocation) / TILE_NUM +
((endCopyLocation - startCopyLocation) % TILE_NUM == 0 ? 0 : 1);
uint32_t tailNum = (endCopyLocation - startCopyLocation) % TILE_NUM == 0 ? TILE_NUM :
(endCopyLocation - startCopyLocation) % TILE_NUM;
struct CopyInfo info = {copyTimes, tailNum, startCopyLocation, completedCopy};
CopyTile2Gm(info, tmpLocal, srcLocal);
completedCopy += (endCopyLocation - startCopyLocation);
endCopyLocation = cumSumFull;
startCopyLocation = endCopyLocation - cumSumBlock.GetValue(currentExpert);
}
previousExpert = currentExpert;
}
} else {
uint32_t copyTimes = (topKNum / TILE_NUM) + (topKNum % TILE_NUM == 0 ? 0 : 1);
uint32_t tailNum = topKNum % TILE_NUM == 0 ? TILE_NUM : topKNum % TILE_NUM;
CopyTile2Gm({copyTimes, tailNum, 0, 0}, tmpLocal, srcLocal);
}
if (deviceExpertNum > 0) {
validIndexGm.SetValue(0, validIndexNum);
}
outQueueSortToken.FreeTensor(tmpLocal);
}
__aicore__ inline void CopyIn(uint32_t processIndex, uint32_t processNum)
{
uint32_t paddingNum = (processNum * INT32_SIZE) % 32 == 0 ?
0 : (32 - (processNum * INT32_SIZE) % 32) / INT32_SIZE;
LocalTensor<int32_t> topkLocal = inQueueTopK.AllocTensor<int32_t>();
DataCopyExtParams copyParams1{1, static_cast<uint32_t>(processNum * INT32_SIZE), 0, 0, 0};
DataCopyPadExtParams<int32_t> padParams1{true, 0, static_cast<uint8_t>(paddingNum), paddingValueInt};
DataCopyPad(topkLocal, topkGm[offSet + processIndex * TILE_NUM], copyParams1, padParams1);
PipeBarrier<PIPE_ALL>();
LocalTensor<int32_t> idxArrLocal = inQueueIdxArr.AllocTensor<int32_t>();
DataCopyExtParams copyParams2{1, static_cast<uint32_t>(processNum * sizeof(uint32_t)), 0, 0, 0};
DataCopyPadExtParams<int32_t> padParams2{true, 0, static_cast<uint8_t>(paddingNum), paddingValueInt};
DataCopyPad(idxArrLocal, idxArrGm[offSet + processIndex * TILE_NUM], copyParams2, padParams2);
PipeBarrier<PIPE_ALL>();
inQueueTopK.EnQue(topkLocal);
inQueueIdxArr.EnQue(idxArrLocal);
}
__aicore__ inline void Compute(uint32_t processIndex, uint32_t processNum)
{
uint32_t paddingNum = (processNum * INT32_SIZE) % 32 == 0 ?
0 : (32 - (processNum * INT32_SIZE) % 32) / INT32_SIZE;
LocalTensor<int32_t> topkLocalInt = inQueueTopK.DeQue<int32_t>();
LocalTensor<int32_t> idxArrLocalInt = inQueueIdxArr.DeQue<int32_t>();
LocalTensor<uint32_t> idxArrLocal = idxArrLocalInt.ReinterpretCast<uint32_t>();
if (expertNum > 0) {
for (uint32_t i = 0; i < processNum; i++) {
uint32_t expertIndex = topkLocalInt.GetValue(i);
selectedExpertCount[expertIndex]++;
}
}
PipeBarrier<PIPE_ALL>();
LocalTensor<float> groupSortLocal = outQueueGroupSortToken.AllocTensor<float>();
LocalTensor<float> sortLocal = outQueueSortToken.AllocTensor<float>();
Cast(sortLocal, topkLocalInt, RoundMode::CAST_FLOOR, processNum + paddingNum);
PipeBarrier<PIPE_ALL>();
Duplicate<float>(sortLocal[processNum + paddingNum], paddingValueFloat, TILE_NUM - (processNum + paddingNum));
Duplicate<uint32_t>(idxArrLocal[processNum + paddingNum], paddingValueUint,
TILE_NUM - (processNum + paddingNum));
float factor = -1;
Muls(sortLocal, sortLocal, factor, TILE_NUM);
PipeBarrier<PIPE_V>();
const int elementsPerSort = 32;
Sort32<float>(groupSortLocal, sortLocal, idxArrLocal, TILE_NUM / elementsPerSort);
PipeBarrier<PIPE_V>();
MergeSort4Queue(groupSortLocal, sortLocal);
PipeBarrier<PIPE_ALL>();
outQueueSortToken.EnQue<float>(sortLocal);
inQueueTopK.FreeTensor(topkLocalInt);
inQueueIdxArr.FreeTensor(idxArrLocal);
outQueueGroupSortToken.FreeTensor(groupSortLocal);
}
__aicore__ inline void MergeSort4Queue(LocalTensor<float> &groupSortLocal, LocalTensor<float> &sortLocal)
{
const uint16_t mergeCount = 4;
LocalTensor<float> sortedQue[2] = {groupSortLocal, sortLocal};
bool switchFlag = false;
uint16_t currentQueLength = 32;
uint16_t currentQueNum = TILE_NUM / currentQueLength;
while (currentQueLength < TILE_NUM) {
const uint16_t elementLengths[4] = {currentQueLength,
currentQueLength,
currentQueLength,
currentQueLength};
const uint16_t fullMrgSortTime = currentQueNum / mergeCount;
if (fullMrgSortTime > 0) {
MrgSort4Info params = {elementLengths, false, 0b1111, fullMrgSortTime};
MrgSort<float>(sortedQue[!switchFlag],
{sortedQue[switchFlag][0], sortedQue[switchFlag][currentQueLength * 1 * DOUBLE],
sortedQue[switchFlag][currentQueLength * 2 * DOUBLE],
sortedQue[switchFlag][currentQueLength * 3 * DOUBLE]},
params);
switchFlag = !switchFlag;
}
currentQueNum = fullMrgSortTime;
currentQueLength *= mergeCount;
}
PipeBarrier<PIPE_ALL>();
if (!switchFlag) {
DataCopy(sortLocal, groupSortLocal, DOUBLE_TILE_NUM);
}
}
__aicore__ inline void CopyOut(uint32_t processIndex, uint32_t processNum)
{
LocalTensor<float> sortLocal = outQueueSortToken.DeQue<float>();
PipeBarrier<PIPE_ALL>();
DataCopy(globalSortBlock[offSet * DOUBLE + processIndex * DOUBLE_TILE_NUM], sortLocal, DOUBLE_TILE_NUM);
PipeBarrier<PIPE_ALL>();
outQueueSortToken.FreeTensor(sortLocal);
}
private:
TPipe pipe;
GlobalTensor<int32_t> topkGm;
GlobalTensor<int32_t> idxArrGm;
GlobalTensor<int32_t> tokenIndexGm;
GlobalTensor<int32_t> originalGm;
GlobalTensor<CumSumNumType> cumSumGm;
GlobalTensor<int32_t> validIndexGm;
GlobalTensor<float> globalSortBlock;
GlobalTensor<float> globalSortBlock2;
GlobalTensor<int32_t> cumSumBlock;
GlobalTensor<int32_t> syncGm;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueTopK, inQueueIdxArr;
TQue<QuePosition::VECIN, BUFFER_NUM> syncTQue;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueGroupSortToken;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueSortToken;
float paddingValueFloat = 11300000.0;
int32_t paddingValueInt = 11300000;
uint32_t paddingValueUint = 11300000;
int32_t topkExpertNum = 0;
int64_t topKNum = -1;
int32_t expertNum = -1;
int32_t cumSumNum32BytesPadded = -1;
int32_t actualCoreNum = 1;
int32_t tailBlockDataSize = 0;
int32_t syncSize = 0;
int32_t blockNumPerCore = 0;
uint32_t offSet = 0;
int64_t topKNumPadded = 0;
int32_t blockIdx = 0;
int32_t deviceExpertNum = 0;
int32_t selectedExpertCount[1025] = {0};
int32_t deviceExpert[1025] = {0};
int maxSortLengthArr[4] = {2560, 2560, 2560, 2560};
uint16_t validQueueArr[5] = {0, 0, 0b11, 0b111, 0b1111};
};
extern "C" __global__ __aicore__ void gating(GM_ADDR topk, GM_ADDR idxArr,
GM_ADDR tokenIndex, GM_ADDR cumSum,
GM_ADDR originalIndex, GM_ADDR validIndex,
GM_ADDR globalSortWorkspace, GM_ADDR cumSumWorkspace,
GM_ADDR syncWorkspace, GM_ADDR tiling)
{
AtbOps::GatingTilingData tilingData;
InitGatingTilingData(tiling, &tilingData);
if (TILING_KEY_IS(2000000000)) {
GatingEP<int32_t> op;
op.Init(topk, idxArr, tokenIndex,
cumSum, originalIndex, validIndex,
globalSortWorkspace, cumSumWorkspace, syncWorkspace, &tilingData);
op.Process();
}
if (TILING_KEY_IS(2000000001)) {
GatingEP<int64_t> op;
op.Init(topk, idxArr, tokenIndex,
cumSum, originalIndex, validIndex,
globalSortWorkspace, cumSumWorkspace, syncWorkspace, &tilingData);
op.Process();
}
}