* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <iostream>
#include <cstdlib>
#include <string>
#include <vector>
#include "utils.h"
#include "param.h"
#include "acl/acl.h"
#include "acl/acl_mdl.h"
#include "shmem.h"
#include "aclgraph_demo_kernel.h"
int g_npus = 8;
const char *ipport = "tcp://127.0.0.1:8998";
int f_pe = 0;
int f_npu = 0;
const char *data_type = "int";
int loop_times = 3;
constexpr int64_t SYNC_FLAG_INTERVAL = 16;
constexpr int64_t UB_DMA_MAX_SIZE = 190 * 1024;
constexpr int64_t GVA_BUFF_MAX_SIZE = 100 * 1024 * 1024;
constexpr uint32_t MAGIC_MULTIPLIER = 1024;
constexpr uint32_t DATA_SIZE_THRESHOLD = 2097152;
constexpr uint32_t BLOCK_NUM_SMALL_DATA = 8;
constexpr uint32_t BLOCK_NUM_LARGE_DATA = 16;
aclshmemx_uniqueid_t default_flag_uid;
template <class T>
void run_vector_add(int64_t numElements, void *a, void *b, void *c, void *stream);
template <class T>
int test_aclshmem_all_gather(int pe_id, int n_pes)
{
int status = 0;
aclrtStream stream = nullptr;
status = aclrtCreateStream(&stream);
uint64_t ffts_addr = util_get_ffts_config();
int case_num = 2;
std::vector<uint32_t> test_cases = {};
for (int i = 0; i < case_num; i++) {
int data_len = 262144 * (1 << i);
test_cases.push_back(data_len);
}
uint32_t BLOCK_NUM = 8;
int magic = 1;
for (int i = 0; i < test_cases.size(); i++) {
aclmdlRI model = nullptr;
if (pe_id == 0) {
std::cout << "Case: " << test_cases[i] << " Started." << std::endl;
}
uint32_t trans_size = test_cases[i];
if (trans_size * sizeof(T) < DATA_SIZE_THRESHOLD) {
BLOCK_NUM = BLOCK_NUM_SMALL_DATA;
} else {
BLOCK_NUM = BLOCK_NUM_LARGE_DATA;
}
void *input_ptr;
aclrtMalloc(&input_ptr, trans_size * sizeof(T), ACL_MEM_MALLOC_HUGE_FIRST);
uint8_t *input_host;
aclrtMallocHost(reinterpret_cast<void **>(&input_host), trans_size * sizeof(T));
std::string inputFile = "../../examples/aclgraph_demo/golden/allgather_" + std::to_string(trans_size) + "_" +
std::to_string(n_pes) + "/input_gm_" + std::to_string(pe_id) + ".bin";
ReadFile(inputFile, input_host, trans_size * sizeof(T));
aclrtMemcpy(input_ptr, trans_size * sizeof(T), input_host, trans_size * sizeof(T), ACL_MEMCPY_HOST_TO_DEVICE);
void *input_b_ptr;
aclrtMalloc(&input_b_ptr, trans_size * sizeof(T), ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMemcpy(input_b_ptr, trans_size * sizeof(T), input_host, trans_size * sizeof(T), ACL_MEMCPY_HOST_TO_DEVICE);
void *input_c_ptr;
aclrtMalloc(&input_c_ptr, trans_size * sizeof(T), ACL_MEM_MALLOC_HUGE_FIRST);
void *output_ptr;
aclrtMalloc(&output_ptr, trans_size * n_pes * sizeof(T), ACL_MEM_MALLOC_HUGE_FIRST);
void *output_b_ptr;
aclrtMalloc(&output_b_ptr, trans_size * n_pes * sizeof(T), ACL_MEM_MALLOC_HUGE_FIRST);
void *output_c_ptr;
aclrtMalloc(&output_c_ptr, trans_size * n_pes * sizeof(T), ACL_MEM_MALLOC_HUGE_FIRST);
int aiv_num = BLOCK_NUM;
void *ptr = aclshmem_malloc(aiv_num * SYNC_FLAG_INTERVAL * sizeof(T) + GVA_BUFF_MAX_SIZE / sizeof(T));
int *magic_host_ptr;
aclrtMallocHost(reinterpret_cast<void **>(&magic_host_ptr), sizeof(int));
int magic_value;
int magic_value_2;
int *magic_ptr;
aclrtMalloc(reinterpret_cast<void **>(&magic_ptr), sizeof(int), ACL_MEM_MALLOC_HUGE_FIRST);
int *magic_ptr2;
aclrtMalloc(reinterpret_cast<void **>(&magic_ptr2), sizeof(int), ACL_MEM_MALLOC_HUGE_FIRST);
T *output_host;
size_t output_size = n_pes * trans_size * sizeof(T);
status = aclrtMallocHost(reinterpret_cast<void **>(&output_host), output_size);
T *golden_host;
status = aclrtMallocHost(reinterpret_cast<void **>(&golden_host), output_size);
std::string goldenFile = "../../examples/aclgraph_demo/golden/allgather_" +
std::to_string(trans_size) + "_" + std::to_string(n_pes) + "/golden.bin";
ReadFile(goldenFile, golden_host, n_pes * trans_size * sizeof(T));
for (int zz = 0; zz < loop_times; zz++) {
magic++;
magic_value = magic * MAGIC_MULTIPLIER;
magic_value_2 = (magic + loop_times + 10) * MAGIC_MULTIPLIER;
*magic_host_ptr = magic_value;
aclrtMemcpy(magic_ptr, sizeof(int), magic_host_ptr, sizeof(int), ACL_MEMCPY_HOST_TO_DEVICE);
*magic_host_ptr = magic_value_2;
aclrtMemcpy(magic_ptr2, sizeof(int), magic_host_ptr, sizeof(int), ACL_MEMCPY_HOST_TO_DEVICE);
if (zz == 0) {
aclmdlRICaptureBegin(stream, ACL_MODEL_RI_CAPTURE_MODE_RELAXED);
run_vector_add<T>(trans_size, input_ptr, input_b_ptr, input_c_ptr, stream);
aclrtMemcpyAsync(input_ptr, trans_size * sizeof(T), input_c_ptr, trans_size * sizeof(T),
ACL_MEMCPY_DEVICE_TO_DEVICE, stream);
allgather_demo<T>(BLOCK_NUM, stream, ffts_addr, (uint8_t *)input_ptr,
(uint8_t *)output_ptr, (uint8_t *)ptr, trans_size, magic_ptr);
allgather_demo<T>(BLOCK_NUM, stream, ffts_addr, (uint8_t *)input_b_ptr,
(uint8_t *)output_b_ptr, (uint8_t *)ptr, trans_size, magic_ptr2);
run_vector_add<T>(n_pes * trans_size, output_ptr, output_b_ptr, output_c_ptr, stream);
aclmdlRICaptureEnd(stream, &model);
aclmdlRIExecuteAsync(model, stream);
} else {
aclmdlRIExecuteAsync(model, stream);
}
status = aclrtSynchronizeStream(stream);
if (status != 0) {
std::cout << "sync stream failed\n";
}
aclrtMemcpy(output_host, output_size, output_c_ptr, output_size, ACL_MEMCPY_DEVICE_TO_HOST);
for (int zzz = 0; zzz < n_pes * trans_size; zzz++) {
if (static_cast<int>(output_host[zzz]) != static_cast<int>(golden_host[zzz]) * (zz + 3)) {
std::cout << "ERROR each loop: " << static_cast<int>(output_host[zzz]) << " != "
<< static_cast<int>(golden_host[zzz]) * (zz + 3)
<< ", trans_size: " << trans_size << ", idx is: " << zzz
<< ", pe_id is: " << pe_id << std::endl;
status = -1;
break;
}
}
}
aclshmemx_show_prof(nullptr, true);
aclrtFreeHost(input_host);
aclrtFreeHost(magic_host_ptr);
aclrtFreeHost(output_host);
aclrtFreeHost(golden_host);
aclshmem_free(ptr);
aclrtFree(input_ptr);
aclrtFree(input_b_ptr);
aclrtFree(input_c_ptr);
aclrtFree(output_ptr);
aclrtFree(output_b_ptr);
aclrtFree(output_c_ptr);
aclrtFree(magic_ptr);
aclrtFree(magic_ptr2);
if (pe_id == 0) {
std::cout << "Case: " << test_cases[i] << " Finished !! Result Correct !!" << std::endl;
}
aclmdlRIDestroy(model);
}
aclrtDestroyStream(stream);
return status;
}
int main(int argc, char *argv[])
{
if (argc != INDEX9) {
std::cerr << "Usage: " << argv[0]
<< " <n_pes> <pe_id> <ipport> <g_npus> <f_pe> <f_npu> <data_type> <loop_times>" << std::endl;
return EXIT_FAILURE;
}
int status = 0;
int n_pes = atoi(argv[INDEX1]);
int pe_id = atoi(argv[INDEX2]);
ipport = argv[INDEX3];
g_npus = atoi(argv[INDEX4]);
f_pe = atoi(argv[INDEX5]);
f_npu = atoi(argv[INDEX6]);
data_type = argv[INDEX7];
loop_times = atoi(argv[INDEX8]);
int32_t device_id = pe_id % g_npus + f_npu;
status = aclInit(nullptr);
status = aclrtSetDevice(device_id);
uint64_t local_mem_size = 1024UL * 1024UL * 1024;
aclshmemx_init_attr_t attributes;
test_set_attr(pe_id, n_pes, local_mem_size, ipport, default_flag_uid, &attributes);
status = aclshmemx_init_attr(ACLSHMEMX_INIT_WITH_DEFAULT, &attributes);
if (std::string(data_type) == "int") {
status = test_aclshmem_all_gather<int>(pe_id, n_pes);
} else {
std::cerr << "only support data_type == int" << std::endl;
return EXIT_FAILURE;
}
status = aclshmem_finalize();
status = aclrtResetDevice(device_id);
status = aclFinalize();
if (status) {
std::exit(EXIT_FAILURE);
}
std::cout << "[SUCCESS] demo run success in pe " << pe_id << std::endl;
return 0;
}