* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cmath>
#include <algorithm>
#include <cerrno>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include <sys/stat.h>
#include <unistd.h>
#include "acl/acl.h"
#include "opdev/fp16_t.h"
#include "param.h"
#include "shmem.h"
#include "utils.h"
#include "combine_kernel.h"
using fp16_t = op::fp16_t;
namespace {
constexpr int MAGIC_MULTIPLIER = 1024;
constexpr int ASSIST_FIELDS = 3;
constexpr size_t ALIGN_BYTES = 32;
constexpr int FULL_FRAME_ID = 0;
constexpr int COMM_FRAME_ID = 1;
constexpr int COMBINE_UB_SIZE_KB = 190;
int g_npus = 8;
const char *ipport = "tcp://127.0.0.1:8766";
int f_npu = 0;
const char *data_type = "int32_t";
aclshmemx_uniqueid_t default_flag_uid;
aclshmem_prof_pe_t *out_profs = nullptr;
struct CombineArgs {
int pe_size = 2;
int pe_id = 0;
int bs = 8;
int h = 7168;
int topk = 8;
int expert_per_pe = 2;
};
struct PerfArgs {
int perf_mode = 0;
int warmup_count = 0;
int loop_count = 1;
std::string csv_path;
std::string case_id;
};
template <typename T>
bool AlmostEqual(T lhs, T rhs)
{
return static_cast<float>(lhs) == static_cast<float>(rhs);
}
template <>
bool AlmostEqual<fp16_t>(fp16_t lhs, fp16_t rhs)
{
const float diff = std::abs(static_cast<float>(lhs) - static_cast<float>(rhs));
return diff <= 1.0e-2F;
}
template <typename T>
bool CheckArray(const T *actual, const T *expected, size_t count, int pe_id)
{
for (size_t i = 0; i < count; ++i) {
if (!AlmostEqual(actual[i], expected[i])) {
std::cerr << "[Combine] x_out mismatch, pe_id=" << pe_id << ", idx=" << i
<< ", actual=" << static_cast<float>(actual[i])
<< ", expected=" << static_cast<float>(expected[i]) << std::endl;
return false;
}
}
return true;
}
size_t AlignUp(size_t value, size_t alignment)
{
return (value + alignment - 1) / alignment * alignment;
}
size_t CombineWindowBytes(size_t bs, size_t h, size_t topk, size_t elem_size)
{
const size_t payload_stride = AlignUp(h * elem_size, ALIGN_BYTES);
return bs * topk * payload_stride + bs * topk * ALIGN_BYTES;
}
bool MakeDirs(const std::string &path)
{
if (path.empty()) {
return true;
}
std::string current;
for (size_t i = 0; i < path.size(); ++i) {
current.push_back(path[i]);
if (path[i] != '/' && i + 1 != path.size()) {
continue;
}
if (current.empty() || current == "/") {
continue;
}
if (mkdir(current.c_str(), 0755) != 0 && errno != EEXIST) {
return false;
}
}
return true;
}
int GetCycleToUs()
{
const char *soc_name = aclrtGetSocName();
if (soc_name != nullptr && std::string(soc_name).find("Ascend950") != std::string::npos) {
return 1000;
}
return 50;
}
double BytesToGBps(size_t bytes, double us)
{
if (us <= 0.0) {
return 0.0;
}
return static_cast<double>(bytes) / us * 1000000.0 / 1024.0 / 1024.0 / 1024.0;
}
std::string DoubleToString(double value)
{
std::ostringstream oss;
oss << std::fixed << std::setprecision(4) << value;
return oss.str();
}
int GetProfPe()
{
const char *prof_pe_env = std::getenv("SHMEM_CYCLE_PROF_PE");
if (prof_pe_env == nullptr) {
return -1;
}
return std::atoi(prof_pe_env);
}
std::vector<double> GetCoreTimesUs(aclshmem_prof_pe_t *profs, int frame_id, int block_num, double *max_time)
{
std::vector<double> core_times;
*max_time = 0.0;
if (profs == nullptr || frame_id < 0 || frame_id >= ACLSHMEM_CYCLE_PROF_FRAME_CNT) {
return core_times;
}
const int actual_blocks = std::min(block_num, ACLSHMEM_CYCLE_PROF_MAX_BLOCK);
const int cycle_to_us = GetCycleToUs();
for (int block_id = 0; block_id < actual_blocks; ++block_id) {
aclshmem_prof_block_t *prof = &profs->block_prof[block_id];
if (prof->ccount[frame_id] == 0) {
continue;
}
const double avg_us = static_cast<double>(prof->cycles[frame_id]) /
static_cast<double>(prof->ccount[frame_id]) / static_cast<double>(cycle_to_us);
*max_time = std::max(*max_time, avg_us);
core_times.push_back(avg_us);
}
return core_times;
}
void AppendPerfCsvRows(const std::string &csv_path, const CombineArgs &args, const PerfArgs &perf_args,
size_t elem_size, int block_num)
{
if (csv_path.empty()) {
return;
}
aclshmemx_show_prof(&out_profs, false);
if (out_profs == nullptr) {
return;
}
const bool file_exists = (access(csv_path.c_str(), F_OK) == 0);
const size_t per_pe_bytes = static_cast<size_t>(args.bs) * args.topk * args.h * elem_size;
const size_t global_bytes = per_pe_bytes * args.pe_size;
const int prof_pe = GetProfPe();
std::string dir = get_dir(csv_path);
if (!dir.empty()) {
MakeDirs(dir);
}
std::ofstream out_file(csv_path, std::ios::app);
if (!out_file.is_open()) {
std::cerr << "[Combine] failed to open perf csv: " << csv_path << std::endl;
return;
}
if (!file_exists) {
out_file << "DataSize/B,Npus,Blocks,UBsize/KB,Bandwidth/GB/s,CoreMaxTime/us,Metric,GlobalDataSize/B,"
"PerPeBandwidth/GB/s,BS,H,TopK,ExpertPerPe,Dtype,Warmup,Loops,ProfPe,CaseId,"
"SingleCoreTime/us\n";
}
const auto write_row = [&](const std::string &metric, int frame_id) {
double max_time = 0.0;
std::vector<double> core_times = GetCoreTimesUs(out_profs, frame_id, block_num, &max_time);
out_file << per_pe_bytes << "," << args.pe_size << "," << block_num << "," << COMBINE_UB_SIZE_KB << ","
<< DoubleToString(BytesToGBps(global_bytes, max_time)) << "," << DoubleToString(max_time) << ","
<< metric << "," << global_bytes << "," << DoubleToString(BytesToGBps(per_pe_bytes, max_time))
<< "," << args.bs << "," << args.h << "," << args.topk << "," << args.expert_per_pe << ","
<< data_type << "," << perf_args.warmup_count << "," << perf_args.loop_count << "," << prof_pe
<< "," << perf_args.case_id;
for (double core_time : core_times) {
out_file << "," << DoubleToString(core_time);
}
out_file << "\n";
};
write_row("full_op", FULL_FRAME_ID);
write_row("comm_only", COMM_FRAME_ID);
}
template <class T>
int RunCombineCase(const CombineArgs &args, const PerfArgs &perf_args)
{
if (args.expert_per_pe <= 0 || args.pe_size <= 0 || args.pe_id < 0 || args.pe_id >= args.pe_size ||
args.bs <= 0 || args.h <= 0 || args.topk <= 0 || g_npus <= 0) {
std::cerr << "[Combine] invalid arguments." << std::endl;
return 1;
}
aclrtStream stream = nullptr;
T *expand_x_host = nullptr;
int32_t *assist_host = nullptr;
int32_t *ep_recv_count_host = nullptr;
int32_t *expert_ids_host = nullptr;
float *expert_scales_host = nullptr;
T *x_out_host = nullptr;
T *golden_x_out = nullptr;
void *expand_x_device = nullptr;
void *assist_device = nullptr;
void *ep_recv_count_device = nullptr;
void *expert_ids_device = nullptr;
void *expert_scales_device = nullptr;
void *x_out_device = nullptr;
void *shmem_window = nullptr;
auto cleanup = [&]() {
if (shmem_window != nullptr) {
aclshmem_free(shmem_window);
shmem_window = nullptr;
}
if (expand_x_device != nullptr) {
ACL_CHECK(aclrtFree(expand_x_device));
expand_x_device = nullptr;
}
if (assist_device != nullptr) {
ACL_CHECK(aclrtFree(assist_device));
assist_device = nullptr;
}
if (ep_recv_count_device != nullptr) {
ACL_CHECK(aclrtFree(ep_recv_count_device));
ep_recv_count_device = nullptr;
}
if (expert_ids_device != nullptr) {
ACL_CHECK(aclrtFree(expert_ids_device));
expert_ids_device = nullptr;
}
if (expert_scales_device != nullptr) {
ACL_CHECK(aclrtFree(expert_scales_device));
expert_scales_device = nullptr;
}
if (x_out_device != nullptr) {
ACL_CHECK(aclrtFree(x_out_device));
x_out_device = nullptr;
}
if (expand_x_host != nullptr) {
ACL_CHECK(aclrtFreeHost(expand_x_host));
expand_x_host = nullptr;
}
if (assist_host != nullptr) {
ACL_CHECK(aclrtFreeHost(assist_host));
assist_host = nullptr;
}
if (ep_recv_count_host != nullptr) {
ACL_CHECK(aclrtFreeHost(ep_recv_count_host));
ep_recv_count_host = nullptr;
}
if (expert_ids_host != nullptr) {
ACL_CHECK(aclrtFreeHost(expert_ids_host));
expert_ids_host = nullptr;
}
if (expert_scales_host != nullptr) {
ACL_CHECK(aclrtFreeHost(expert_scales_host));
expert_scales_host = nullptr;
}
if (x_out_host != nullptr) {
ACL_CHECK(aclrtFreeHost(x_out_host));
x_out_host = nullptr;
}
if (golden_x_out != nullptr) {
ACL_CHECK(aclrtFreeHost(golden_x_out));
golden_x_out = nullptr;
}
if (stream != nullptr) {
ACL_CHECK(aclrtDestroyStream(stream));
stream = nullptr;
}
};
auto check_acl = [&](aclError error, const char *op_name) -> bool {
if (error != ACL_ERROR_NONE) {
std::cerr << "[Combine] " << op_name << " failed, aclError=" << error << std::endl;
return false;
}
return true;
};
auto check_not_null = [&](const void *ptr, const char *name) -> bool {
if (ptr == nullptr) {
std::cerr << "[Combine] " << name << " returned nullptr." << std::endl;
return false;
}
return true;
};
auto alloc_host = [&](void **ptr, size_t bytes, const char *name) -> bool {
if (!check_acl(aclrtMallocHost(ptr, bytes), name)) {
return false;
}
return check_not_null(*ptr, name);
};
auto alloc_device = [&](void **ptr, size_t bytes, const char *name) -> bool {
if (!check_acl(aclrtMalloc(ptr, bytes, ACL_MEM_MALLOC_HUGE_FIRST), name)) {
return false;
}
return check_not_null(*ptr, name);
};
if (!check_acl(aclrtCreateStream(&stream), "aclrtCreateStream") || !check_not_null(stream, "aclrtCreateStream")) {
cleanup();
return 1;
}
const uint64_t fftsAddr = util_get_ffts_config();
const int64_t moe_expert_num = static_cast<int64_t>(args.pe_size) * args.expert_per_pe;
const int64_t max_recv_tokens = static_cast<int64_t>(args.pe_size) * args.bs * args.topk;
const int64_t segment_num = static_cast<int64_t>(args.expert_per_pe) * args.pe_size;
const std::string case_dir = "golden/shape_" + std::to_string(args.bs) + "_" +
std::to_string(args.h) + "_" + std::to_string(args.topk) + "_" +
std::to_string(moe_expert_num) + "_" + std::to_string(args.pe_size);
const std::string rank_dir = case_dir + "/rank_" + std::to_string(args.pe_id);
const size_t expand_x_bytes = static_cast<size_t>(max_recv_tokens) * args.h * sizeof(T);
const size_t assist_bytes = static_cast<size_t>(max_recv_tokens) * ASSIST_FIELDS * sizeof(int32_t);
const size_t ep_recv_count_bytes = static_cast<size_t>(segment_num) * sizeof(int32_t);
const size_t expert_ids_bytes = static_cast<size_t>(args.bs) * args.topk * sizeof(int32_t);
const size_t expert_scales_bytes = static_cast<size_t>(args.bs) * args.topk * sizeof(float);
const size_t x_out_bytes = static_cast<size_t>(args.bs) * args.h * sizeof(T);
const size_t shmem_window_bytes = CombineWindowBytes(args.bs, args.h, args.topk, sizeof(T));
if (!alloc_host(reinterpret_cast<void **>(&expand_x_host), expand_x_bytes, "aclrtMallocHost(expand_x_host)") ||
!alloc_host(reinterpret_cast<void **>(&assist_host), assist_bytes, "aclrtMallocHost(assist_host)") ||
!alloc_host(reinterpret_cast<void **>(&ep_recv_count_host), ep_recv_count_bytes,
"aclrtMallocHost(ep_recv_count_host)") ||
!alloc_host(reinterpret_cast<void **>(&expert_ids_host), expert_ids_bytes, "aclrtMallocHost(expert_ids_host)") ||
!alloc_host(reinterpret_cast<void **>(&expert_scales_host), expert_scales_bytes,
"aclrtMallocHost(expert_scales_host)")) {
cleanup();
return 1;
}
if (!ReadFile(rank_dir + "/expand_x.bin", expand_x_host, expand_x_bytes) ||
!ReadFile(rank_dir + "/assist_info.bin", assist_host, assist_bytes) ||
!ReadFile(rank_dir + "/ep_recv_count.bin", ep_recv_count_host, ep_recv_count_bytes) ||
!ReadFile(rank_dir + "/expert_ids.bin", expert_ids_host, expert_ids_bytes) ||
!ReadFile(rank_dir + "/expert_scales.bin", expert_scales_host, expert_scales_bytes)) {
std::cerr << "[Combine] failed to read input files from " << rank_dir << std::endl;
cleanup();
return 1;
}
if (!alloc_device(&expand_x_device, expand_x_bytes, "aclrtMalloc(expand_x_device)") ||
!alloc_device(&assist_device, assist_bytes, "aclrtMalloc(assist_device)") ||
!alloc_device(&ep_recv_count_device, ep_recv_count_bytes, "aclrtMalloc(ep_recv_count_device)") ||
!alloc_device(&expert_ids_device, expert_ids_bytes, "aclrtMalloc(expert_ids_device)") ||
!alloc_device(&expert_scales_device, expert_scales_bytes, "aclrtMalloc(expert_scales_device)") ||
!alloc_device(&x_out_device, x_out_bytes, "aclrtMalloc(x_out_device)")) {
cleanup();
return 1;
}
shmem_window = aclshmem_malloc(shmem_window_bytes);
if (!check_not_null(shmem_window, "aclshmem_malloc(shmem_window)")) {
cleanup();
return 1;
}
if (!check_acl(aclrtMemcpy(expand_x_device, expand_x_bytes, expand_x_host, expand_x_bytes, ACL_MEMCPY_HOST_TO_DEVICE),
"aclrtMemcpy(expand_x_device)") ||
!check_acl(aclrtMemcpy(assist_device, assist_bytes, assist_host, assist_bytes, ACL_MEMCPY_HOST_TO_DEVICE),
"aclrtMemcpy(assist_device)") ||
!check_acl(aclrtMemcpy(ep_recv_count_device, ep_recv_count_bytes, ep_recv_count_host, ep_recv_count_bytes,
ACL_MEMCPY_HOST_TO_DEVICE),
"aclrtMemcpy(ep_recv_count_device)") ||
!check_acl(aclrtMemcpy(expert_ids_device, expert_ids_bytes, expert_ids_host, expert_ids_bytes,
ACL_MEMCPY_HOST_TO_DEVICE),
"aclrtMemcpy(expert_ids_device)") ||
!check_acl(aclrtMemcpy(expert_scales_device, expert_scales_bytes, expert_scales_host, expert_scales_bytes,
ACL_MEMCPY_HOST_TO_DEVICE),
"aclrtMemcpy(expert_scales_device)") ||
!check_acl(aclrtMemset(x_out_device, x_out_bytes, 0, x_out_bytes), "aclrtMemset(x_out_device)") ||
!check_acl(aclrtMemset(shmem_window, shmem_window_bytes, 0, shmem_window_bytes), "aclrtMemset(shmem_window)") ||
!check_acl(aclrtSynchronizeStream(stream), "aclrtSynchronizeStream")) {
cleanup();
return 1;
}
aclshmem_barrier_all();
const uint32_t block_num = static_cast<uint32_t>(args.pe_size);
const int warmup_count = perf_args.perf_mode ? perf_args.warmup_count : 0;
const int loop_count = perf_args.perf_mode ? perf_args.loop_count : 1;
for (int iter = 0; iter < warmup_count + loop_count; ++iter) {
const int launch_perf_mode = (perf_args.perf_mode && iter >= warmup_count) ? 1 : 0;
combine_demo<T>(block_num, stream, fftsAddr, reinterpret_cast<uint8_t *>(expand_x_device),
reinterpret_cast<int32_t *>(assist_device), reinterpret_cast<int32_t *>(ep_recv_count_device),
reinterpret_cast<int32_t *>(expert_ids_device), reinterpret_cast<float *>(expert_scales_device),
reinterpret_cast<uint8_t *>(x_out_device), reinterpret_cast<uint8_t *>(shmem_window), args.bs,
args.h, args.topk, moe_expert_num, MAGIC_MULTIPLIER, launch_perf_mode, FULL_FRAME_ID,
COMM_FRAME_ID, 0, 1);
if (!check_acl(aclrtSynchronizeStream(stream), "aclrtSynchronizeStream")) {
cleanup();
return 1;
}
aclshmem_barrier_all();
}
if (!alloc_host(reinterpret_cast<void **>(&x_out_host), x_out_bytes, "aclrtMallocHost(x_out_host)") ||
!alloc_host(reinterpret_cast<void **>(&golden_x_out), x_out_bytes, "aclrtMallocHost(golden_x_out)")) {
cleanup();
return 1;
}
if (!check_acl(aclrtMemcpy(x_out_host, x_out_bytes, x_out_device, x_out_bytes, ACL_MEMCPY_DEVICE_TO_HOST),
"aclrtMemcpy(x_out_host)")) {
cleanup();
return 1;
}
if (!ReadFile(rank_dir + "/golden_x_out.bin", golden_x_out, x_out_bytes)) {
std::cerr << "[Combine] failed to read golden output from " << rank_dir << std::endl;
cleanup();
return 1;
}
make_dir("output");
const std::string output_file = "output/x_out_" + std::to_string(args.pe_id) + ".bin";
unlink(output_file.c_str());
WriteFile(output_file, x_out_host, x_out_bytes);
const bool ok = CheckArray(x_out_host, golden_x_out, static_cast<size_t>(args.bs) * args.h, args.pe_id);
if (ok && perf_args.perf_mode && args.pe_id == GetProfPe()) {
AppendPerfCsvRows(perf_args.csv_path, args, perf_args, sizeof(T), static_cast<int>(block_num));
}
cleanup();
if (!ok) {
return 1;
}
std::cout << "[Combine] pe " << args.pe_id << " result correct." << std::endl;
return 0;
}
}
int main(int argc, char *argv[])
{
if (argc < 12) {
std::cerr << "Usage: combine pe_size pe_id ipport g_npus first_pe first_npu data_type bs h topk "
"expert_per_pe"
<< std::endl;
return 1;
}
CombineArgs args;
args.pe_size = std::atoi(argv[INDEX1]);
args.pe_id = std::atoi(argv[INDEX2]);
ipport = argv[INDEX3];
g_npus = std::atoi(argv[INDEX4]);
f_npu = std::atoi(argv[INDEX6]);
data_type = argv[INDEX7];
args.bs = std::atoi(argv[INDEX8]);
args.h = std::atoi(argv[INDEX9]);
args.topk = std::atoi(argv[10]);
args.expert_per_pe = std::atoi(argv[11]);
PerfArgs perf_args;
if (argc > 12) {
perf_args.perf_mode = std::atoi(argv[12]);
}
if (argc > 13) {
perf_args.warmup_count = std::atoi(argv[13]);
}
if (argc > 14) {
perf_args.loop_count = std::atoi(argv[14]);
}
if (argc > 15) {
perf_args.csv_path = argv[15];
}
if (argc > 16) {
perf_args.case_id = argv[16];
}
if (perf_args.case_id.empty()) {
perf_args.case_id = "shape_" + std::to_string(args.bs) + "_" + std::to_string(args.h) + "_" +
std::to_string(args.topk) + "_" +
std::to_string(args.pe_size * args.expert_per_pe) + "_" +
std::to_string(args.pe_size);
}
if (perf_args.warmup_count < 0) {
perf_args.warmup_count = 0;
}
if (perf_args.loop_count <= 0) {
perf_args.loop_count = 1;
}
if (perf_args.perf_mode && std::getenv("SHMEM_CYCLE_PROF_PE") == nullptr) {
setenv("SHMEM_CYCLE_PROF_PE", "0", 0);
}
const int32_t device_id = args.pe_id % g_npus + f_npu;
ACL_CHECK(aclInit(nullptr));
ACL_CHECK(aclrtSetDevice(device_id));
aclshmemx_init_attr_t attributes;
test_set_attr(args.pe_id, args.pe_size, 1024UL * 1024UL * 1024, ipport, default_flag_uid, &attributes);
ACL_CHECK(aclshmemx_init_attr(ACLSHMEMX_INIT_WITH_DEFAULT, &attributes));
int status = 0;
if (std::string(data_type) == "int" || std::string(data_type) == "int32_t") {
status = RunCombineCase<int32_t>(args, perf_args);
} else if (std::string(data_type) == "float16_t") {
status = RunCombineCase<fp16_t>(args, perf_args);
} else {
std::cerr << "[Combine] unsupported data type: " << data_type << std::endl;
status = 1;
}
ACL_CHECK(aclshmem_finalize());
ACL_CHECK(aclrtResetDevice(device_id));
ACL_CHECK(aclFinalize());
if (status != 0) {
return status;
}
std::cout << "[SUCCESS] combine demo run success in pe " << args.pe_id << std::endl;
return 0;
}