Copyright (c) 2025 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iostream>
#include <vector>
#include <numeric>
#include <random>
#include <algorithm>
#include "acl/acl.h"
#include "securec.h"
#include "moe_dispatch_config.h"
#include "moe_dispatch_launch.h"
enum class DispatchMode
{
Direct,
ViaGM,
WithSync
};
static DispatchMode g_dispatchMode = DispatchMode::Direct;
#ifndef AICORE
#define AICORE
#endif
#ifndef __gm__
#define __gm__
#endif
#include "../../../../tests/npu/a5/comm/st/testcase/common.hpp"
struct RoutingTables {
int32_t EP;
int32_t expertPerRank;
int32_t maxTokensPerRank;
std::vector<int32_t> tokenPerExpert;
std::vector<int32_t> cumsumMM;
std::vector<int32_t> preSumBeforeRank;
void Generate(int myRank, std::mt19937 &rng)
{
tokenPerExpert.resize(EP * EP * expertPerRank, 0);
std::uniform_int_distribution<int32_t> dist(0, maxTokensPerRank / (EP * expertPerRank) + 1);
for (int src = 0; src < EP; ++src) {
int32_t totalFromSrc = 0;
for (int dst = 0; dst < EP; ++dst) {
for (int g = 0; g < expertPerRank; ++g) {
int32_t count = dist(rng);
if (totalFromSrc + count > maxTokensPerRank) {
count = maxTokensPerRank - totalFromSrc;
}
tokenPerExpert[src * EP * expertPerRank + dst * expertPerRank + g] = count;
totalFromSrc += count;
}
}
}
cumsumMM.resize(EP * expertPerRank, 0);
for (int g = 0; g < expertPerRank; ++g) {
int32_t cumSum = 0;
for (int srcRank = 0; srcRank < EP; ++srcRank) {
cumSum += tokenPerExpert[srcRank * EP * expertPerRank + myRank * expertPerRank + g];
cumsumMM[srcRank * expertPerRank + g] = cumSum;
}
}
preSumBeforeRank.resize(EP * expertPerRank, 0);
for (int srcRank = 0; srcRank < EP; ++srcRank) {
int32_t offset = 0;
for (int dst = 0; dst < EP; ++dst) {
for (int g = 0; g < expertPerRank; ++g) {
if (dst == myRank) {
preSumBeforeRank[srcRank * expertPerRank + g] = offset;
}
offset += tokenPerExpert[srcRank * EP * expertPerRank + dst * expertPerRank + g];
}
}
}
}
int32_t TotalOutputTokens(int myRank) const
{
int32_t total = 0;
for (int src = 0; src < EP; ++src) {
for (int g = 0; g < expertPerRank; ++g) {
total += tokenPerExpert[src * EP * expertPerRank + myRank * expertPerRank + g];
}
}
return total;
}
int32_t TotalSrcTokens(int srcRank) const
{
int32_t total = 0;
for (int dst = 0; dst < EP; ++dst) {
for (int g = 0; g < expertPerRank; ++g) {
total += tokenPerExpert[srcRank * EP * expertPerRank + dst * expertPerRank + g];
}
}
return total;
}
};
void GenerateShmemData(std::vector<int8_t> &shmemData, std::vector<float> &expectedScales, int32_t totalRows,
int32_t hiddenSize, int32_t rankId, std::mt19937 &rng)
{
int32_t rowStride = hiddenSize + UB_ALIGN;
shmemData.resize(static_cast<size_t>(totalRows) * rowStride, 0);
expectedScales.resize(totalRows, 0.0f);
std::uniform_int_distribution<int32_t> dataDist(-127, 127);
std::uniform_real_distribution<float> scaleDist(0.001f, 1.0f);
for (int32_t row = 0; row < totalRows; ++row) {
int8_t *rowPtr = shmemData.data() + static_cast<size_t>(row) * rowStride;
for (int32_t col = 0; col < hiddenSize; ++col) {
rowPtr[col] = static_cast<int8_t>(dataDist(rng));
}
float scale = scaleDist(rng);
expectedScales[row] = scale;
memcpy_s(rowPtr + hiddenSize, UB_ALIGN, &scale, sizeof(float));
if (UB_ALIGN > static_cast<int32_t>(sizeof(float))) {
memset_s(rowPtr + hiddenSize + sizeof(float), UB_ALIGN - sizeof(float), 0, UB_ALIGN - sizeof(float));
}
}
}
void ComputeGolden(const RoutingTables &routing, const std::vector<std::vector<int8_t>> &allShmemData, int32_t myRank,
int32_t hiddenSize, int32_t maxOutputSize, std::vector<int8_t> &expectedGmA,
std::vector<float> &expectedGmScale)
{
int32_t EP = routing.EP;
int32_t expertPerRank = routing.expertPerRank;
int32_t rowStride = hiddenSize + UB_ALIGN;
expectedGmA.resize(static_cast<size_t>(maxOutputSize) * hiddenSize, 0);
expectedGmScale.resize(maxOutputSize, 0.0f);
uint32_t prevGroupSum = 0;
std::vector<int32_t> prevSumPerRank(EP, 0);
for (int r = 0; r < EP; ++r) {
prevSumPerRank[r] = routing.preSumBeforeRank[r * expertPerRank];
}
for (int32_t groupIdx = 0; groupIdx < expertPerRank; ++groupIdx) {
uint32_t currentM = static_cast<uint32_t>(routing.cumsumMM[(EP - 1) * expertPerRank + groupIdx]);
for (int32_t dstEpIdx = 0; dstEpIdx < EP; ++dstEpIdx) {
uint32_t rowStart;
if (dstEpIdx == 0) {
rowStart = prevGroupSum;
} else {
rowStart =
static_cast<uint32_t>(routing.cumsumMM[(dstEpIdx - 1) * expertPerRank + groupIdx]) + prevGroupSum;
}
if (rowStart >= static_cast<uint32_t>(maxOutputSize))
continue;
int32_t tpeIdx = dstEpIdx * EP * expertPerRank + myRank * expertPerRank + groupIdx;
uint32_t rows = static_cast<uint32_t>(routing.tokenPerExpert[tpeIdx]);
if (rowStart + rows > static_cast<uint32_t>(maxOutputSize)) {
rows = static_cast<uint32_t>(maxOutputSize) - rowStart;
}
if (rows == 0)
continue;
uint32_t rowSrc = static_cast<uint32_t>(prevSumPerRank[dstEpIdx]);
prevSumPerRank[dstEpIdx] += static_cast<int32_t>(rows);
const int8_t *srcShmem = allShmemData[dstEpIdx].data();
for (uint32_t r = 0; r < rows; ++r) {
const int8_t *srcRow = srcShmem + static_cast<size_t>(rowSrc + r) * rowStride;
int8_t *dstToken = expectedGmA.data() + static_cast<size_t>(rowStart + r) * hiddenSize;
memcpy_s(dstToken, hiddenSize, srcRow, hiddenSize);
float scale;
memcpy_s(&scale, sizeof(float), srcRow + hiddenSize, sizeof(float));
expectedGmScale[rowStart + r] = scale;
}
}
prevGroupSum += currentM;
}
}
template <int HIDDEN_SIZE>
bool RunMoeDispatch(int rankId, int nRanks, int nDevices, int firstDeviceId, const HcclRootInfo *rootInfo,
const MoeDispatchParams ¶ms, uint32_t seed)
{
TestContext ctx;
if (!ctx.Init(rankId, nRanks, nDevices, firstDeviceId, rootInfo)) {
std::cerr << "[ERROR] Rank " << rankId << ": TestContext init failed\n";
return false;
}
int32_t EP = params.EP;
int32_t expertPerRank = params.expertPerRank;
int32_t hiddenSize = params.hiddenSize;
int32_t maxOutputSize = params.maxOutputSize;
int32_t rowStride = hiddenSize + UB_ALIGN;
std::mt19937 rng(seed);
RoutingTables routing;
routing.EP = EP;
routing.expertPerRank = expertPerRank;
routing.maxTokensPerRank = params.maxTokensPerRank;
routing.Generate(rankId, rng);
int32_t totalSrcTokens = routing.TotalSrcTokens(rankId);
int32_t totalDstTokens = routing.TotalOutputTokens(rankId);
std::cout << "[INFO] Rank " << rankId << ": totalSrcTokens=" << totalSrcTokens
<< " totalDstTokens=" << totalDstTokens << std::endl;
std::mt19937 dataRng(seed + rankId * 1000);
std::vector<int8_t> localShmemData;
std::vector<float> localScales;
GenerateShmemData(localShmemData, localScales, totalSrcTokens, hiddenSize, rankId, dataRng);
std::vector<std::vector<int8_t>> allShmemData(EP);
for (int r = 0; r < EP; ++r) {
std::mt19937 peerRng(seed + r * 1000);
int32_t peerTokens = routing.TotalSrcTokens(r);
std::vector<float> peerScales;
GenerateShmemData(allShmemData[r], peerScales, peerTokens, hiddenSize, r, peerRng);
}
std::vector<int8_t> expectedGmA;
std::vector<float> expectedGmScale;
ComputeGolden(routing, allShmemData, rankId, hiddenSize, maxOutputSize, expectedGmA, expectedGmScale);
size_t shmemSize = static_cast<size_t>(totalSrcTokens) * rowStride;
size_t gmASize = static_cast<size_t>(maxOutputSize) * hiddenSize;
size_t gmScaleSize = static_cast<size_t>(maxOutputSize) * UB_ALIGN;
size_t cumsumSize = EP * expertPerRank * sizeof(int32_t);
size_t tpeSize = EP * EP * expertPerRank * sizeof(int32_t);
size_t psbSize = EP * expertPerRank * sizeof(int32_t);
int32_t blockNum = EP;
size_t syncWsSize = static_cast<size_t>(blockNum) * SYNCALL_SOFT_SLOT_INT32 * sizeof(int32_t);
size_t tpeAreaSize = 0;
int64_t offsetTPE = 0;
size_t routingWsSize = 0;
size_t uniformShmemBase = static_cast<size_t>(params.maxTokensPerRank) * rowStride;
if (g_dispatchMode == DispatchMode::WithSync) {
tpeAreaSize = static_cast<size_t>(TPEAreaBytes(EP, expertPerRank));
offsetTPE = static_cast<int64_t>(uniformShmemBase);
routingWsSize = static_cast<size_t>(SyncWorkspaceBytes(EP, expertPerRank));
}
size_t winOffset = 256;
size_t totalShmemAlloc = (g_dispatchMode == DispatchMode::WithSync) ? uniformShmemBase + tpeAreaSize : shmemSize;
void *devShmem = WindowAlloc(ctx.hostCtx.windowsIn[rankId], winOffset, totalShmemAlloc);
void *devGmA = nullptr, *devGmScale = nullptr;
void *devCumsumMM = nullptr, *devTPE = nullptr, *devPSBR = nullptr;
void *devSyncWs = nullptr;
void *devRoutingWs = nullptr;
aclrtMalloc(&devGmA, gmASize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc(&devGmScale, gmScaleSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc(&devSyncWs, syncWsSize, ACL_MEM_MALLOC_HUGE_FIRST);
if (g_dispatchMode != DispatchMode::WithSync) {
aclrtMalloc(&devCumsumMM, cumsumSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc(&devTPE, tpeSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc(&devPSBR, psbSize, ACL_MEM_MALLOC_HUGE_FIRST);
} else {
aclrtMalloc(&devRoutingWs, routingWsSize, ACL_MEM_MALLOC_HUGE_FIRST);
}
aclrtMemcpy(devShmem, shmemSize, localShmemData.data(), shmemSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemset(devGmA, gmASize, 0, gmASize);
aclrtMemset(devGmScale, gmScaleSize, 0, gmScaleSize);
aclrtMemset(devSyncWs, syncWsSize, 0, syncWsSize);
if (g_dispatchMode == DispatchMode::WithSync) {
int32_t paddedExpNum = PaddedExpertNum(EP, expertPerRank);
size_t tpeRowBytes = static_cast<size_t>(paddedExpNum) * sizeof(int32_t);
std::vector<int32_t> localTPEPadded(paddedExpNum, 0);
for (int dst = 0; dst < EP; ++dst) {
for (int g = 0; g < expertPerRank; ++g) {
localTPEPadded[dst * expertPerRank + g] =
routing.tokenPerExpert[rankId * EP * expertPerRank + dst * expertPerRank + g];
}
}
uint8_t *tpeAreaPtr = reinterpret_cast<uint8_t *>(devShmem) + uniformShmemBase;
aclrtMemset(tpeAreaPtr, tpeAreaSize, 0, tpeAreaSize);
uint8_t *myRowPtr = tpeAreaPtr + static_cast<size_t>(rankId) * tpeRowBytes;
aclrtMemcpy(myRowPtr, tpeRowBytes, localTPEPadded.data(), tpeRowBytes, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemset(devRoutingWs, routingWsSize, 0, routingWsSize);
} else {
aclrtMemcpy(devCumsumMM, cumsumSize, routing.cumsumMM.data(), cumsumSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemcpy(devTPE, tpeSize, routing.tokenPerExpert.data(), tpeSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemcpy(devPSBR, psbSize, routing.preSumBeforeRank.data(), psbSize, ACL_MEMCPY_HOST_TO_DEVICE);
}
HcclHostBarrier(ctx.comm, ctx.stream);
int64_t offsetA = 0;
void *devTempGm = nullptr;
if (g_dispatchMode == DispatchMode::ViaGM) {
size_t tempGmSize = static_cast<size_t>(maxOutputSize) * rowStride;
aclrtMalloc(&devTempGm, tempGmSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMemset(devTempGm, tempGmSize, 0, tempGmSize);
}
bool kernelOk = false;
if (g_dispatchMode == DispatchMode::Direct) {
kernelOk = LaunchMoeDispatchK128(blockNum, ctx.stream, devGmA, devGmScale, devCumsumMM, devTPE, devPSBR,
devShmem, ctx.deviceCtx, devSyncWs, EP, expertPerRank, maxOutputSize, offsetA);
} else if (g_dispatchMode == DispatchMode::ViaGM) {
kernelOk = LaunchMoeDispatchViaGM_K128(blockNum, ctx.stream, devGmA, devGmScale, devTempGm, devCumsumMM, devTPE,
devPSBR, devShmem, ctx.deviceCtx, devSyncWs, EP, expertPerRank,
maxOutputSize, offsetA);
} else {
kernelOk = LaunchMoeDispatchWithSync_K128(blockNum, ctx.stream, devGmA, devGmScale, devShmem, ctx.deviceCtx,
devRoutingWs, devSyncWs, EP, expertPerRank, maxOutputSize, offsetA,
offsetTPE);
}
if (!kernelOk) {
std::cerr << "[ERROR] Rank " << rankId << ": kernel execution failed\n";
aclrtFree(devGmA);
aclrtFree(devGmScale);
if (devCumsumMM)
aclrtFree(devCumsumMM);
if (devTPE)
aclrtFree(devTPE);
if (devPSBR)
aclrtFree(devPSBR);
aclrtFree(devSyncWs);
if (devRoutingWs)
aclrtFree(devRoutingWs);
if (devTempGm)
aclrtFree(devTempGm);
ctx.Finalize();
return false;
}
std::cout << "[INFO] Rank " << rankId << ": kernel execution completed\n";
size_t compactTokenSize = static_cast<size_t>(maxOutputSize) * hiddenSize;
std::vector<int8_t> actualGmA(compactTokenSize);
aclrtMemcpy(actualGmA.data(), compactTokenSize, devGmA, compactTokenSize, ACL_MEMCPY_DEVICE_TO_HOST);
std::vector<uint8_t> actualScaleRaw(gmScaleSize);
aclrtMemcpy(actualScaleRaw.data(), gmScaleSize, devGmScale, gmScaleSize, ACL_MEMCPY_DEVICE_TO_HOST);
bool pass = true;
int32_t checkedRows = std::min(totalDstTokens, maxOutputSize);
for (int32_t i = 0; i < checkedRows && pass; ++i) {
for (int32_t j = 0; j < hiddenSize && pass; ++j) {
int8_t actual = actualGmA[static_cast<size_t>(i) * hiddenSize + j];
int8_t expected = expectedGmA[static_cast<size_t>(i) * hiddenSize + j];
if (actual != expected) {
std::cerr << "[FAIL] Rank " << rankId << ": token mismatch at row " << i << " col " << j
<< " actual=" << (int)actual << " expected=" << (int)expected << "\n";
pass = false;
}
}
}
for (int32_t i = 0; i < checkedRows && pass; ++i) {
float actual;
memcpy_s(&actual, sizeof(float), actualScaleRaw.data() + static_cast<size_t>(i) * UB_ALIGN, sizeof(float));
if (actual != expectedGmScale[i]) {
std::cerr << "[FAIL] Rank " << rankId << ": scale mismatch at row " << i << " actual=" << actual
<< " expected=" << expectedGmScale[i] << "\n";
pass = false;
}
}
if (pass) {
std::cout << "[PASS] Rank " << rankId << ": verified " << checkedRows << " rows (compact token " << hiddenSize
<< "B/row + scale)\n";
}
aclrtFree(devGmA);
aclrtFree(devGmScale);
if (devCumsumMM)
aclrtFree(devCumsumMM);
if (devTPE)
aclrtFree(devTPE);
if (devPSBR)
aclrtFree(devPSBR);
aclrtFree(devSyncWs);
if (devRoutingWs)
aclrtFree(devRoutingWs);
if (devTempGm)
aclrtFree(devTempGm);
ctx.Finalize();
return pass;
}
int main(int argc, char *argv[])
{
CommMpiInit(&argc, &argv);
const char *modeEnv = std::getenv("DISPATCH_MODE");
if (modeEnv && std::string(modeEnv) == "viagm") {
g_dispatchMode = DispatchMode::ViaGM;
} else if (modeEnv && std::string(modeEnv) == "sync") {
g_dispatchMode = DispatchMode::WithSync;
}
MoeDispatchParams params;
params.EP = CONFIG_EP;
params.expertPerRank = CONFIG_EXPERT_PER_RANK;
params.hiddenSize = CONFIG_HIDDEN_SIZE;
params.maxOutputSize = CONFIG_MAX_OUTPUT_SIZE;
params.maxTokensPerRank = CONFIG_MAX_TOKENS_PER_RANK;
uint32_t seed = 42;
int nRanks = params.EP;
int nDevices = nRanks;
int firstDeviceId = CONFIG_FIRST_DEVICE_ID;
std::cout << "=== MoE Dispatch PTO-ISA Test ===" << std::endl;
std::cout << "EP=" << params.EP << " expertPerRank=" << params.expertPerRank << " hiddenSize=" << params.hiddenSize
<< " maxOutput=" << params.maxOutputSize << " maxTokens/rank=" << params.maxTokensPerRank << " mode="
<< (g_dispatchMode == DispatchMode::ViaGM ? "viagm" :
g_dispatchMode == DispatchMode::WithSync ? "sync" :
"direct")
<< std::endl;
bool success =
ForkAndRunWithHcclRootInfo(nRanks, 0, firstDeviceId, [&](int rankId, const HcclRootInfo *rootInfo) -> bool {
return RunMoeDispatch<CONFIG_HIDDEN_SIZE>(rankId, nRanks, nDevices, firstDeviceId, rootInfo, params, seed);
});
if (success) {
std::cout << "[PASS] MoE Dispatch test completed successfully\n";
} else {
std::cerr << "[FAIL] MoE Dispatch test failed\n";
}
CommMpiFinalize();
return success ? 0 : 1;
}