* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>
#include <iostream>
#include <cmath>
#include "acl/acl.h"
#include "shmemi_host_common.h"
#include "gm2gm/shmemi_device_cc_kernel.h"
uint64_t g_team_mask = 0;
aclshmemx_team_t *g_aclshmem_team_pool = nullptr;
#define ACLSHMEM_TEAM_CHECK_SINGLE_INSTANCE(func_name) \
do { \
if (g_instance_ctx->id != 0) { \
SHM_LOG_ERROR((func_name) << " is not supported in multi-instance mode (instance_id=" \
<< g_instance_ctx->id << "). " \
<< "Only Instance 0 supports team operations. Please Create new instance instead."); \
return ACLSHMEM_NOT_SUPPORTED; \
} \
} while (0)
inline std::string team_config2string(aclshmemx_team_t *config)
{
std::ostringstream oss;
oss << "[team:" << config->team_idx;
oss << ",npes:" << config->size;
oss << ",mype:" << config->mype;
oss << ",start:" << config->start;
oss << ",stride:" << config->stride;
oss << "]";
return oss.str();
}
inline bool is_valid_team(aclshmem_team_t &team)
{
return (g_state.is_aclshmem_initialized && g_aclshmem_team_pool != nullptr && team >= 0 && team < ACLSHMEM_MAX_TEAMS &&
((g_team_mask >> team) & 1));
}
inline void device_team_destroy(int32_t team_idx)
{
aclshmemx_team_t *device_team_ptr = g_state.team_pools[team_idx];
if (device_team_ptr != nullptr) {
if (aclrtFree((void *)device_team_ptr) != ACL_SUCCESS) {
SHM_LOG_ERROR("aclrtFree for device_team_ptr failed for team " << team_idx);
}
g_state.team_pools[team_idx] = nullptr;
}
}
inline int32_t device_team_update(int team_idx, aclshmemx_team_t *host_team_ptr)
{
void *team_ptr = nullptr;
ACLSHMEM_CHECK_RET(aclrtMalloc(&team_ptr, sizeof(aclshmemx_team_t), ACL_MEM_MALLOC_NORMAL_ONLY));
auto ret = aclrtMemcpy((aclshmemx_team_t *)team_ptr, sizeof(aclshmemx_team_t), host_team_ptr, sizeof(aclshmemx_team_t),
ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != 0) {
SHM_LOG_ERROR("memcpy device team info failed, ret: " << ret);
ACLSHMEM_CHECK_RET(aclrtFree(team_ptr));
return ACLSHMEM_INNER_ERROR;
}
g_state.team_pools[team_idx] = (aclshmemx_team_t *)team_ptr;
return ACLSHMEM_SUCCESS;
}
int32_t aclshmemi_team_init_sync_pool()
{
g_state.sync_pool = (uint64_t)aclshmem_malloc(SYNC_POOL_SIZE);
if (g_state.sync_pool == 0) {
aclshmemi_team_finalize();
SHM_LOG_ERROR("malloc sync pool failed.");
return ACLSHMEM_INNER_ERROR;
}
auto ret = aclrtMemset((void *)g_state.sync_pool, SYNC_POOL_SIZE, 0, SYNC_POOL_SIZE);
if (ret != 0) {
aclshmemi_team_finalize();
SHM_LOG_ERROR("memset sync pool failed. ret=" << ret);
return ACLSHMEM_INNER_ERROR;
}
return ACLSHMEM_SUCCESS;
}
int32_t aclshmemi_team_init_sync_counter()
{
g_state.sync_counter = (uint64_t)aclshmem_malloc(SYNC_COUNTERS_SIZE);
if (g_state.sync_counter == 0) {
aclshmemi_team_finalize();
SHM_LOG_ERROR("malloc sync counter failed.");
return ACLSHMEM_INNER_ERROR;
}
auto ret = aclrtMemset((void *)g_state.sync_counter, SYNC_COUNTERS_SIZE, 0, SYNC_COUNTERS_SIZE);
if (ret != 0) {
aclshmemi_team_finalize();
SHM_LOG_ERROR("memset sync counter failed.");
return ACLSHMEM_INNER_ERROR;
}
return ACLSHMEM_SUCCESS;
}
int32_t aclshmemi_team_init_core_sync_pool()
{
auto ret = aclrtMalloc((void **)&(g_state.core_sync_pool), ACLSHMEM_CORE_SYNC_POOL_SIZE, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != 0 || g_state.core_sync_pool == 0) {
aclshmemi_team_finalize();
SHM_LOG_ERROR("malloc core sync pool failed.");
return ACLSHMEM_INNER_ERROR;
}
ret = aclrtMemset((void *)g_state.core_sync_pool, ACLSHMEM_CORE_SYNC_POOL_SIZE, 0, ACLSHMEM_CORE_SYNC_POOL_SIZE);
if (ret != 0) {
aclshmemi_team_finalize();
SHM_LOG_ERROR("memset core sync pool failed.");
return ACLSHMEM_INNER_ERROR;
}
return ACLSHMEM_SUCCESS;
}
int32_t aclshmemi_team_init_core_sync_counter()
{
auto ret = aclrtMalloc((void **)&(g_state.core_sync_counter), ACLSHMEM_CORE_SYNC_COUNTER_SIZE,
ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != 0 || g_state.core_sync_counter == 0) {
aclshmemi_team_finalize();
SHM_LOG_ERROR("malloc core sync counter failed.");
return ACLSHMEM_INNER_ERROR;
}
ret = aclrtMemset((void *)g_state.core_sync_counter, ACLSHMEM_CORE_SYNC_COUNTER_SIZE, 0, ACLSHMEM_CORE_SYNC_COUNTER_SIZE);
if (ret != 0) {
aclshmemi_team_finalize();
SHM_LOG_ERROR("memset core sync counter failed.");
return ACLSHMEM_INNER_ERROR;
}
return ACLSHMEM_SUCCESS;
}
int32_t aclshmemi_team_init(int32_t rank, int32_t size)
{
int team_size = g_instance_ctx->id == 0 ? ACLSHMEM_MAX_TEAMS : 1;
g_aclshmem_team_pool = (aclshmemx_team_t *)calloc(team_size, sizeof(aclshmemx_team_t));
if (g_aclshmem_team_pool == nullptr) {
SHM_LOG_ERROR("malloc host shmem team pool failed.");
return ACLSHMEM_INNER_ERROR;
}
for (int i = 0; i < team_size; i++) {
g_aclshmem_team_pool[i] = aclshmemx_team_t{-1, -1, -1, -1, -1, {0, 0, 0,{'0'}}, {-1}};
}
aclshmemx_team_t &aclshmem_team_world = g_aclshmem_team_pool[ACLSHMEM_TEAM_WORLD];
aclshmem_team_world.team_idx = ACLSHMEM_TEAM_WORLD;
aclshmem_team_world.start = 0;
aclshmem_team_world.stride = 1;
aclshmem_team_world.size = size;
aclshmem_team_world.mype = rank;
g_team_mask |= 1ULL << ACLSHMEM_TEAM_WORLD;
aclshmemi_team_populate_pe_mappings_from_constant_stride(&aclshmem_team_world);
ACLSHMEM_CHECK_RET(device_team_update(ACLSHMEM_TEAM_WORLD, &aclshmem_team_world));
auto ret = aclshmemi_team_init_sync_pool();
if (ret != 0) {
return ret;
}
ret = aclshmemi_team_init_sync_counter();
if (ret != 0) {
return ret;
}
ret = aclshmemi_team_init_core_sync_pool();
if (ret != 0) {
return ret;
}
return aclshmemi_team_init_core_sync_counter();
}
int32_t first_free_idx_fetch()
{
int32_t aclshmem_max_teams = ACLSHMEM_MAX_TEAMS;
for (int32_t i = 0; i < aclshmem_max_teams; i++) {
if (!((g_team_mask >> i) & 1)) {
g_team_mask |= 1ULL << i;
return i;
}
}
return -1;
}
int32_t aclshmemi_team_finalize()
{
int team_size = g_instance_ctx->id == 0 ? ACLSHMEM_MAX_TEAMS : 1;
for (int32_t i = 0; i < team_size; i++) {
aclshmem_team_t team = i;
if (is_valid_team(team)) {
aclshmem_team_destroy(team);
}
}
if (g_state.sync_counter != 0) {
aclshmem_free(reinterpret_cast<void *>(g_state.sync_counter));
g_state.sync_counter = 0;
}
if (g_state.sync_pool != 0) {
aclshmem_free(reinterpret_cast<void *>(g_state.sync_pool));
g_state.sync_pool = 0;
}
if (g_state.core_sync_counter != 0) {
ACLSHMEM_CHECK_RET(aclrtFree(reinterpret_cast<void *>(g_state.core_sync_counter)));
g_state.core_sync_counter = 0;
}
if (g_state.core_sync_pool != 0) {
ACLSHMEM_CHECK_RET(aclrtFree(reinterpret_cast<void *>(g_state.core_sync_pool)));
g_state.core_sync_pool = 0;
}
if (g_aclshmem_team_pool != nullptr) {
free(g_aclshmem_team_pool);
g_aclshmem_team_pool = nullptr;
}
return 0;
}
void aclshmemi_team_populate_from_world_pe_mapping(aclshmemx_team_t *team)
{
for (int i = 0; i < team->size; i++) {
int global_pe_index = team->pe_mapping[i];
team->pe_mapping[global_pe_index + ACLSHMEM_MAX_PES] = i;
}
}
void aclshmemi_team_populate_pe_mappings_from_constant_stride(aclshmemx_team_t *team)
{
for (int i = 0; i < team->size; i++) {
int global_pe_index = team->start + i * team->stride;
team->pe_mapping[i] = global_pe_index;
}
aclshmemi_team_populate_from_world_pe_mapping(team);
}
int32_t aclshmemi_team_pe_mapping(aclshmem_team_t team, int pe)
{
if (!is_valid_team(team)) {
SHM_LOG_ERROR("input team is invalid!, team: " << team);
return ACLSHMEM_INVALID_PARAM;
}
return g_aclshmem_team_pool[team].pe_mapping[pe];
}
int32_t aclshmem_team_split_strided_precheck(aclshmem_team_t parent_team, int32_t pe_start, int32_t pe_stride,
int32_t pe_size, aclshmem_team_t *new_team)
{
if (new_team == nullptr) {
SHM_LOG_ERROR("output team is null.");
return ACLSHMEM_INVALID_PARAM;
}
*new_team = ACLSHMEM_TEAM_INVALID;
if (!is_valid_team(parent_team)) {
SHM_LOG_ERROR("input parent team is invalid!, team: " << parent_team);
return ACLSHMEM_INVALID_PARAM;
}
aclshmemx_team_t *src_team = &g_aclshmem_team_pool[parent_team];
if (pe_start >= ACLSHMEM_MAX_PES || pe_stride >= ACLSHMEM_MAX_PES || pe_size > ACLSHMEM_MAX_PES) {
SHM_LOG_ERROR("create team failed, input invalid, pe_start:" << pe_start << " pe_size:" << pe_size
<< " pe_stride:" << pe_stride << " parent:"
<< team_config2string(src_team));
return ACLSHMEM_INVALID_PARAM;
}
return ACLSHMEM_SUCCESS;
}
int32_t aclshmem_team_split_strided(aclshmem_team_t parent_team, int32_t pe_start, int32_t pe_stride, int32_t pe_size,
aclshmem_team_t *new_team)
{
ACLSHMEM_TEAM_CHECK_SINGLE_INSTANCE(__func__);
auto ret = aclshmem_team_split_strided_precheck(parent_team, pe_start, pe_stride, pe_size, new_team);
if (ret != ACLSHMEM_SUCCESS) {
return ret;
}
aclshmemx_team_t *src_team = &g_aclshmem_team_pool[parent_team];
if (pe_start < 0 || pe_start >= src_team->size || pe_size <= 0 || pe_size > src_team->size || pe_stride < 1) {
SHM_LOG_ERROR("create team failed, input invalid, pe_start:" << pe_start << " pe_size:" << pe_size
<< " pe_stride:" << pe_stride << " parent:"
<< team_config2string(src_team));
return ACLSHMEM_INVALID_PARAM;
}
int32_t global_pe = src_team->mype;
int32_t global_pe_start = src_team->pe_mapping[pe_start];
int32_t global_pe_stride = src_team->stride * pe_stride;
int32_t global_pe_end = global_pe_start + global_pe_stride * (pe_size - 1);
if (global_pe_start >= aclshmem_n_pes() || global_pe_end >= aclshmem_n_pes()) {
SHM_LOG_ERROR("create team failed, large than world size, pe_start:"
<< pe_start << " pe_size:" << pe_size << " pe_stride:" << pe_stride
<< " world_size:" << aclshmem_n_pes() << " parent:" << team_config2string(src_team));
return ACLSHMEM_INVALID_PARAM;
}
aclshmemx_team_t my_team;
my_team.mype = (global_pe - global_pe_start) / global_pe_stride;
if (global_pe < global_pe_start || (global_pe - global_pe_start) % global_pe_stride || my_team.mype >= pe_size) {
SHM_LOG_INFO("This PE is not a member of the new team.");
return ACLSHMEM_SUCCESS;
}
my_team.start = global_pe_start;
my_team.stride = global_pe_stride;
my_team.size = pe_size;
for (int i = 0; i < pe_size; i++) {
int temp_global_pe = global_pe_start + i * global_pe_stride;
my_team.pe_mapping[i] = temp_global_pe;
my_team.pe_mapping[temp_global_pe + ACLSHMEM_MAX_PES] = i;
}
my_team.team_idx = first_free_idx_fetch();
if (my_team.team_idx == -1) {
SHM_LOG_ERROR("create team failed, team num is full!");
return ACLSHMEM_INNER_ERROR;
}
g_aclshmem_team_pool[my_team.team_idx] = my_team;
if (device_team_update(my_team.team_idx, &g_aclshmem_team_pool[my_team.team_idx]) != 0) {
aclshmem_team_destroy(my_team.team_idx);
SHM_LOG_ERROR("create team failed, malloc device state failed!");
return ACLSHMEM_INNER_ERROR;
}
if (update_device_state() != 0) {
aclshmem_team_destroy(my_team.team_idx);
SHM_LOG_ERROR("create team failed, update state failed!");
return ACLSHMEM_INNER_ERROR;
}
SHM_LOG_INFO("create team success:" << team_config2string(&my_team));
*new_team = my_team.team_idx;
return 0;
}
int aclshmemi_team_split_2d_precheck(aclshmem_team_t p_team, int x_range, aclshmem_team_t *x_team, aclshmem_team_t *y_team)
{
if (x_team == nullptr || y_team == nullptr) {
SHM_LOG_ERROR("output team is null.");
return ACLSHMEM_INVALID_PARAM;
}
if (x_range <= 0) {
SHM_LOG_ERROR("input x range must be larger than 0.");
return ACLSHMEM_INVALID_PARAM;
}
*x_team = ACLSHMEM_TEAM_INVALID;
*y_team = ACLSHMEM_TEAM_INVALID;
if (!is_valid_team(p_team)) {
SHM_LOG_ERROR("input parent team is invalid!, team: " << p_team);
return ACLSHMEM_INVALID_PARAM;
}
return ACLSHMEM_SUCCESS;
}
int aclshmemi_team_split_2d_x(aclshmem_team_t &parent_team, int32_t &x_team_counts, int32_t &src_size,
int &x_range, aclshmem_team_t *&x_team)
{
int start = 0;
aclshmemx_team_t *src_team = &g_aclshmem_team_pool[parent_team];
for (int i = 0; i < x_team_counts; ++i) {
aclshmem_team_t my_xteam;
int x_stride = 1;
int x_size = ((i == x_team_counts - 1) && (src_size % x_range != 0)) ? src_size % x_range : x_range;
if (aclshmem_team_split_strided(parent_team, start, x_stride, x_size, &my_xteam) != ACLSHMEM_SUCCESS) {
SHM_LOG_INFO("create x-axis team " << (i + 1) << " of " << x_team_counts << " failed");
}
start += x_range;
if (my_xteam != ACLSHMEM_TEAM_INVALID) {
if (*x_team == ACLSHMEM_TEAM_INVALID) {
*x_team = my_xteam;
SHM_LOG_INFO("Current pe is " << src_team->mype << " , split x-axis succeed for x- " << i);
} else {
return ACLSHMEM_INNER_ERROR;
}
}
}
return ACLSHMEM_SUCCESS;
}
int aclshmemi_team_split_2d_y(aclshmem_team_t &parent_team, int32_t &y_team_counts, int32_t &src_size,
int &x_range, aclshmem_team_t *&y_team)
{
int start = 0;
int errorCode = 0;
aclshmemx_team_t *src_team = &g_aclshmem_team_pool[parent_team];
for (int i = 0; i < y_team_counts; ++i) {
aclshmem_team_t my_yteam;
int y_stride = x_range;
int remainder = src_size % x_range;
int y_range = src_size / x_range;
int y_size = (remainder && i < remainder) ? y_range + 1 : y_range;
if (aclshmem_team_split_strided(parent_team, start, y_stride, y_size, &my_yteam) != ACLSHMEM_SUCCESS) {
SHM_LOG_INFO("create y-axis team " << (i + 1) << " of " << y_team_counts << " failed");
}
start += 1;
if (my_yteam != ACLSHMEM_TEAM_INVALID) {
if (*y_team == ACLSHMEM_TEAM_INVALID) {
*y_team = my_yteam;
SHM_LOG_INFO("Current pe is " << src_team->mype << " , split y-axis succeed for y- " << i);
} else {
return ACLSHMEM_INNER_ERROR;
}
}
}
return ACLSHMEM_SUCCESS;
}
int aclshmem_team_split_2d(aclshmem_team_t parent_team, int x_range, aclshmem_team_t *x_team, aclshmem_team_t *y_team)
{
ACLSHMEM_TEAM_CHECK_SINGLE_INSTANCE(__func__);
auto ret = aclshmemi_team_split_2d_precheck(parent_team, x_range, x_team, y_team);
if (ret != ACLSHMEM_SUCCESS) {
return ret;
}
aclshmemx_team_t *src_team = &g_aclshmem_team_pool[parent_team];
int32_t src_start = src_team->start;
int32_t src_stride = src_team->stride;
int32_t src_size = src_team->size;
if (x_range > src_size) {
x_range = src_size;
}
int32_t x_team_counts = static_cast<int32_t>(std::ceil(src_size / static_cast<double>(x_range)));
int32_t y_team_counts = x_range;
ret = aclshmemi_team_split_2d_x(parent_team, x_team_counts, src_size, x_range, x_team);
if (ret != ACLSHMEM_SUCCESS) {
return ret;
}
return aclshmemi_team_split_2d_y(parent_team, y_team_counts, src_size, x_range, y_team);
}
int32_t aclshmem_team_translate_pe(aclshmem_team_t src_team, int32_t src_pe, aclshmem_team_t dest_team)
{
ACLSHMEM_TEAM_CHECK_SINGLE_INSTANCE(__func__);
if (!is_valid_team(src_team) || !is_valid_team(dest_team)) {
return -1;
}
if (src_pe < 0 || src_pe >= g_aclshmem_team_pool[src_team].size) {
return -1;
}
int global_pe = g_aclshmem_team_pool[src_team].pe_mapping[src_pe];
return g_aclshmem_team_pool[dest_team].pe_mapping[global_pe + ACLSHMEM_MAX_PES];
}
void aclshmem_team_destroy(aclshmem_team_t team)
{
if (!is_valid_team(team)) {
SHM_LOG_INFO("input team is invalid!, team: " << team);
return;
}
device_team_destroy(team);
g_team_mask ^= 1ULL << team;
if (update_device_state() != ACLSHMEM_SUCCESS) {
SHM_LOG_INFO("update state failed when destroy team!");
}
}
int32_t aclshmem_my_pe(void)
{
return g_state.mype;
}
int32_t aclshmem_n_pes(void)
{
return g_state.npes;
}
int32_t aclshmem_team_my_pe(aclshmem_team_t team)
{
ACLSHMEM_TEAM_CHECK_SINGLE_INSTANCE(__func__);
if (is_valid_team(team)) {
return g_aclshmem_team_pool[team].mype;
} else {
return -1;
}
}
int32_t aclshmem_team_n_pes(aclshmem_team_t team)
{
ACLSHMEM_TEAM_CHECK_SINGLE_INSTANCE(__func__);
if (is_valid_team(team)) {
return g_aclshmem_team_pool[team].size;
} else {
return -1;
}
}
int aclshmem_team_get_config(aclshmem_team_t team, aclshmem_team_config_t *config)
{
ACLSHMEM_TEAM_CHECK_SINGLE_INSTANCE(__func__);
ACLSHMEM_CHECK_RET(config == nullptr);
if (is_valid_team(team)) {
*config = g_aclshmem_team_pool[team].config;
return 0;
} else {
return ACLSHMEM_INVALID_PARAM;
}
}