* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file moe_dispatch.h
* \brief
*/
#ifndef __DISTRIBUTED_DISPATCH__
#define __DISTRIBUTED_DISPATCH__
#include "common.h"
#include <type_traits>
namespace TileOp::Distributed {
template <
typename T, int32_t axisH, int32_t tRowOffset, int32_t tColOffset, int32_t tRowShape, int32_t tColShape,
int32_t groupIndex>
TILEOP void SendToRoutingExpert(
CoreFuncParam* param, __gm__ int32_t* syncTensor, __ubuf__ T* tokenBuffer, __ubuf__ int32_t* expertTableUb,
__ubuf__ int32_t* expertBuffer, __gm__ T* token, __gm__ T* shmemDataBaseAddr, __gm__ int32_t* expertTable,
uint32_t tableOffset0, uint32_t tableOffset1, uint32_t tableRawShape0, uint32_t tableRawShape1,
uint32_t shmemDataOffset0, uint32_t shmemDataOffset1, uint32_t shmemDataOffset2, uint32_t shmemDataOffset3,
uint32_t shmemDataRawShape0, uint32_t shmemDataRawShape1, uint32_t shmemDataRawShape2, uint32_t shmemDataRawShape3,
__gm__ int64_t* hcclContext)
{
int32_t topK = static_cast<int32_t>(tableRawShape1);
int32_t expertTblSize = static_cast<int32_t>(tableRawShape0) * static_cast<int32_t>(tableRawShape1);
int32_t lenBurst = AlignUp<int32_t>(expertTblSize * sizeof(int32_t), 32) / 32;
copy_gm_to_ubuf(expertTableUb, expertTable, 0, 1, lenBurst, 0, 0);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_S, EVENT_ID0);
__gm__ CommContext* winContext = (__gm__ CommContext*)(hcclContext[groupIndex]);
uint64_t localUsrRankId = static_cast<uint64_t>(winContext->rankId);
const int32_t hOutSize = axisH * sizeof(T);
int32_t shmemDataLength = AlignUp<int32_t>(axisH, 512) + 512;
const int32_t tokenQuantAlign32 = AlignUp<int32_t>(hOutSize, 32) / sizeof(int32_t);
__ubuf__ int32_t* tmpTokenBuffer = reinterpret_cast<__ubuf__ int32_t*>(tokenBuffer);
int32_t combineInfoOffset = 32;
for (int32_t row = tRowOffset; row < tRowOffset + tRowShape; ++row) {
copy_gm_to_ubuf(tokenBuffer, token + row * axisH, 0, 1, hOutSize / 32, 0, 0);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_S, EVENT_ID0);
for (int32_t col = tColOffset; col < tColOffset + tColShape; ++col) {
tmpTokenBuffer[tokenQuantAlign32 + combineInfoOffset] = static_cast<int32_t>(localUsrRankId);
tmpTokenBuffer[tokenQuantAlign32 + (combineInfoOffset + 1)] = row;
tmpTokenBuffer[tokenQuantAlign32 + (combineInfoOffset + 2)] = col;
int32_t tableIndex = row * topK + col;
int32_t remoteExpertId = expertTableUb[tableIndex];
int32_t remoteRankId = remoteExpertId / static_cast<int32_t>(shmemDataRawShape2);
int32_t remoteExpertOffset = remoteExpertId % static_cast<int32_t>(shmemDataRawShape2);
int32_t tokenOffset = CalcOccurrencesVector(expertTableUb, remoteExpertId, tableIndex, expertBuffer);
__gm__ T* remoteShmemBaseAddr =
MapVirtualAddr<T>(hcclContext, shmemDataBaseAddr, static_cast<uint32_t>(remoteRankId));
__gm__ T* remoteShmemDataAddr =
remoteShmemBaseAddr +
static_cast<uint64_t>(
localUsrRankId * static_cast<uint64_t>(shmemDataRawShape2) *
static_cast<uint64_t>(shmemDataRawShape3) +
static_cast<uint64_t>(remoteExpertOffset) * static_cast<uint64_t>(shmemDataRawShape3) +
static_cast<uint64_t>(tokenOffset) * static_cast<uint64_t>(shmemDataLength));
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE3, EVENT_ID0);
copy_ubuf_to_gm(remoteShmemDataAddr, tokenBuffer, 0, 1, shmemDataLength * sizeof(T) / 32, 0, 0);
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
}
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
}
}
template <typename T, int32_t bs, int32_t axisH, int32_t tileRowShape, int32_t groupIndex>
TILEOP void SendToSharedExpert(
CoreFuncParam* param, __gm__ int32_t* syncTensor, __ubuf__ T* tokenBuffer, __gm__ T* token,
__gm__ T* shmemDataBaseAddr, uint32_t tokenOffset0, uint32_t tokenOffset1, uint32_t tokenShape0,
uint32_t tokenShape1, uint32_t shmemDataOffset0, uint32_t shmemDataOffset1, uint32_t shmemDataOffset2,
uint32_t shmemDataOffset3, uint32_t shmemDataRawShape0, uint32_t shmemDataRawShape1, uint32_t shmemDataRawShape2,
uint32_t shmemDataRawShape3, __gm__ int64_t* hcclContext)
{
__gm__ CommContext* winContext = (__gm__ CommContext*)(hcclContext[groupIndex]);
int32_t localUsrRankId = static_cast<int32_t>(winContext->rankId);
int32_t rankSize = static_cast<int32_t>(winContext->rankNum);
constexpr int32_t hOutSize = axisH * sizeof(T);
constexpr int32_t tokenQuantAlign32 = AlignUp<int32_t>(hOutSize, 32) / sizeof(int32_t);
int32_t shareRankSize = 1;
int32_t moeRankSize = rankSize - shareRankSize;
int32_t shareOpProcessRankSize = moeRankSize / shareRankSize;
__ubuf__ int32_t* tmpTokenBuffer = reinterpret_cast<__ubuf__ int32_t*>(tokenBuffer);
for (int32_t row = tokenOffset0; row < tokenOffset0 + tileRowShape; ++row) {
copy_gm_to_ubuf(tokenBuffer, token + row * axisH, 0, 1, hOutSize / 32, 0, 0);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_S, EVENT_ID0);
tmpTokenBuffer[tokenQuantAlign32] =
localUsrRankId;
tmpTokenBuffer[tokenQuantAlign32 + 1] = row;
tmpTokenBuffer[tokenQuantAlign32 + 2] = 1;
uint32_t remoteRankId = (localUsrRankId - shareRankSize) / shareOpProcessRankSize;
GM_ADDR remoteShmemBaseAddr = (GM_ADDR)MapVirtualAddr<T>(hcclContext, shmemDataBaseAddr, remoteRankId);
GM_ADDR remoteShmemDataAddr = remoteShmemBaseAddr +
localUsrRankId * shmemDataRawShape2 * shmemDataRawShape3 * sizeof(T) +
row * shmemDataRawShape3 * sizeof(T);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE3, EVENT_ID0);
copy_ubuf_to_gm(remoteShmemDataAddr, tokenBuffer, 0, 1, shmemDataRawShape3 * sizeof(T) / 32, 0, 0);
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
}
}
template <typename T, int32_t bs, int32_t axisH, int32_t tileRowShape>
TILEOP void CopyToLocalExpert(
CoreFuncParam* param, __gm__ T* expandX, __gm__ int32_t* syncTensor, __ubuf__ T* tokenBuffer, __gm__ T* token,
uint32_t tokenOffset0, uint32_t tokenOffset1, uint32_t tokenShape0, uint32_t tokenShape1,
__gm__ int64_t* hcclContext)
{
constexpr int32_t hOutSize = axisH * sizeof(T);
constexpr int32_t hCommuSize = hOutSize;
for (int32_t row = tokenOffset0; row < tokenOffset0 + tileRowShape; ++row) {
copy_gm_to_ubuf(tokenBuffer, token + row * axisH, 0, 1, hOutSize / 32, 0, 0);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
copy_ubuf_to_gm(expandX + row * axisH, tokenBuffer, 0, 1, hOutSize / 32, 0, 0);
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
}
}
template <typename T, int32_t bs, int32_t topK, int32_t groupIndex, int32_t expertShape, int32_t rankShape>
TILEOP void DispatchSetFlag(
CoreFuncParam* param, __gm__ int32_t* syncDummy, __ubuf__ int32_t* statusTensor, __ubuf__ int32_t* expertTableUb,
__ubuf__ int32_t* expertBuffer, __gm__ T* expertTable, __gm__ int32_t* shmemFlagBaseAddr,
__gm__ int32_t* syncTensor, uint32_t shmemFlagOffset0, uint32_t shmemFlagOffset1, uint32_t shmemFlagOffset2,
uint32_t shmemFlagOffset3, uint32_t shmemFlagRawShape0, uint32_t shmemFlagRawShape1, uint32_t shmemFlagRawShape2,
uint32_t shmemFlagRawShape3, __gm__ int64_t* hcclContext)
{
(void)shmemFlagOffset2;
(void)shmemFlagOffset3;
(void)shmemFlagRawShape0;
__gm__ CommContext* winContext = (__gm__ CommContext*)(hcclContext[groupIndex]);
int32_t localUsrRankId = static_cast<int32_t>(winContext->rankId);
constexpr int32_t expertTblSize = bs * topK;
constexpr int32_t lenBurst = AlignUp<int32_t>(expertTblSize * sizeof(int32_t), 32) / 32;
copy_gm_to_ubuf(expertTableUb, expertTable, 0, 1, lenBurst, 0, 0);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_S, EVENT_ID0);
for (int32_t rankId = shmemFlagOffset0; rankId < shmemFlagOffset0 + rankShape; ++rankId) {
for (int32_t dstExpertId = shmemFlagOffset1; dstExpertId < shmemFlagOffset1 + expertShape; ++dstExpertId) {
__gm__ int32_t* remoteFlagBaseAddr = MapVirtualAddr<T>(hcclContext, shmemFlagBaseAddr, rankId);
int32_t remoteExpertId = dstExpertId + rankId * static_cast<int32_t>(shmemFlagRawShape1);
__gm__ int32_t* shmemFlagWriteAddr =
remoteFlagBaseAddr +
dstExpertId * static_cast<int32_t>(shmemFlagRawShape2) * static_cast<int32_t>(shmemFlagRawShape3) +
localUsrRankId * static_cast<int32_t>(shmemFlagRawShape3);
statusTensor[dstExpertId * 8] = 1;
statusTensor[dstExpertId * 8 + 1] =
CalcOccurrencesVector(expertTableUb, remoteExpertId, expertTblSize, expertBuffer);
PIPE_SYNC_EVENT(PIPE_V, PIPE_MTE3, EVENT_ID0);
copy_ubuf_to_gm(shmemFlagWriteAddr, statusTensor + dstExpertId * 8, 0, 1, 1, 0, 0);
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
}
}
}
TILEOP void CopyOutRecvTokenCnt(
GM_ADDR outRecvTokenCntAddr, UB_ADDR recvTokenCntAddr, uint32_t tileIndex, uint32_t totalTileNum)
{
DataCopyParams dataCopyParams = {0, 1, 1, 0, 0};
GM_ADDR outRecvTokenCntStartAddr = outRecvTokenCntAddr + tileIndex * 512;
copy_ubuf_to_gm(
outRecvTokenCntStartAddr, recvTokenCntAddr, dataCopyParams.sid, dataCopyParams.nBurst, dataCopyParams.lenBurst,
dataCopyParams.srcStride, dataCopyParams.dstStride);
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
}
template <typename T>
TILEOP void ConstructOutRecvTokenCnt(
__gm__ T* out, __ubuf__ uint32_t* src0, __ubuf__ uint32_t* src1, __ubuf__ uint32_t* dst, uint32_t cnt,
__gm__ int64_t* hcclContext, DispatchInfo& dispatchInfo)
{
GatherMaskAndSum(out, src0, src1, dst, MASK_SELECT_SEND_COUNT, cnt, hcclContext);
GM_ADDR outRecvTokenCntAddr = reinterpret_cast<GM_ADDR>(out);
UB_ADDR recvTokenCntAddr = reinterpret_cast<UB_ADDR>(src1);
CopyOutRecvTokenCnt(outRecvTokenCntAddr, recvTokenCntAddr, dispatchInfo.tileIndex, dispatchInfo.totalTileNum);
}
template <typename T>
TILEOP void MoeRankWaitFlag(
__gm__ T* out, __ubuf__ uint32_t* src0, __ubuf__ uint32_t* src1, __ubuf__ uint32_t* dst,
__gm__ int64_t* hcclContext, uint32_t processRankSize, __gm__ int32_t* shmemFlagBaseAddr,
DispatchInfo& dispatchInfo)
{
uint32_t cnt = processRankSize;
uint32_t flagSum = 0;
uint32_t offset = dispatchInfo.expertIndex * dispatchInfo.rankNum * 512 + dispatchInfo.rankOffset * 512;
while (flagSum != cnt) {
ReadFlagV2<T>(src0, offset, cnt, hcclContext, shmemFlagBaseAddr, dispatchInfo);
WaitFlagV2<T>(out, src0, src1, dst, cnt, hcclContext);
flagSum = src1[0];
}
ReadFlagV2<T>(src0, offset, cnt, hcclContext, shmemFlagBaseAddr, dispatchInfo);
ConstructOutRecvTokenCnt<T>(out, src0, src1, dst, cnt, hcclContext, dispatchInfo);
}
template <typename T>
TILEOP void ShareRankWaitFlag(
__gm__ T* out, __ubuf__ uint32_t* src0, __ubuf__ uint32_t* src1, __ubuf__ uint32_t* dst,
__gm__ int64_t* hcclContext, uint32_t processRankSize, __gm__ int32_t* shmemFlagBaseAddr,
DispatchInfo& dispatchInfo)
{
uint32_t startMoeRankId = dispatchInfo.tileIndex * processRankSize + dispatchInfo.shareRankCnt;
if (dispatchInfo.tileIndex != 0) {
return;
}
uint32_t cnt = processRankSize;
uint32_t flagSum = 0;
uint32_t offset = startMoeRankId * 512;
while (flagSum != cnt) {
ReadFlagV2<T>(src0, offset, cnt, hcclContext, shmemFlagBaseAddr, dispatchInfo);
WaitFlagV2<T>(out, src0, src1, dst, cnt, hcclContext);
flagSum = src1[0];
}
ClearFlagV2(
reinterpret_cast<__ubuf__ int32_t*>(src0), offset, cnt, hcclContext, dispatchInfo,
shmemFlagBaseAddr);
}
template <
typename T, uint32_t tileIndex, uint32_t groupIndex, uint32_t shareRankCnt, uint32_t totalTileNum,
int32_t rankShape, int32_t expertNum>
TILEOP void FFNSched(
CoreFuncParam* param, __gm__ T* out, __ubuf__ int32_t* buffer, __gm__ int32_t* dummy,
__gm__ int32_t* shmemFlagBaseAddr, uint32_t flagShmemOffset0, uint32_t flagShmemOffset1, uint32_t flagShmemOffset2,
uint32_t flagShmemOffset3, uint32_t flagShmemShape0, uint32_t flagShmemShape1, uint32_t flagShmemShape2,
uint32_t flagShmemShape3, __gm__ int64_t* hcclContext)
{
__gm__ CommContext* winContext = (__gm__ CommContext*)(hcclContext[groupIndex]);
DispatchInfo dispatchInfo = {tileIndex, groupIndex, 0, 0, rankShape, static_cast<int32_t>(flagShmemOffset2), 0, 0, 0, 0,
totalTileNum, shareRankCnt, expertNum, static_cast<int32_t>(winContext->rankNum), static_cast<int32_t>(flagShmemOffset1)};
__ubuf__ uint8_t* tmpUb = reinterpret_cast<__ubuf__ uint8_t*>(buffer);
uint32_t offset = 0;
__ubuf__ uint32_t* src0 = reinterpret_cast<__ubuf__ uint32_t*>(tmpUb + offset);
uint32_t moeOpProcessRankSize = dispatchInfo.rankShape;
uint32_t src0Size = moeOpProcessRankSize * 32;
offset += src0Size;
__ubuf__ uint32_t* sumResult = reinterpret_cast<__ubuf__ uint32_t*>(tmpUb + offset);
uint32_t src1Size = 256;
offset += src1Size;
__ubuf__ uint32_t* sumDst = reinterpret_cast<__ubuf__ uint32_t*>(tmpUb + offset);
uint32_t dstSize =
AlignUp<uint32_t>(moeOpProcessRankSize * 4, 256);
offset += dstSize;
MoeRankWaitFlag<T>(
out, src0, sumResult, sumDst, hcclContext, moeOpProcessRankSize, shmemFlagBaseAddr, dispatchInfo);
}
TILEOP void ReadRecvTokenCnt(
__ubuf__ uint32_t* recvTokenCnt, __gm__ uint32_t* src, DispatchInfo& dispatchInfo, __gm__ int64_t* hcclContext,
uint32_t tileCnt)
{
uint16_t srcStride = 512 / 32 - 1;
DataCopyParams gmToUbParams = {0,static_cast<uint16_t>(dispatchInfo.tileIndex),1,static_cast<uint16_t>(srcStride),0};
GM_ADDR thisTileStartSrcAddr = reinterpret_cast<GM_ADDR>(src);
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE2, EVENT_ID0);
copy_gm_to_ubuf(
recvTokenCnt, thisTileStartSrcAddr, gmToUbParams.sid, gmToUbParams.nBurst, gmToUbParams.lenBurst,
gmToUbParams.srcStride, gmToUbParams.dstStride);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_S, EVENT_ID0);
}
template <typename T, int32_t expertShape>
TILEOP void FFNValidCnt(
CoreFuncParam* param, __gm__ int32_t* validCnt, __ubuf__ int32_t* buffer, __gm__ int32_t* gmRecvTokenCnt,
__gm__ int32_t* shmemFlagBaseAddr, uint32_t shmemFlagOffset0, uint32_t shmemFlagOffset1, uint32_t shmemFlagOffset2,
uint32_t shmemFlagOffset3, uint32_t shmemFlagRawShape0, uint32_t shmemFlagRawShape1, uint32_t shmemFlagRawShape2,
uint32_t shmemFlagRawShape3, __gm__ int64_t* hcclContext)
{
int32_t localUsrRankId = static_cast<int32_t>(shmemFlagOffset0);
__gm__ int32_t* winFlagBaseAddr = MapVirtualAddr<int32_t>(hcclContext, shmemFlagBaseAddr, localUsrRankId);
__ubuf__ uint32_t* flag = reinterpret_cast<__ubuf__ uint32_t*>(buffer);
uint32_t flagSize = 32;
__ubuf__ int32_t* receiveCnt = reinterpret_cast<__ubuf__ int32_t*>(buffer + flagSize);
int32_t offsetResult = 0;
for (int32_t expertId = shmemFlagOffset1; expertId < shmemFlagOffset1 + expertShape; ++expertId) {
uint32_t receiveToken = 0;
for (int32_t rankId = 0; rankId < shmemFlagRawShape0; ++rankId) {
uint32_t thisRankFlagOffset = expertId * shmemFlagRawShape0 * 512 + rankId * 512;
GM_ADDR winFlagReadStartAddr = (GM_ADDR)winFlagBaseAddr + thisRankFlagOffset;
DataCopyParams dataCopyParams = {0, 1, 1, 15, 0};
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE2, EVENT_ID0);
copy_gm_to_ubuf(
flag, winFlagReadStartAddr, dataCopyParams.sid, dataCopyParams.nBurst, dataCopyParams.lenBurst,
dataCopyParams.srcStride, dataCopyParams.dstStride);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_S, EVENT_ID0);
receiveToken += flag[1];
}
pipe_barrier(PIPE_ALL);
receiveCnt[offsetResult++] = receiveToken;
}
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE3, EVENT_ID0);
TileOp::UBCopyOut<int32_t, 1, expertShape, expertShape, expertShape>(validCnt + shmemFlagOffset1, receiveCnt);
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
}
template <typename T>
TILEOP void CombineInfoCopyOut(
__gm__ int32_t* combineInfo, __ubuf__ uint8_t* combineInfoBuffer, DispatchInfo& dispatchInfo,
__gm__ int64_t* hcclContext, __gm__ T* shmemDataBaseAddr, __gm__ int32_t* shmemFlagBaseAddr, uint32_t rankSize,
uint32_t bs, uint32_t shmemLength)
{
__gm__ CommContext* winContext = (__gm__ CommContext*)(hcclContext[dispatchInfo.groupIndex]);
uint32_t localUsrRankId = winContext->rankId;
GM_ADDR localShmemBaseAddr = (GM_ADDR)MapVirtualAddr<T>(hcclContext, shmemDataBaseAddr, localUsrRankId);
__ubuf__ int32_t* combineBuffer = reinterpret_cast<__ubuf__ int32_t*>(combineInfoBuffer);
__ubuf__ uint32_t* flag = reinterpret_cast<__ubuf__ uint32_t*>(combineBuffer);
__ubuf__ int32_t* buffer = reinterpret_cast<__ubuf__ int32_t*>(combineBuffer + 32);
uint32_t tokenCnt = 0;
int32_t combineInfoOffset = 32;
for (int32_t rankId = dispatchInfo.rankOffset; rankId < dispatchInfo.rankOffset + dispatchInfo.rankShape;
rankId++) {
GM_ADDR thisRankExpertAddrBase =
localShmemBaseAddr + rankId * dispatchInfo.expertNumPerRank * bs * shmemLength * sizeof(T);
GM_ADDR thisRankExpertTokenAddr =
thisRankExpertAddrBase + dispatchInfo.expertIndex * bs * shmemLength * sizeof(T);
__gm__ int32_t* thisExpertCombineAddr = combineInfo + tokenCnt * MOE_COMBINE_INFO_NUM;
GM_ADDR thisRankExpertCombineAddr = thisRankExpertTokenAddr +
AlignUp<uint64_t>(dispatchInfo.colShape, 512) * sizeof(T) +
combineInfoOffset * sizeof(int32_t);
uint32_t thisRankFlagOffset = dispatchInfo.expertIndex * rankSize * 512 + rankId * 512;
ReadFlagV2(
flag, thisRankFlagOffset, 1, hcclContext, shmemFlagBaseAddr,
dispatchInfo);
pipe_barrier(PIPE_ALL);
uint32_t thisRankSendTokenCnt = flag[1];
uint32_t shmemColBurst = shmemLength * sizeof(T) / sizeof(int32_t);
__gm__ int32_t* combineExpertAddr = reinterpret_cast<__gm__ int32_t*>(thisRankExpertCombineAddr);
for (int i = 0; i < thisRankSendTokenCnt; i++) {
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE2, EVENT_ID0);
TileOp::UBCopyIn<int32_t, 1, MOE_COMBINE_INFO_NUM, 8, MOE_COMBINE_INFO_NUM>(
buffer, combineExpertAddr + i * shmemColBurst);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
TileOp::UBCopyOut<int32_t, 1, MOE_COMBINE_INFO_NUM, MOE_COMBINE_INFO_NUM, 8>(
thisExpertCombineAddr + i * MOE_COMBINE_INFO_NUM, buffer);
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
}
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
tokenCnt += thisRankSendTokenCnt;
}
}
template <typename T>
TILEOP void MoeRankWinCopyOut(
__gm__ T* expandX, __gm__ uint32_t* validCnt, __ubuf__ uint8_t* buffer, DispatchInfo& dispatchInfo,
__gm__ int64_t* hcclContext, __gm__ T* shmemDataBaseAddr, __gm__ int32_t* shmemFlagBaseAddr, uint32_t rankSize,
uint32_t bs, uint32_t shmemLength)
{
uint32_t offset = 0;
__ubuf__ T* token = reinterpret_cast<__ubuf__ T*>(buffer + offset);
uint32_t tokenSize = dispatchInfo.colShape * sizeof(T);
offset = offset + tokenSize;
__ubuf__ uint32_t* flag = reinterpret_cast<__ubuf__ uint32_t*>(buffer + offset);
uint32_t flagSize = 32;
offset = offset + flagSize;
__gm__ CommContext* winContext = (__gm__ CommContext*)(hcclContext[dispatchInfo.groupIndex]);
uint32_t localUsrRankId = winContext->rankId;
GM_ADDR localShmemBaseAddr = (GM_ADDR)MapVirtualAddr<T>(hcclContext, shmemDataBaseAddr, localUsrRankId);
uint32_t tokenCnt = 0;
uint16_t lenBurst = static_cast<uint16_t>(dispatchInfo.colShape * sizeof(T) / 32);
DataCopyParams gmToUbParams = {0, 1, lenBurst, 0, 0};
DataCopyParams ubToGmParams = {0, 1, lenBurst, 0, 0};
for (int32_t rankId = dispatchInfo.rankOffset; rankId < dispatchInfo.rankOffset + dispatchInfo.rankShape;
rankId++) {
GM_ADDR thisRankExpertAddrBase =
localShmemBaseAddr + rankId * dispatchInfo.expertNumPerRank * bs * shmemLength * sizeof(T);
GM_ADDR thisRankOutAddr = reinterpret_cast<GM_ADDR>(expandX) + tokenCnt * dispatchInfo.colShape * sizeof(T);
GM_ADDR thisRankExpertTokenAddr =
thisRankExpertAddrBase + dispatchInfo.expertIndex * bs * shmemLength * sizeof(T);
uint32_t thisRankFlagOffset =
dispatchInfo.expertIndex * rankSize * 512 + rankId * 512;
ReadFlagV2(
flag, thisRankFlagOffset, 1, hcclContext, shmemFlagBaseAddr,
dispatchInfo);
uint32_t thisRankSendTokenCnt = flag[1];
tokenCnt += thisRankSendTokenCnt;
for (int i = 0; i < thisRankSendTokenCnt; i++) {
CopyGmToGm(
thisRankOutAddr + i * dispatchInfo.colShape * sizeof(T),
thisRankExpertTokenAddr + i * shmemLength * sizeof(T), token, gmToUbParams, ubToGmParams);
}
}
}
template <typename T1, typename T2>
TILEOP void MoeRankCopyOut(
__gm__ T1* out, __gm__ uint32_t* validCnt, __ubuf__ uint8_t* buffer, __gm__ uint32_t* gmRecvTokenCnt,
DispatchInfo& dispatchInfo, __gm__ int64_t* hcclContext, __gm__ T2* shmemDataBaseAddr,
__gm__ int32_t* shmemFlagBaseAddr, uint32_t rankSize, uint32_t bs, uint32_t shmemLength, bool copyOutData)
{
int tileCnt = dispatchInfo.totalTileNum;
uint32_t offset = 0;
__ubuf__ uint32_t* recvTokenCnt = reinterpret_cast<__ubuf__ uint32_t*>(buffer + offset);
uint32_t recvTokenCntSize = AlignUp<uint32_t>(tileCnt * 32, 256);
offset = offset + recvTokenCntSize;
__ubuf__ uint32_t* cumSumDst = reinterpret_cast<__ubuf__ uint32_t*>(buffer + offset);
uint32_t cumSumDstSize = 512;
offset = offset + cumSumDstSize;
__ubuf__ uint32_t* gatherMaskDst = reinterpret_cast<__ubuf__ uint32_t*>(buffer + offset);
uint32_t gatherMaskDstSize = AlignUp<uint32_t>(tileCnt * 4, 32);
offset = offset + gatherMaskDstSize;
ReadRecvTokenCnt(recvTokenCnt, gmRecvTokenCnt, dispatchInfo, hcclContext, tileCnt);
CumSum(cumSumDst, recvTokenCnt, gatherMaskDst, MASK_SELECT_RECV_TOKEN_CNT, dispatchInfo.tileIndex);
uint32_t recvTokenOffset = cumSumDst[0];
if (copyOutData) {
uint32_t expandXOffset = recvTokenOffset * dispatchInfo.colShape;
__gm__ T2* expandX = reinterpret_cast<__gm__ T2*>(out);
MoeRankWinCopyOut<T2>(
expandX + expandXOffset, validCnt, buffer, dispatchInfo, hcclContext, shmemDataBaseAddr, shmemFlagBaseAddr,
rankSize, bs, shmemLength);
} else {
int32_t infoOffset = static_cast<int32_t>(recvTokenOffset * MOE_COMBINE_INFO_NUM);
__gm__ int32_t* combineInfo = reinterpret_cast<__gm__ int32_t*>(out);
CombineInfoCopyOut<T2>(
combineInfo + infoOffset, buffer, dispatchInfo, hcclContext, shmemDataBaseAddr, shmemFlagBaseAddr, rankSize,
bs, shmemLength);
}
}
template <typename T>
TILEOP void ShareRankWinCopyOut(
__gm__ T* expandX, __ubuf__ uint8_t* buffer, uint32_t processTokenCnt, uint32_t recvTokenOffset,
DispatchInfo& dispatchInfo, __gm__ int64_t* hcclContext, uint32_t processMoeRankCnt, __gm__ T* shmemDataAddr,
uint32_t shmemDataRawShape2, uint32_t shmemDataRawShape3)
{
if (processTokenCnt == 0) {
return;
}
__gm__ CommContext* winContext = (__gm__ CommContext*)(hcclContext[dispatchInfo.groupIndex]);
uint32_t localUsrRankId = winContext->rankId;
GM_ADDR winTokenAddr = (GM_ADDR)MapVirtualAddr<T>(hcclContext, shmemDataAddr, localUsrRankId);
uint32_t shareRankCnt = dispatchInfo.shareRankCnt;
uint32_t processMoeRankStartIdx =
localUsrRankId * processMoeRankCnt + shareRankCnt;
GM_ADDR thisRankWinTokenAddr =
winTokenAddr + processMoeRankStartIdx * shmemDataRawShape2 * shmemDataRawShape3 * sizeof(T);
GM_ADDR thisTileWinTokenAddr = thisRankWinTokenAddr + recvTokenOffset * shmemDataRawShape3 * sizeof(T);
uint32_t offset = 0;
__ubuf__ T* token = reinterpret_cast<__ubuf__ T*>(buffer + offset);
uint32_t tokenSize = dispatchInfo.colShape * sizeof(T);
offset = offset + tokenSize;
DataCopyParams gmToUbParams;
gmToUbParams.sid = 0;
gmToUbParams.nBurst = 1;
gmToUbParams.lenBurst = dispatchInfo.colShape * sizeof(T) / 32;
gmToUbParams.srcStride = 0;
gmToUbParams.dstStride = 0;
DataCopyParams ubToGmParams;
ubToGmParams.sid = 0;
ubToGmParams.nBurst = 1;
ubToGmParams.lenBurst = dispatchInfo.colShape * sizeof(T) / 32;
ubToGmParams.srcStride = 0;
ubToGmParams.dstStride = 0;
GM_ADDR thisTileOutAddr = reinterpret_cast<GM_ADDR>(expandX);
for (int i = 0; i < processTokenCnt; i++) {
CopyGmToGm(thisTileOutAddr, thisTileWinTokenAddr, token, gmToUbParams, ubToGmParams);
thisTileWinTokenAddr += shmemDataRawShape3 * sizeof(T);
thisTileOutAddr += dispatchInfo.colShape * sizeof(T);
}
}
template <typename T>
TILEOP void ShareRankCopyOut(
__gm__ T* expandX, __ubuf__ uint8_t* buffer, DispatchInfo& dispatchInfo, __gm__ int64_t* hcclContext,
__gm__ T* shmemDataBaseAddr, uint32_t shmemDataRawShape2, uint32_t shmemDataRawShape3)
{
__gm__ CommContext* winContext = (__gm__ CommContext*)(hcclContext[dispatchInfo.groupIndex]);
uint32_t vectorCnt = dispatchInfo.totalTileNum;
uint32_t processMoeRankCnt = (winContext->rankNum - dispatchInfo.shareRankCnt) / dispatchInfo.shareRankCnt;
uint32_t tokenCntRecvFromMoeRank = processMoeRankCnt * dispatchInfo.rowShape;
uint32_t tailProcessTokenCnt = tokenCntRecvFromMoeRank / vectorCnt;
uint32_t tileCnt = tokenCntRecvFromMoeRank % vectorCnt;
uint32_t tileProcessTokenCnt = tailProcessTokenCnt + 1;
uint32_t recvTokenOffset = 0;
uint32_t processTokenCnt = 0;
if (dispatchInfo.tileIndex < tileCnt) {
recvTokenOffset = dispatchInfo.tileIndex * tileProcessTokenCnt;
processTokenCnt = tileProcessTokenCnt;
} else {
recvTokenOffset = tileCnt * tileProcessTokenCnt + (dispatchInfo.tileIndex - tileCnt) * tailProcessTokenCnt;
processTokenCnt = tailProcessTokenCnt;
}
uint32_t outOffset = dispatchInfo.rowShape * dispatchInfo.colShape;
outOffset += recvTokenOffset * dispatchInfo.colShape;
ShareRankWinCopyOut<T>(
expandX + outOffset, buffer, processTokenCnt, recvTokenOffset, dispatchInfo, hcclContext, processMoeRankCnt,
shmemDataBaseAddr, shmemDataRawShape2, shmemDataRawShape3);
}
template <
typename T, uint32_t tileIndex, uint32_t groupIndex, uint32_t shareRankCnt, uint32_t totalTileNum,
uint32_t rankShape, uint32_t axisH, uint32_t bs, uint32_t expandXRow>
TILEOP void FFNBatching(
CoreFuncParam* param, __gm__ T* expandX, __gm__ int32_t* validCnt, __ubuf__ int32_t* buffer,
__gm__ T* shmemDataBaseAddr, __gm__ int32_t* shmemFlagBaseAddr, __gm__ int32_t* gmRecvTokenCnt,
uint32_t shmemDataOffset0, uint32_t shmemDataOffset1, uint32_t shmemDataOffset2, uint32_t shmemDataOffset3,
uint32_t shmemDataShape0, uint32_t shmemDataShape1, uint32_t shmemDataShape2, uint32_t shmemDataShape3,
__gm__ int64_t* hcclContext)
{
DispatchInfo dispatchInfo = { tileIndex, groupIndex, 0, 0, rankShape, static_cast<int32_t>(shmemDataOffset1),
bs, 0, axisH, 0, totalTileNum, shareRankCnt, static_cast<int32_t>(shmemDataShape2),
static_cast<int32_t>(shmemDataShape0), static_cast<int32_t>(shmemDataOffset2)};
int32_t shmemLength = AlignUp<int32_t>(axisH, 512) + 512;
MoeRankCopyOut<T, T>(
expandX, reinterpret_cast<__gm__ uint32_t*>(validCnt), reinterpret_cast<__ubuf__ uint8_t*>(buffer),
reinterpret_cast<__gm__ uint32_t*>(gmRecvTokenCnt), dispatchInfo, hcclContext, shmemDataBaseAddr,
shmemFlagBaseAddr, shmemDataShape0, bs, shmemLength, true);
}
template <
typename T, uint32_t tileIndex, uint32_t groupIndex, uint32_t shareRankCnt, uint32_t totalTileNum,
uint32_t rankShape, uint32_t axisH, uint32_t bs, uint32_t expandXRow>
TILEOP void FFNCombineInfo(
CoreFuncParam* param, __gm__ int32_t* combineInfo, __ubuf__ int32_t* buffer, __gm__ T* shmemDataBaseAddr,
__gm__ int32_t* shmemFlagBaseAddr, __gm__ int32_t* gmRecvTokenCnt, uint32_t shmemDataOffset0,
uint32_t shmemDataOffset1, uint32_t shmemDataOffset2, uint32_t shmemDataOffset3, uint32_t shmemDataShape0,
uint32_t shmemDataShape1, uint32_t shmemDataShape2, uint32_t shmemDataShape3, __gm__ int64_t* hcclContext)
{
DispatchInfo dispatchInfo = {tileIndex, groupIndex, 0, 0, rankShape, static_cast<int32_t>(shmemDataOffset1),
bs, 0, axisH, 0, totalTileNum, shareRankCnt, static_cast<int32_t>(shmemDataShape2),
static_cast<int32_t>(shmemDataShape0), static_cast<int32_t>(shmemDataOffset2)};
int32_t shmemLength = AlignUp<int32_t>(axisH, 512) + 512;
MoeRankCopyOut<int32_t, T>(
combineInfo, nullptr, reinterpret_cast<__ubuf__ uint8_t*>(buffer),
reinterpret_cast<__gm__ uint32_t*>(gmRecvTokenCnt), dispatchInfo, hcclContext, shmemDataBaseAddr,
shmemFlagBaseAddr, shmemDataShape0, bs, shmemLength, false);
}
}
#endif