* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <iostream>
#include <vector>
#include <sstream>
#include "host/shmem_host_def.h"
#include "shmemi_logger.h"
#include "shmemi_string_util.h"
#include "store_tcp_config.h"
#include "store_prefix.h"
#include "store_factory.h"
namespace shm {
namespace store {
static __thread int failedReason_ = 0;
static constexpr size_t MAX_TLS_INFO_LEN = 10 * 1024U;
bool StoreFactory::enableTls = true;
std::string StoreFactory::tlsInfo;
std::string StoreFactory::tlsPkInfo;
std::string StoreFactory::tlsPkPwdInfo;
std::mutex StoreFactory::storesMutex_;
AcclinkTlsOption StoreFactory::tlsOption_;
bool StoreFactory::isTlsInitialized_ = false;
std::thread StoreFactory::cleanerThread_;
std::atomic<bool> StoreFactory::timerRunning_{false};
std::condition_variable StoreFactory::cv_;
std::atomic<bool> StoreFactory::stop_{false};
StorePtr StoreFactory::CreateStore(const std::string &ip, uint16_t port, bool isServer, int32_t rankId,
int32_t connMaxRetry, int32_t sockFd) noexcept
{
std::string storeKey = std::string(ip).append(":").append(std::to_string(port));
std::unique_lock<std::mutex> lockGuard{storesMutex_};
auto store = SmMakeRef<TcpConfigStore>(ip, port, isServer, rankId, sockFd);
SHM_ASSERT_RETURN(store != nullptr, nullptr);
if (!isTlsInitialized_ && InitTlsOption() != StoreErrorCode::SUCCESS) {
SHM_LOG_ERROR("init tls option failed. ");
return nullptr;
}
auto ret = store->Startup(tlsOption_, connMaxRetry);
if (ret == SM_RESOURCE_IN_USE) {
SHM_LOG_INFO("Startup for store(isSever=" << isServer << ", rank=" << rankId << ") address in use");
failedReason_ = SM_RESOURCE_IN_USE;
return nullptr;
}
if (ret != 0) {
SHM_LOG_ERROR("Startup for store(isSever=" << isServer << ", rank=" << rankId << ") failed:" << ret);
failedReason_ = ret;
return nullptr;
}
lockGuard.unlock();
return store.Get();
}
void StoreFactory::DestroyStore() noexcept
{
TlsCleanUp();
ShutDownCleanupThread();
}
void StoreFactory::DestroyStoreAll(bool afterFork) noexcept
{
TlsCleanUp();
ShutDownCleanupThread();
}
void StoreFactory::TlsCleanUp() noexcept
{
StoreFactory::tlsPkInfo = "";
StoreFactory::tlsPkPwdInfo = "";
tlsOption_.tlsPk = "";
tlsOption_.tlsPkPwd = "";
}
StorePtr StoreFactory::PrefixStore(const StorePtr &base, const std::string &prefix) noexcept
{
SHM_VALIDATE_RETURN(base != nullptr, "invalid param, base is nullptr", nullptr);
auto store = SmMakeRef<PrefixConfigStore>(base, prefix);
SHM_ASSERT_RETURN(store != nullptr, nullptr);
return store.Get();
}
int StoreFactory::GetFailedReason() noexcept
{
return failedReason_;
}
Result ParseStr2Array(const std::string &token, char splitter, std::set<std::string> &parts)
{
std::istringstream tokenSteam(token);
std::string part;
while (std::getline(tokenSteam, part, splitter)) {
part = utils::StringUtil::TrimString(part);
if (!part.empty()) {
parts.insert(part);
}
}
if (parts.empty()) {
SHM_LOG_INFO("parse token to array failed");
return StoreErrorCode::ERROR;
}
return StoreErrorCode::SUCCESS;
}
Result ParseStr2KV(const std::string &token, char splitter, std::pair<std::string, std::string> &pair)
{
std::istringstream stm(token);
std::string key;
std::string value;
if (std::getline(stm, key, splitter) && std::getline(stm, value, splitter)) {
key = utils::StringUtil::TrimString(key);
value = utils::StringUtil::TrimString(value);
if (!key.empty() && !value.empty()) {
pair.first = key;
pair.second = value;
return StoreErrorCode::SUCCESS;
}
}
SHM_LOG_INFO("parse token to kv failed");
return StoreErrorCode::ERROR;
}
bool SetTlsOptionValue(AcclinkTlsOption &tlsOption, const std::string &key, const std::string &value)
{
if (key == "tlsCaPath") {
tlsOption.tlsCaPath = value;
} else if (key == "tlsCert") {
tlsOption.tlsCert = value;
} else if (key == "tlsCrlPath") {
tlsOption.tlsCrlPath = value;
} else if (key == "packagePath") {
tlsOption.packagePath = value;
} else {
return false;
}
return true;
}
bool SetTlsOptionValues(AcclinkTlsOption &tlsOption, const std::string &key, std::set<std::string> &values)
{
if (key == "tlsCrlFile") {
tlsOption.tlsCrlFile = values;
} else if (key == "tlsCaFile") {
tlsOption.tlsCaFile = values;
} else {
return false;
}
return true;
}
Result ParseTlsInfo(const std::string &inputStr, AcclinkTlsOption &tlsOption)
{
std::istringstream tokenSteam(inputStr);
std::vector<std::string> tokens;
std::string token;
while (std::getline(tokenSteam, token, ';')) {
if (!utils::StringUtil::TrimString(token).empty()) {
tokens.push_back(token);
}
}
for (std::string &t : tokens) {
std::pair<std::string, std::string> pair;
auto ret = ParseStr2KV(t, ':', pair);
if (ret != StoreErrorCode::SUCCESS) {
continue;
}
bool res = true;
auto key = pair.first;
std::set<std::string> paths;
if (pair.first == "tlsCrlFile" || pair.first == "tlsCaFile") {
ret = ParseStr2Array(pair.second, ',', paths);
if (ret != StoreErrorCode::SUCCESS) {
continue;
}
res = SetTlsOptionValues(tlsOption, pair.first, paths);
} else {
res = SetTlsOptionValue(tlsOption, pair.first, pair.second);
}
if (!res) {
SHM_LOG_INFO("un-match tls info key " << pair.first);
}
}
return StoreErrorCode::SUCCESS;
}
Result StoreFactory::InitTlsOption() noexcept
{
tlsOption_.enableTls = StoreFactory::enableTls;
if (!tlsOption_.enableTls) {
SHM_LOG_INFO("tls is not enabled.");
return StoreErrorCode::SUCCESS;
}
tlsOption_.tlsPk = StoreFactory::tlsPkInfo;
tlsOption_.tlsPkPwd = StoreFactory::tlsPkPwdInfo;
if (ParseTlsInfo(StoreFactory::tlsInfo, tlsOption_) != StoreErrorCode::SUCCESS) {
SHM_LOG_ERROR("extract ssl info from input failed.");
return StoreErrorCode::ERROR;
}
isTlsInitialized_ = true;
return StoreErrorCode::SUCCESS;
}
std::function<int(const std::string&, char*, size_t&)> StoreFactory::ConvertFunc(int (*rawFunc)(const char*,
size_t, char*, size_t&)) noexcept
{
return [rawFunc](const std::string &cipherText, char *plainText, size_t &plainTextLen) {
auto tmpCipherLen = cipherText.size();
int ret = rawFunc(cipherText.c_str(), tmpCipherLen, plainText, plainTextLen);
return ret;
};
}
int32_t StoreFactory::SetTlsInfo(bool enable, const char *tlsData, const size_t tlsDataLen) noexcept
{
enableTls = enable;
if (!enable) {
return StoreErrorCode::SUCCESS;
}
if (tlsData == nullptr || tlsDataLen > MAX_TLS_INFO_LEN) {
SHM_LOG_ERROR("tls info null or len invalid.");
return StoreErrorCode::ERROR;
}
StoreFactory::tlsInfo = std::string(tlsData, tlsDataLen);
return StoreErrorCode::SUCCESS;
}
int32_t StoreFactory::SetTlsPkInfo(const char *tlsPk, const uint32_t tlsPkLen, const char *tlsPkPwd,
const uint32_t tlsPkPwLen, const smem_decrypt_handler &h) noexcept
{
if (timerRunning_.exchange(true)) {
SHM_LOG_INFO("TLS private key has been set multiple times");
return StoreErrorCode::SUCCESS;
}
if (tlsPk == nullptr || tlsPkLen > MAX_TLS_INFO_LEN) {
SHM_LOG_ERROR("tls private key is null or len invalid.");
return StoreErrorCode::ERROR;
}
if (tlsPkPwd == nullptr) {
SHM_LOG_INFO("tls private key password is null.");
StoreFactory::tlsPkPwdInfo = "";
} else {
if (tlsPkPwLen > MAX_TLS_INFO_LEN) {
SHM_LOG_ERROR("tls private key password len invalid.");
return StoreErrorCode::ERROR;
}
StoreFactory::tlsPkPwdInfo = std::string(tlsPkPwd, tlsPkPwLen);
}
StoreFactory::tlsPkInfo = std::string(tlsPk, tlsPkLen);
if (h != nullptr) {
tlsOption_.decryptHandler_ = ConvertFunc(h);
}
stop_ = false;
cleanerThread_ = std::thread([]() {
std::unique_lock<std::mutex> lockGuard{storesMutex_};
if (!cv_.wait_for(lockGuard, std::chrono::hours(1), [] { return stop_.load(); })) {
TlsCleanUp();
SHM_LOG_INFO("TlsCleanUp successfully");
}
});
return StoreErrorCode::SUCCESS;
}
void StoreFactory::ShutDownCleanupThread() noexcept
{
if (timerRunning_) {
{
std::lock_guard<std::mutex> lockGuard{storesMutex_};
stop_ = true;
}
cv_.notify_one();
if (cleanerThread_.joinable()) {
cleanerThread_.join();
}
timerRunning_ = false;
}
}
}
}