* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <netdb.h>
#include <net/if.h>
#include "shmemi_host_def.h"
#include "host/shmem_host_def.h"
#include "utils/shmemi_logger.h"
#include "utils/shmemi_host_types.h"
#include "init/bootstrap/shmemi_bootstrap.h"
#include "store_op.h"
#include "store_net_group_engine.h"
#include "store_factory.h"
#include "sotre_net.h"
#include "store_net_common.h"
#include "socket/uid_utils.h"
#include "config_store/store_net_utils.h"
constexpr int DEFAULT_ID = 0;
static std::mutex g_store_mutex;
static int g_store_ref = 0;
struct ConfigStoreState {
shm::store::StorePtr store_ = nullptr;
shm::store::SmemGroupEnginePtr group_engine_ = nullptr;
};
int config_store_get_unique_id(void* uid) {
char pta_env_ip[shm::utils::MAX_IP];
uint16_t pta_env_port;
sa_family_t sockType;
const char *ipPort = std::getenv("SHMEM_UID_SESSION_ID");
const char *ipInfo = std::getenv("SHMEM_UID_SOCK_IFNAME");
bool is_from_ifa = false;
if (ipPort != nullptr) {
if (shm::utils::aclshmemi_get_ip_from_env(pta_env_ip, pta_env_port, sockType, ipPort) != ACLSHMEM_SUCCESS) {
SHM_LOG_ERROR("cant get pta master addr.");
return ACLSHMEM_INVALID_PARAM;
}
} else {
is_from_ifa = true;
if (shm::utils::aclshmemi_get_ip_from_ifa(pta_env_ip, sockType, ipInfo) != ACLSHMEM_SUCCESS) {
SHM_LOG_ERROR("cant get available ip port.");
return ACLSHMEM_INVALID_PARAM;
}
}
SHM_LOG_INFO("get master IP value:" << pta_env_ip);
return shm::utils::aclshmemi_set_ip_info((aclshmemx_uniqueid_t *)uid, sockType, pta_env_ip, pta_env_port, is_from_ifa);
}
int aclshmemi_bootstrap_plugin_pre_init(aclshmemi_bootstrap_handle_t* handle)
{
if (!handle) {
SHM_LOG_ERROR("bootstrap pre_init: invalid arguments (nullptr).");
return ACLSHMEM_BOOTSTRAP_ERROR;
}
if (handle->pre_init_ops == nullptr) {
SHM_LOG_DEBUG(" bootstrap plugin pre init start.");
ACLSHMEM_CHECK_RET(ACLSHMEM_BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1));
handle->pre_init_ops->get_unique_id = config_store_get_unique_id;
handle->pre_init_ops->get_unique_id_static_magic = nullptr;
handle->pre_init_ops->cookie = nullptr;
SHM_LOG_DEBUG(" bootstrap plugin pre init end.");
} else {
SHM_LOG_DEBUG(" pre_init_ops had already prepared.");
}
return ACLSHMEM_SUCCESS;
}
void config_store_bootstrap_global_exit(int status, aclshmemi_bootstrap_handle_t* handle)
{
if (!handle) {
SHM_LOG_ERROR("bootstrap global_exit: invalid arguments (nullptr).");
return;
}
auto state = static_cast<ConfigStoreState*>(handle->bootstrap_state);
if (!state) {
SHM_LOG_ERROR("handle->state is NULL");
return;
}
if (state->group_engine_ == nullptr) {
SHM_LOG_ERROR("Group is NULL");
return;
}
state->group_engine_->GroupBroadcastExit(status);
}
int32_t config_store_set_tls_info(bool enable, const char *tls_info, const uint32_t tls_info_len)
{
return shm::store::StoreFactory::SetTlsInfo(enable, tls_info, tls_info_len);
}
int32_t config_store_bootstrap_set_tls_key(
const char *tls_pk, const uint32_t tls_pk_len,
const char *tls_pk_pw, const uint32_t tls_pk_pw_len, const aclshmem_decrypt_handler decrypt_handler)
{
return shm::store::StoreFactory::SetTlsPkInfo(tls_pk, tls_pk_len, tls_pk_pw, tls_pk_pw_len, decrypt_handler);
}
static int config_store_bootstrap_finalize(aclshmemi_bootstrap_handle_t *handle) {
std::lock_guard<std::mutex> lock(g_store_mutex);
g_store_ref--;
if (!handle) {
return ACLSHMEM_SUCCESS;
}
if (handle->pre_init_ops) {
ACLSHMEM_BOOTSTRAP_PTR_FREE(handle->pre_init_ops);
handle->pre_init_ops = nullptr;
}
auto state = static_cast<ConfigStoreState*>(handle->bootstrap_state);
if (!state) return ACLSHMEM_INNER_ERROR;
if (state->store_ != nullptr) {
if (g_store_ref == 0) {
shm::store::StoreFactory::DestroyStore();
SHM_LOG_INFO("Global Store Destroyed. ");
} else {
SHM_LOG_INFO("Store ref count is " << g_store_ref << ", Skip Global DestroyStore.");
}
}
state->group_engine_ = nullptr;
state->store_ = nullptr;
delete state;
return ACLSHMEM_SUCCESS;
}
static int config_store_bootstrap_allgather(const void *in, void *out, int len, aclshmemi_bootstrap_handle_t *handle) {
if (!in || !out || !handle) {
SHM_LOG_ERROR("bootstrap allgather: invalid arguments (nullptr).");
return ACLSHMEM_BOOTSTRAP_ERROR;
}
auto state = static_cast<ConfigStoreState*>(handle->bootstrap_state);
if (!state) return ACLSHMEM_INNER_ERROR;
int rank = handle->mype;
int nranks = handle->npes;
auto ret = (shm::store::SmemGroupEnginePtr(state->group_engine_))->GroupAllGather((char *)in, len, (char *)out, len * nranks);
if (ret != ACLSHMEM_SUCCESS) {
SHM_LOG_ERROR("Group AllGather timeout or store failure");
return ACLSHMEM_SMEM_ERROR;
}
return ACLSHMEM_SUCCESS;
}
static int config_store_bootstrap_barrier(aclshmemi_bootstrap_handle_t *handle) {
if (!handle) {
SHM_LOG_ERROR("bootstrap barrier: invalid arguments (nullptr).");
return ACLSHMEM_BOOTSTRAP_ERROR;
}
auto state = static_cast<ConfigStoreState*>(handle->bootstrap_state);
if (!state) return ACLSHMEM_INNER_ERROR;
int rank = handle->mype;
int tag = 0;
int nranks = handle->npes;
if (nranks == 1) {
SHM_LOG_DEBUG("Single rank, skip barrier");
return ACLSHMEM_SUCCESS;
}
SHM_LOG_INFO("group_engine_bootstrap_barrier");
SHM_LOG_DEBUG("Barrier start. rank: " << rank << " nranks: " << nranks <<" tag: "<< tag);
if (state->group_engine_ == nullptr) {
SHM_LOG_ERROR("Group is NULL");
return ACLSHMEM_INNER_ERROR;
}
auto ret = (shm::store::SmemGroupEnginePtr(state->group_engine_))->GroupBarrier();
if (ret != ACLSHMEM_SUCCESS) {
SHM_LOG_ERROR("Group barrier timeout or store failure");
return ACLSHMEM_SMEM_ERROR;
}
SHM_LOG_DEBUG("Barrier end. rank: " << rank << " nranks: " << nranks <<" tag: "<< tag);
return ACLSHMEM_SUCCESS;
}
int32_t init_config_store(aclshmemi_bootstrap_handle_t* handle)
{
if (!handle) {
SHM_LOG_ERROR("bootstrap init_config_store: invalid arguments (nullptr).");
return ACLSHMEM_BOOTSTRAP_ERROR;
}
auto state = static_cast<ConfigStoreState*>(handle->bootstrap_state);
if (!state) return ACLSHMEM_INNER_ERROR;
shm::store::StoreFactory::SetTlsInfo(false, nullptr, 0);
int32_t sock_fd = handle->sockFd;
shm::store::UrlExtraction option;
std::string url(handle->ipport);
SHM_ASSERT_RETURN(option.ExtractIpPortFromUrl(url) == ACLSHMEM_SUCCESS, ACLSHMEM_INVALID_PARAM);
if (handle->mype == 0) {
state->store_ = shm::store::StoreFactory::CreateStore(option.ip, option.port, true, 0, -1, sock_fd);
}
else {
state->store_ = shm::store::StoreFactory::CreateStore(option.ip, option.port, false, handle->mype, handle->timeOut);
}
return ACLSHMEM_SUCCESS;
}
int32_t init_group_engine(aclshmemi_bootstrap_handle_t* handle)
{
if (!handle) {
SHM_LOG_ERROR("bootstrap init_group_engine: invalid arguments (nullptr).");
return ACLSHMEM_BOOTSTRAP_ERROR;
}
auto state = static_cast<ConfigStoreState*>(handle->bootstrap_state);
if (!state) return ACLSHMEM_INNER_ERROR;
std::string prefix = "SHM_(" + std::to_string(DEFAULT_ID) + ")_";
shm::store::StorePtr store_ptr = shm::store::StoreFactory::PrefixStore(state->store_, prefix);
shm::store::SmemGroupOption opt = {
(uint32_t)handle->npes, (uint32_t)handle->mype, handle->timeControlOut * 1000U, false, nullptr, nullptr};
shm::store::SmemGroupEnginePtr group = shm::store::SmemNetGroupEngine::Create(store_ptr, opt);
SHM_ASSERT_RETURN(group != nullptr, ACLSHMEM_SMEM_ERROR);
state->group_engine_ = group;
return ACLSHMEM_SUCCESS;
}
int aclshmemi_bootstrap_plugin_init(void* comm, aclshmemi_bootstrap_handle_t* handle)
{
std::lock_guard<std::mutex> lock(g_store_mutex);
g_store_ref++;
int status = ACLSHMEM_SUCCESS;
if (handle == nullptr) {
SHM_LOG_ERROR(" aclshmemi_bootstrap_plugin_init: invalid arguments (nullptr)");
return ACLSHMEM_BOOTSTRAP_ERROR;
}
ConfigStoreState *state = new (std::nothrow) ConfigStoreState();
if (state == nullptr) {
SHM_LOG_ERROR("Failed to allocate memory for bootstrap state. ");
return ACLSHMEM_BOOTSTRAP_ERROR;
}
handle->bootstrap_state = static_cast<void *>(state);
std::string ipPort;
if (handle->use_attr_ipport == false) {
aclshmemi_bootstrap_uid_state_t* uid_args = (aclshmemi_bootstrap_uid_state_t*)comm;
if (uid_args->addr.type == ADDR_IPv6) {
char ipStr[INET6_ADDRSTRLEN] = {0};
if (inet_ntop(AF_INET6, &(uid_args->addr.addr.addr6.sin6_addr), ipStr, sizeof(ipStr)) == nullptr) {
SHM_LOG_ERROR("inet_ntop failed for IPv6");
return ACLSHMEM_INNER_ERROR;
}
uint16_t port = ntohs(uid_args->addr.addr.addr6.sin6_port);
ipPort = "tcp6://[" + std::string(ipStr) + "]:" + std::to_string(port);
} else {
char ipStr[INET_ADDRSTRLEN] = {0};
if (inet_ntop(AF_INET, &(uid_args->addr.addr.addr4.sin_addr), ipStr, sizeof(ipStr)) == nullptr) {
SHM_LOG_ERROR("inet_ntop failed for IPv4");
return ACLSHMEM_INNER_ERROR;
}
uint16_t port = ntohs(uid_args->addr.addr.addr4.sin_port);
ipPort = "tcp://" + std::string(ipStr) + ":" + std::to_string(port);
}
handle->sockFd = uid_args->inner_sockFd;
handle->timeOut = DEFAULT_TIMEOUT;
handle->timeControlOut = DEFAULT_TIMEOUT;
size_t ip_len = sizeof(handle->ipport);
if (ipPort.length() > ip_len) {
SHM_LOG_ERROR("Generated IP String Illegal!!");
return ACLSHMEM_INNER_ERROR;
}
strncpy(handle->ipport, ipPort.c_str(), ip_len - 1);
handle->ipport[ip_len - 1] = '\0';
SHM_LOG_INFO("extract ip port:" << ipPort);
}
status = config_store_set_tls_info(handle->tls_enable, handle->tls_info, handle->tls_info_len);
if (handle->tls_enable == true) {
status = config_store_bootstrap_set_tls_key(
handle->tls_pk, handle->tls_pk_len,
handle->tls_pk_pw, handle->tls_pk_pw_len, handle->decrypt_handler);
}
status = init_config_store(handle);
if (status != ACLSHMEM_SUCCESS) {
SHM_LOG_ERROR("shmem init config store failed, error: " << status);
return status;
}
status = init_group_engine(handle);
if (status != ACLSHMEM_SUCCESS) {
SHM_LOG_ERROR("shmem init group engine failed, error: " << status);
return status;
}
handle->allgather = config_store_bootstrap_allgather;
handle->barrier = config_store_bootstrap_barrier;
handle->finalize = config_store_bootstrap_finalize;
handle->alltoall = nullptr;
handle->global_exit = config_store_bootstrap_global_exit;
SHM_LOG_INFO("pe " << handle->mype << ": bootstrap plugin initialized successfully");
return ACLSHMEM_SUCCESS;
}