* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclgraph_demo_kernel.h"
#include "kernel_operator.h"
#include "acl/acl.h"
#include "shmem.h"
#include "utils/prof/shmemi_prof.h"
#include "add_custom.h"
using namespace AscendC;
constexpr int64_t SYNC_FLAG_INTERVAL = 16;
constexpr int64_t UB_DMA_MAX_SIZE = 190 * 1024;
constexpr int64_t GVA_BUFF_MAX_SIZE = 100 * 1024 * 1024;
template <typename T>
ACLSHMEM_DEVICE void all_gather_origin(__gm__ T *input, __gm__ T *output, __gm__ T *gva, int64_t max_gva_num,
int elements, int len, int64_t magic)
{
const int64_t aiv_num = GetBlockNum();
const int64_t aiv_index = GetBlockIdx();
const int64_t data_offset = aiv_num * SYNC_FLAG_INTERVAL;
const int64_t flag_offset = aiv_index * SYNC_FLAG_INTERVAL;
int32_t my_rank = aclshmem_my_pe();
int32_t pe_size = aclshmem_n_pes();
__gm__ T *input_gm = (__gm__ T *)input;
__gm__ T *output_gm = (__gm__ T *)output;
__gm__ T *gva_data_gm = (__gm__ T *)((__gm__ int32_t *)gva + data_offset);
__gm__ int32_t *gva_sync_gm = (__gm__ int32_t *)gva;
__ubuf__ int32_t *flags_ub1[16];
__ubuf__ int32_t *flags_ub2[16];
for (int i = 0; i * 8 < 128; i++) {
flags_ub1[i] = (__ubuf__ int32_t *)(32) + i * 16;
flags_ub2[i] = (__ubuf__ int32_t *)(544) + i * 16;
}
int core_group_num = aiv_num / 2;
int core_per_rank = core_group_num / pe_size;
int len_per_core = len / core_group_num;
int group_per_num = len_per_core;
if (aiv_index == core_group_num - 1) {
group_per_num = len - group_per_num * aiv_index;
}
if (aiv_index < core_group_num) {
__ubuf__ T *tmp_buff = reinterpret_cast<__ubuf__ T *>(uint64_t(1024 + 32));
uint32_t copy_ub_size = UB_DMA_MAX_SIZE;
uint32_t copy_ub_num = copy_ub_size / sizeof(T);
uint32_t copy_total_size = group_per_num * sizeof(T);
int64_t times = 0;
int64_t flag = 0;
while (copy_total_size >= copy_ub_size) {
SHMEMI_PROF_START(0);
aclshmemx_mte_put_nbi(gva_data_gm + aiv_index * len_per_core + times * copy_ub_num,
input_gm + aiv_index * len_per_core + times * copy_ub_num, tmp_buff, copy_ub_size,
copy_ub_num, my_rank, EVENT_ID0);
SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
times += 1;
flag = times + magic;
aclshmemx_signal_op(gva_sync_gm + flag_offset, flag, ACLSHMEM_SIGNAL_SET, my_rank);
SHMEMI_PROF_END(0);
SetFlag<HardEvent::S_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::S_MTE2>(EVENT_ID0);
SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
copy_total_size -= copy_ub_size;
}
if (copy_total_size <= 0) {
return;
}
aclshmemx_mte_put_nbi(gva_data_gm + aiv_index * len_per_core + times * copy_ub_num,
input_gm + aiv_index * len_per_core + times * copy_ub_num, tmp_buff, copy_ub_size,
copy_total_size / sizeof(T), my_rank, EVENT_ID0);
SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
times += 1;
flag = times + magic;
aclshmemx_signal_op(gva_sync_gm + flag_offset, flag, ACLSHMEM_SIGNAL_SET, my_rank);
return;
}
for (int64_t i = 0; i < core_group_num; i++) {
*flags_ub1[i] = 0;
*flags_ub2[i] = 0;
}
__ubuf__ T *ping_buff = reinterpret_cast<__ubuf__ T *>(uint64_t(1 * 1024 + 32));
__ubuf__ T *pong_buff = reinterpret_cast<__ubuf__ T *>(uint64_t(96 * 1024 + 32));
uint32_t copy_ub_size = UB_DMA_MAX_SIZE / 2;
uint32_t copy_ub_num = copy_ub_size / sizeof(T);
int x = (aiv_index - core_group_num) / core_per_rank;
int pingpong_id = 0;
SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
while (true) {
for (int group_idx = 0; group_idx < core_group_num; group_idx++) {
if (*flags_ub1[group_idx] == INT32_MAX) {
continue;
}
int64_t all_data_size = len_per_core * sizeof(T);
if (group_idx == core_group_num - 1) {
all_data_size = (len - group_idx * len_per_core) * sizeof(T);
}
if (*flags_ub1[group_idx] * UB_DMA_MAX_SIZE >= all_data_size) {
*flags_ub1[group_idx] = INT32_MAX;
continue;
}
aclshmem_int32_get_nbi(flags_ub2[group_idx], gva_sync_gm + group_idx * SYNC_FLAG_INTERVAL, 1, x);
PipeBarrier<PIPE_ALL>();
if ((*flags_ub2[group_idx] >> 10) != (magic >> 10)) {
continue;
}
int64_t ready_num = *flags_ub2[group_idx] - magic;
if (ready_num <= 0 || *flags_ub1[group_idx] >= ready_num) {
continue;
}
int group_recv_offset = x * elements + group_idx * len_per_core;
int group_send_offset = group_idx * len_per_core;
int send_offset = *flags_ub1[group_idx] * UB_DMA_MAX_SIZE / sizeof(T);
int recv_offset = *flags_ub1[group_idx] * UB_DMA_MAX_SIZE / sizeof(T);
int num_total = (ready_num - *flags_ub1[group_idx]) * UB_DMA_MAX_SIZE / sizeof(T);
if (ready_num * UB_DMA_MAX_SIZE > all_data_size) {
num_total = (all_data_size - *flags_ub1[group_idx] * UB_DMA_MAX_SIZE) / sizeof(T);
}
PipeBarrier<PIPE_ALL>();
for (int i = 0; num_total > 0; i++) {
TEventID event_id = pingpong_id == 0 ? EVENT_ID0 : EVENT_ID1;
__ubuf__ T *buf = pingpong_id == 0 ? ping_buff : pong_buff;
uint32_t copy_num = num_total > copy_ub_num ? copy_ub_num : num_total;
WaitFlag<HardEvent::MTE3_MTE2>(event_id);
aclshmemx_mte_get_nbi(output_gm + group_recv_offset + recv_offset,
gva_data_gm + group_send_offset + send_offset, buf, copy_ub_size, copy_num, x,
event_id);
SetFlag<HardEvent::MTE3_MTE2>(event_id);
send_offset += copy_num;
recv_offset += copy_num;
num_total -= copy_num;
pingpong_id = 1 - pingpong_id;
}
PipeBarrier<PIPE_ALL>();
*flags_ub1[group_idx] = ready_num;
PipeBarrier<PIPE_ALL>();
}
bool finished = true;
for (int64_t group_idx = 0; group_idx < core_group_num; group_idx++) {
if (*flags_ub1[group_idx] != INT32_MAX) {
finished = false;
break;
}
}
if (finished) {
break;
}
}
WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
}
template <typename T>
ACLSHMEM_DEVICE void all_gather_big_data(uint64_t ffts_addr, __gm__ T *input, __gm__ T *output, __gm__ T *gva, int elements,
__gm__ int *magic)
{
util_set_ffts_config(ffts_addr);
int magic_data = *magic;
const int64_t max_gva_num = GVA_BUFF_MAX_SIZE / sizeof(T);
int times = (elements + max_gva_num - 1) / max_gva_num;
int total_num = elements;
__ubuf__ int64_t *ctrl_ub = (__ubuf__ int64_t *)(0);
for (int i = 0; i < times; i++) {
*ctrl_ub = 0;
PipeBarrier<PIPE_ALL>();
int32_t len = total_num > max_gva_num ? max_gva_num : total_num;
aclshmemx_barrier_all_vec();
all_gather_origin(input + i * max_gva_num, output + i * max_gva_num, gva, max_gva_num, elements, len,
(magic_data + i) * 1024);
total_num -= max_gva_num;
PipeBarrier<PIPE_ALL>();
}
}
template <typename T>
ACLSHMEM_DEVICE void all_gather_small_data(uint64_t ffts_addr, __gm__ T *input, __gm__ T *output, __gm__ T *gva,
int elements, __gm__ int *magic)
{
const int64_t aiv_num = GetBlockNum();
const int64_t aiv_index = GetBlockIdx();
const int64_t data_offset = aiv_num * SYNC_FLAG_INTERVAL;
const int64_t flag_offset = aiv_index * SYNC_FLAG_INTERVAL;
int64_t my_rank = aclshmem_my_pe();
int64_t pe_size = aclshmem_n_pes();
int magic_data = *magic;
__gm__ T *input_gm = (__gm__ T *)input;
__gm__ T *output_gm = (__gm__ T *)output;
__gm__ T *gva_data_gm = (__gm__ T *)((__gm__ int32_t *)gva + data_offset);
__gm__ int32_t *gva_sync_gm = (__gm__ int32_t *)gva;
__ubuf__ T *tmp_buff = (__ubuf__ T *)(64);
const uint32_t ub_size = UB_DMA_MAX_SIZE;
uint32_t input_offset, gva_offset, num_per_core;
num_per_core = elements / aiv_num;
input_offset = aiv_index * num_per_core;
gva_offset = aiv_index * num_per_core;
if (aiv_index == aiv_num - 1) {
num_per_core = elements - num_per_core * aiv_index;
}
aclshmemx_mte_put_nbi(gva_data_gm + gva_offset, input_gm + input_offset, tmp_buff, ub_size, num_per_core, my_rank,
EVENT_ID0);
const int64_t core_per_rank = aiv_num / pe_size;
const int64_t core_rank_idx = aiv_index % core_per_rank;
const int64_t x = aiv_index / core_per_rank;
aclshmem_quiet();
aclshmemx_signal_op(gva_sync_gm + flag_offset, magic_data, ACLSHMEM_SIGNAL_SET, my_rank);
aclshmem_signal_wait_until((__gm__ int32_t *)aclshmem_ptr(gva_sync_gm, x) + flag_offset, ACLSHMEM_CMP_EQ, magic_data);
num_per_core = elements / core_per_rank;
uint32_t output_offset = x * elements + core_rank_idx * num_per_core;
gva_offset = core_rank_idx * num_per_core;
if (core_rank_idx == core_per_rank - 1) {
num_per_core = elements - num_per_core * core_rank_idx;
}
aclshmemx_mte_get_nbi(output_gm + output_offset, gva_data_gm + gva_offset, tmp_buff, ub_size, num_per_core, x,
EVENT_ID0);
}
#define ALLGATHER_FUNC_DEF(type) \
extern "C" [[bisheng::core_ratio(0,1)]] __global__ __aicore__ void ShmemAllGather_##type(uint64_t ffts_addr, GM_ADDR input, GM_ADDR output, \
GM_ADDR gva, int elements, __gm__ int *magic) \
{ \
if (elements * sizeof(type) < 2097152) { \
all_gather_small_data<type>(ffts_addr, (__gm__ type *)input, (__gm__ type *)output, (__gm__ type *)gva, \
elements, magic); \
} else { \
all_gather_big_data<type>(ffts_addr, (__gm__ type *)input, (__gm__ type *)output, (__gm__ type *)gva, \
elements, magic); \
} \
}
#define TYPE_FUNC(fun) fun(int)
TYPE_FUNC(ALLGATHER_FUNC_DEF);
template <class T>
void allgather_demo(uint32_t block_dim, void *stream, uint64_t ffts_addr, uint8_t *input, uint8_t *output, uint8_t *gva,
int elements, int *magic)
{
if (std::is_same<T, int>::value) {
ShmemAllGather_int<<<block_dim, nullptr, stream>>>(ffts_addr, input, output, gva, elements, magic);
}
}
template void allgather_demo<int>(uint32_t block_dim, void *stream, uint64_t ffts_addr, uint8_t *input, uint8_t *output,
uint8_t *gva, int elements, int *magic);
template void run_vector_add<int>(int64_t numElements, void *a, void *b, void *c, void *stream);