* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstdlib>
#include <iostream>
#include <string>
#include <unistd.h>
#include "acl/acl.h"
#include "opdev/bfloat16.h"
#include "opdev/fp16_t.h"
#include "param.h"
#include "shmem.h"
#include "utils.h"
#include "moe_perf_host.h"
#include "dispatch_kernel.h"
using fp16_t = op::fp16_t;
using bfloat16 = op::bfloat16;
namespace {
constexpr int MAGIC_MULTIPLIER = 1024;
constexpr int ASSIST_FIELDS = 3;
constexpr size_t ALIGN_BYTES = 32;
constexpr int FULL_FRAME_ID = 0;
constexpr int COMM_FRAME_ID = 1;
constexpr int DISPATCH_UB_SIZE_KB = 190;
int g_npus = 8;
const char *ipport = "tcp://127.0.0.1:8766";
int f_npu = 0;
const char *data_type = "int32_t";
aclshmemx_uniqueid_t default_flag_uid;
aclshmem_prof_pe_t *out_profs = nullptr;
template <typename T>
bool AlmostEqual(T lhs, T rhs)
{
return static_cast<float>(lhs) == static_cast<float>(rhs);
}
template <>
bool AlmostEqual<fp16_t>(fp16_t lhs, fp16_t rhs)
{
const float diff = std::abs(static_cast<float>(lhs) - static_cast<float>(rhs));
return diff <= 1.0e-2F;
}
template <>
bool AlmostEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs)
{
const float diff = std::abs(static_cast<float>(lhs) - static_cast<float>(rhs));
return diff <= 1.0e-2F;
}
template <typename T>
bool CheckArray(const std::string &name, const T *actual, const T *expected, size_t count, int pe_id)
{
for (size_t i = 0; i < count; ++i) {
if (!AlmostEqual(actual[i], expected[i])) {
std::cerr << "[Dispatch] " << name << " mismatch, pe_id=" << pe_id << ", idx=" << i
<< ", actual=" << static_cast<float>(actual[i])
<< ", expected=" << static_cast<float>(expected[i]) << std::endl;
return false;
}
}
return true;
}
bool CheckIntArray(const std::string &name, const int32_t *actual, const int32_t *expected, size_t count, int pe_id)
{
for (size_t i = 0; i < count; ++i) {
if (actual[i] != expected[i]) {
std::cerr << "[Dispatch] " << name << " mismatch, pe_id=" << pe_id << ", idx=" << i
<< ", actual=" << actual[i] << ", expected=" << expected[i] << std::endl;
return false;
}
}
return true;
}
size_t AlignUp(size_t value, size_t alignment)
{
return (value + alignment - 1) / alignment * alignment;
}
size_t DispatchWindowBytes(size_t bs, size_t h, size_t topk, size_t pe_size, size_t local_expert_num, size_t elem_size)
{
const size_t max_tokens_per_segment = bs * topk;
const size_t segment_num = pe_size * local_expert_num;
const size_t total_slots = segment_num * max_tokens_per_segment;
const size_t payload_stride = AlignUp(h * elem_size, ALIGN_BYTES);
return total_slots * payload_stride +
total_slots * ALIGN_BYTES +
total_slots * ALIGN_BYTES +
segment_num * ALIGN_BYTES +
segment_num * ALIGN_BYTES;
}
size_t DispatchCountOffset(size_t bs, size_t h, size_t topk, size_t pe_size, size_t local_expert_num,
size_t elem_size)
{
const size_t max_tokens_per_segment = bs * topk;
const size_t segment_num = pe_size * local_expert_num;
const size_t total_slots = segment_num * max_tokens_per_segment;
const size_t payload_stride = AlignUp(h * elem_size, ALIGN_BYTES);
return total_slots * payload_stride + total_slots * ALIGN_BYTES + total_slots * ALIGN_BYTES;
}
size_t DispatchCountBytes(size_t pe_size, size_t local_expert_num)
{
return pe_size * local_expert_num * ALIGN_BYTES;
}
template <class T>
int RunDispatchCase(const MoeShapeArgs &args, const MoePerfArgs &perf_args)
{
if (args.expert_per_pe <= 0 || args.pe_size <= 0 || args.pe_id < 0 || args.pe_id >= args.pe_size ||
args.bs <= 0 || args.h <= 0 || args.topk <= 0 || g_npus <= 0) {
std::cerr << "[Dispatch] invalid arguments." << std::endl;
return 1;
}
if (args.expert_per_pe > DISPATCH_MAX_LOCAL_EXPERT_NUM) {
std::cerr << "[Dispatch] invalid expert_per_pe=" << args.expert_per_pe
<< ", max supported value is " << DISPATCH_MAX_LOCAL_EXPERT_NUM << "." << std::endl;
return 1;
}
aclrtStream stream = nullptr;
T *x_host = nullptr;
int32_t *expert_ids_host = nullptr;
T *expand_x_host = nullptr;
int32_t *assist_host = nullptr;
int32_t *ep_recv_count_host = nullptr;
int32_t *expert_token_nums_host = nullptr;
T *golden_expand_x = nullptr;
int32_t *golden_assist = nullptr;
int32_t *golden_ep_recv_count = nullptr;
int32_t *golden_expert_token_nums = nullptr;
void *x_device = nullptr;
void *expert_ids_device = nullptr;
void *expand_x_device = nullptr;
void *assist_device = nullptr;
void *ep_recv_count_device = nullptr;
void *expert_token_nums_device = nullptr;
void *shmem_window = nullptr;
auto cleanup = [&]() {
if (shmem_window != nullptr) {
aclshmem_free(shmem_window);
shmem_window = nullptr;
}
if (x_device != nullptr) {
ACL_CHECK(aclrtFree(x_device));
x_device = nullptr;
}
if (expert_ids_device != nullptr) {
ACL_CHECK(aclrtFree(expert_ids_device));
expert_ids_device = nullptr;
}
if (expand_x_device != nullptr) {
ACL_CHECK(aclrtFree(expand_x_device));
expand_x_device = nullptr;
}
if (assist_device != nullptr) {
ACL_CHECK(aclrtFree(assist_device));
assist_device = nullptr;
}
if (ep_recv_count_device != nullptr) {
ACL_CHECK(aclrtFree(ep_recv_count_device));
ep_recv_count_device = nullptr;
}
if (expert_token_nums_device != nullptr) {
ACL_CHECK(aclrtFree(expert_token_nums_device));
expert_token_nums_device = nullptr;
}
if (x_host != nullptr) {
ACL_CHECK(aclrtFreeHost(x_host));
x_host = nullptr;
}
if (expert_ids_host != nullptr) {
ACL_CHECK(aclrtFreeHost(expert_ids_host));
expert_ids_host = nullptr;
}
if (expand_x_host != nullptr) {
ACL_CHECK(aclrtFreeHost(expand_x_host));
expand_x_host = nullptr;
}
if (assist_host != nullptr) {
ACL_CHECK(aclrtFreeHost(assist_host));
assist_host = nullptr;
}
if (ep_recv_count_host != nullptr) {
ACL_CHECK(aclrtFreeHost(ep_recv_count_host));
ep_recv_count_host = nullptr;
}
if (expert_token_nums_host != nullptr) {
ACL_CHECK(aclrtFreeHost(expert_token_nums_host));
expert_token_nums_host = nullptr;
}
if (golden_expand_x != nullptr) {
ACL_CHECK(aclrtFreeHost(golden_expand_x));
golden_expand_x = nullptr;
}
if (golden_assist != nullptr) {
ACL_CHECK(aclrtFreeHost(golden_assist));
golden_assist = nullptr;
}
if (golden_ep_recv_count != nullptr) {
ACL_CHECK(aclrtFreeHost(golden_ep_recv_count));
golden_ep_recv_count = nullptr;
}
if (golden_expert_token_nums != nullptr) {
ACL_CHECK(aclrtFreeHost(golden_expert_token_nums));
golden_expert_token_nums = nullptr;
}
if (stream != nullptr) {
ACL_CHECK(aclrtDestroyStream(stream));
stream = nullptr;
}
};
auto check_acl = [&](aclError error, const char *op_name) -> bool {
if (error != ACL_ERROR_NONE) {
std::cerr << "[Dispatch] " << op_name << " failed, aclError=" << error << std::endl;
return false;
}
return true;
};
auto check_not_null = [&](const void *ptr, const char *name) -> bool {
if (ptr == nullptr) {
std::cerr << "[Dispatch] " << name << " returned nullptr." << std::endl;
return false;
}
return true;
};
auto alloc_host = [&](void **ptr, size_t bytes, const char *name) -> bool {
if (!check_acl(aclrtMallocHost(ptr, bytes), name)) {
return false;
}
return check_not_null(*ptr, name);
};
auto alloc_device = [&](void **ptr, size_t bytes, const char *name) -> bool {
if (!check_acl(aclrtMalloc(ptr, bytes, ACL_MEM_MALLOC_HUGE_FIRST), name)) {
return false;
}
return check_not_null(*ptr, name);
};
if (!check_acl(aclrtCreateStream(&stream), "aclrtCreateStream") || !check_not_null(stream, "aclrtCreateStream")) {
cleanup();
return 1;
}
const uint64_t fftsAddr = util_get_ffts_config();
const int64_t moe_expert_num = static_cast<int64_t>(args.expert_per_pe) * args.pe_size;
const int64_t max_recv_tokens = static_cast<int64_t>(args.pe_size) * args.bs * args.topk;
const int64_t segment_num = static_cast<int64_t>(args.expert_per_pe) * args.pe_size;
const std::string case_dir = "golden/shape_" + std::to_string(args.bs) + "_" +
std::to_string(args.h) + "_" + std::to_string(args.topk) + "_" +
std::to_string(moe_expert_num) + "_" + std::to_string(args.pe_size);
const std::string rank_dir = case_dir + "/rank_" + std::to_string(args.pe_id);
const size_t x_bytes = static_cast<size_t>(args.bs) * args.h * sizeof(T);
const size_t expert_ids_bytes = static_cast<size_t>(args.bs) * args.topk * sizeof(int32_t);
const size_t expand_x_bytes = static_cast<size_t>(max_recv_tokens) * args.h * sizeof(T);
const size_t assist_bytes = static_cast<size_t>(max_recv_tokens) * ASSIST_FIELDS * sizeof(int32_t);
const size_t ep_recv_count_bytes = static_cast<size_t>(segment_num) * sizeof(int32_t);
const size_t expert_token_nums_bytes = static_cast<size_t>(args.expert_per_pe) * sizeof(int32_t);
const size_t shmem_window_bytes =
DispatchWindowBytes(args.bs, args.h, args.topk, args.pe_size, args.expert_per_pe, sizeof(T));
const size_t shmem_count_offset =
DispatchCountOffset(args.bs, args.h, args.topk, args.pe_size, args.expert_per_pe, sizeof(T));
const size_t shmem_count_bytes = DispatchCountBytes(args.pe_size, args.expert_per_pe);
if (!alloc_host(reinterpret_cast<void **>(&x_host), x_bytes, "aclrtMallocHost(x_host)") ||
!alloc_host(reinterpret_cast<void **>(&expert_ids_host), expert_ids_bytes, "aclrtMallocHost(expert_ids_host)")) {
cleanup();
return 1;
}
if (!ReadFile(rank_dir + "/x.bin", x_host, x_bytes) ||
!ReadFile(rank_dir + "/expert_ids.bin", expert_ids_host, expert_ids_bytes)) {
std::cerr << "[Dispatch] failed to read input files from " << rank_dir << std::endl;
cleanup();
return 1;
}
if (!alloc_device(&x_device, x_bytes, "aclrtMalloc(x_device)") ||
!alloc_device(&expert_ids_device, expert_ids_bytes, "aclrtMalloc(expert_ids_device)") ||
!alloc_device(&expand_x_device, expand_x_bytes, "aclrtMalloc(expand_x_device)") ||
!alloc_device(&assist_device, assist_bytes, "aclrtMalloc(assist_device)") ||
!alloc_device(&ep_recv_count_device, ep_recv_count_bytes, "aclrtMalloc(ep_recv_count_device)") ||
!alloc_device(&expert_token_nums_device, expert_token_nums_bytes, "aclrtMalloc(expert_token_nums_device)")) {
cleanup();
return 1;
}
shmem_window = aclshmem_malloc(shmem_window_bytes);
if (!check_not_null(shmem_window, "aclshmem_malloc(shmem_window)")) {
cleanup();
return 1;
}
if (!check_acl(aclrtMemcpy(x_device, x_bytes, x_host, x_bytes, ACL_MEMCPY_HOST_TO_DEVICE),
"aclrtMemcpy(x_device)") ||
!check_acl(aclrtMemcpy(expert_ids_device, expert_ids_bytes, expert_ids_host, expert_ids_bytes,
ACL_MEMCPY_HOST_TO_DEVICE),
"aclrtMemcpy(expert_ids_device)") ||
!check_acl(aclrtMemset(expand_x_device, expand_x_bytes, 0, expand_x_bytes), "aclrtMemset(expand_x_device)") ||
!check_acl(aclrtMemset(assist_device, assist_bytes, 0, assist_bytes), "aclrtMemset(assist_device)") ||
!check_acl(aclrtMemset(ep_recv_count_device, ep_recv_count_bytes, 0, ep_recv_count_bytes),
"aclrtMemset(ep_recv_count_device)") ||
!check_acl(aclrtMemset(expert_token_nums_device, expert_token_nums_bytes, 0, expert_token_nums_bytes),
"aclrtMemset(expert_token_nums_device)") ||
!check_acl(aclrtMemset(shmem_window, shmem_window_bytes, 0, shmem_window_bytes), "aclrtMemset(shmem_window)") ||
!check_acl(aclrtMemset(reinterpret_cast<uint8_t *>(shmem_window) + shmem_count_offset, shmem_count_bytes,
0xFF, shmem_count_bytes),
"aclrtMemset(shmem_count_window)") ||
!check_acl(aclrtSynchronizeStream(stream), "aclrtSynchronizeStream")) {
cleanup();
return 1;
}
aclshmem_barrier_all();
const uint32_t block_num = static_cast<uint32_t>(args.pe_size);
const int warmup_count = perf_args.perf_mode ? perf_args.warmup_count : 0;
const int loop_count = perf_args.perf_mode ? perf_args.loop_count : 1;
for (int iter = 0; iter < warmup_count + loop_count; ++iter) {
const int launch_perf_mode = (perf_args.perf_mode && iter >= warmup_count) ? 1 : 0;
dispatch_demo<T>(block_num, stream, fftsAddr, reinterpret_cast<uint8_t *>(x_device),
reinterpret_cast<int32_t *>(expert_ids_device), reinterpret_cast<uint8_t *>(expand_x_device),
reinterpret_cast<int32_t *>(assist_device), reinterpret_cast<int32_t *>(ep_recv_count_device),
reinterpret_cast<int32_t *>(expert_token_nums_device),
reinterpret_cast<uint8_t *>(shmem_window), args.bs, args.h, args.topk, moe_expert_num,
MAGIC_MULTIPLIER, launch_perf_mode, FULL_FRAME_ID, COMM_FRAME_ID, 0, 1);
if (!check_acl(aclrtSynchronizeStream(stream), "aclrtSynchronizeStream")) {
cleanup();
return 1;
}
aclshmem_barrier_all();
}
if (!alloc_host(reinterpret_cast<void **>(&expand_x_host), expand_x_bytes, "aclrtMallocHost(expand_x_host)") ||
!alloc_host(reinterpret_cast<void **>(&assist_host), assist_bytes, "aclrtMallocHost(assist_host)") ||
!alloc_host(reinterpret_cast<void **>(&ep_recv_count_host), ep_recv_count_bytes,
"aclrtMallocHost(ep_recv_count_host)") ||
!alloc_host(reinterpret_cast<void **>(&expert_token_nums_host), expert_token_nums_bytes,
"aclrtMallocHost(expert_token_nums_host)")) {
cleanup();
return 1;
}
if (!check_acl(aclrtMemcpy(expand_x_host, expand_x_bytes, expand_x_device, expand_x_bytes, ACL_MEMCPY_DEVICE_TO_HOST),
"aclrtMemcpy(expand_x_host)") ||
!check_acl(aclrtMemcpy(assist_host, assist_bytes, assist_device, assist_bytes, ACL_MEMCPY_DEVICE_TO_HOST),
"aclrtMemcpy(assist_host)") ||
!check_acl(aclrtMemcpy(ep_recv_count_host, ep_recv_count_bytes, ep_recv_count_device, ep_recv_count_bytes,
ACL_MEMCPY_DEVICE_TO_HOST),
"aclrtMemcpy(ep_recv_count_host)") ||
!check_acl(aclrtMemcpy(expert_token_nums_host, expert_token_nums_bytes, expert_token_nums_device,
expert_token_nums_bytes, ACL_MEMCPY_DEVICE_TO_HOST),
"aclrtMemcpy(expert_token_nums_host)")) {
cleanup();
return 1;
}
make_dir("output");
const std::string out_prefix = "output/";
const std::string pe_suffix = std::to_string(args.pe_id) + ".bin";
const std::string expand_out = out_prefix + "expand_x_" + pe_suffix;
const std::string assist_out = out_prefix + "assist_info_" + pe_suffix;
const std::string recv_out = out_prefix + "ep_recv_count_" + pe_suffix;
const std::string expert_out = out_prefix + "expert_token_nums_" + pe_suffix;
unlink(expand_out.c_str());
unlink(assist_out.c_str());
unlink(recv_out.c_str());
unlink(expert_out.c_str());
WriteFile(expand_out, expand_x_host, expand_x_bytes);
WriteFile(assist_out, assist_host, assist_bytes);
WriteFile(recv_out, ep_recv_count_host, ep_recv_count_bytes);
WriteFile(expert_out, expert_token_nums_host, expert_token_nums_bytes);
if (!alloc_host(reinterpret_cast<void **>(&golden_expand_x), expand_x_bytes, "aclrtMallocHost(golden_expand_x)") ||
!alloc_host(reinterpret_cast<void **>(&golden_assist), assist_bytes, "aclrtMallocHost(golden_assist)") ||
!alloc_host(reinterpret_cast<void **>(&golden_ep_recv_count), ep_recv_count_bytes,
"aclrtMallocHost(golden_ep_recv_count)") ||
!alloc_host(reinterpret_cast<void **>(&golden_expert_token_nums), expert_token_nums_bytes,
"aclrtMallocHost(golden_expert_token_nums)")) {
cleanup();
return 1;
}
if (!ReadFile(rank_dir + "/golden_expand_x.bin", golden_expand_x, expand_x_bytes) ||
!ReadFile(rank_dir + "/golden_assist_info.bin", golden_assist, assist_bytes) ||
!ReadFile(rank_dir + "/golden_ep_recv_count.bin", golden_ep_recv_count, ep_recv_count_bytes) ||
!ReadFile(rank_dir + "/golden_expert_token_nums.bin", golden_expert_token_nums, expert_token_nums_bytes)) {
std::cerr << "[Dispatch] failed to read golden files from " << rank_dir << std::endl;
cleanup();
return 1;
}
const bool ok = CheckArray("expand_x", expand_x_host, golden_expand_x, max_recv_tokens * args.h, args.pe_id) &&
CheckIntArray("assist_info", assist_host, golden_assist, max_recv_tokens * ASSIST_FIELDS,
args.pe_id) &&
CheckIntArray("ep_recv_count", ep_recv_count_host, golden_ep_recv_count, segment_num,
args.pe_id) &&
CheckIntArray("expert_token_nums", expert_token_nums_host, golden_expert_token_nums,
args.expert_per_pe, args.pe_id);
if (ok && perf_args.perf_mode && args.pe_id == MoeGetProfPe()) {
MoeAppendPerfCsvRows("Dispatch", &out_profs, perf_args.csv_path, args, perf_args, data_type, sizeof(T),
static_cast<int>(block_num), DISPATCH_UB_SIZE_KB);
}
cleanup();
if (!ok) {
return 1;
}
std::cout << "[Dispatch] pe " << args.pe_id << " result correct." << std::endl;
return 0;
}
}
int main(int argc, char *argv[])
{
if (argc < INDEX12) {
std::cerr << "Usage: dispatch pe_size pe_id ipport g_npus first_pe first_npu data_type bs h topk "
"expert_per_pe"
<< std::endl;
return 1;
}
MoeShapeArgs args;
args.pe_size = std::atoi(argv[INDEX1]);
args.pe_id = std::atoi(argv[INDEX2]);
ipport = argv[INDEX3];
g_npus = std::atoi(argv[INDEX4]);
const int first_pe = std::atoi(argv[INDEX5]);
(void)first_pe;
f_npu = std::atoi(argv[INDEX6]);
data_type = argv[INDEX7];
args.bs = std::atoi(argv[INDEX8]);
args.h = std::atoi(argv[INDEX9]);
args.topk = std::atoi(argv[INDEX10]);
args.expert_per_pe = std::atoi(argv[INDEX11]);
MoePerfArgs perf_args;
MoeInitPerfArgsFromArgv(perf_args, argc, argv, args, INDEX12);
const int32_t device_id = args.pe_id % g_npus + f_npu;
ACL_CHECK(aclInit(nullptr));
ACL_CHECK(aclrtSetDevice(device_id));
aclshmemx_init_attr_t attributes;
test_set_attr(args.pe_id, args.pe_size, 1024UL * 1024UL * 1024, ipport, default_flag_uid, &attributes);
ACL_CHECK(aclshmemx_init_attr(ACLSHMEMX_INIT_WITH_DEFAULT, &attributes));
int status = 0;
if (std::string(data_type) == "int" || std::string(data_type) == "int32_t") {
status = RunDispatchCase<int32_t>(args, perf_args);
} else if (std::string(data_type) == "float16_t") {
status = RunDispatchCase<fp16_t>(args, perf_args);
} else if (std::string(data_type) == "bfloat16_t") {
status = RunDispatchCase<bfloat16>(args, perf_args);
} else {
std::cerr << "[Dispatch] unsupported data type: " << data_type << std::endl;
status = 1;
}
ACL_CHECK(aclshmem_finalize());
ACL_CHECK(aclrtResetDevice(device_id));
ACL_CHECK(aclFinalize());
if (status != 0) {
return status;
}
std::cout << "[SUCCESS] dispatch demo run success in pe " << args.pe_id << std::endl;
return 0;
}