* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/ge_context.h"
#include <stdexcept>
#include "graph/ge_global_options.h"
#include "graph/ge_local_context.h"
#include "graph/types.h"
#include "framework/common/debug/ge_log.h"
#include "utils/extern_math_util.h"
#include "external/ge_common/ge_api_types.h"
namespace ge {
namespace {
const int32_t kDecimal = 10;
const char_t *kHostExecPlacement = "HOST";
const char_t *kEnabled = "1";
template<class T>
ge::Status GetOptionValue(const std::string &option_name, T &var) {
std::string option;
if (ge::GetContext().GetOption(option_name, option) != GRAPH_SUCCESS) {
return ge::FAILED;
}
int64_t value = 0;
try {
value = static_cast<int64_t>(std::stoi(option.c_str()));
} catch (std::invalid_argument &) {
GELOGW("[Init] Transform option %s %s to int failed, as catching invalid_argument exception", option_name.c_str(),
option.c_str());
return ge::FAILED;
} catch (std::out_of_range &) {
GELOGW("[Init] Transform option %s %s to int failed, as catching out_of_range exception", option_name.c_str(),
option.c_str());
return ge::FAILED;
}
if (!IntegerChecker<T>::Compat(value)) {
GELOGW("[Init] Transform option %s %s to int failed, value is invalid_argument", option_name.c_str(),
option.c_str());
return ge::FAILED;
}
var = value;
return ge::SUCCESS;
}
}
GEContext &GetContext() {
static GEContext ge_context {};
return ge_context;
}
thread_local uint64_t GEContext::session_id_ = 0UL;
thread_local uint64_t GEContext::context_id_ = 0UL;
graphStatus GEContext::GetOption(const std::string &key, std::string &option) {
return GetThreadLocalContext().GetOption(key, option);
}
const std::string &GEContext::GetReadableName(const std::string &key) {
return GetThreadLocalContext().GetReadableName(key);
}
bool GEContext::IsOverflowDetectionOpen() const {
std::string enable_overflow_detection;
if (GetThreadLocalContext().GetOption("ge.exec.overflow", enable_overflow_detection) != GRAPH_SUCCESS) {
return false;
}
GELOGD("Option ge.exec.overflow is %s.", enable_overflow_detection.c_str());
return (enable_overflow_detection == kEnabled);
}
bool GEContext::IsGraphLevelSat() const {
std::string graph_level_sat;
if (GetThreadLocalContext().GetOption("ge.graphLevelSat", graph_level_sat) != GRAPH_SUCCESS) {
return false;
}
GELOGD("Option ge.graphLevelSat is %s.", graph_level_sat.c_str());
return (graph_level_sat == kEnabled);
}
bool GEContext::GetHostExecFlag() const {
std::string exec_placement;
if (GetThreadLocalContext().GetOption("ge.exec.placement", exec_placement) != GRAPH_SUCCESS) {
return false;
}
GELOGD("Option ge.exec.placement is %s.", exec_placement.c_str());
return exec_placement == kHostExecPlacement;
}
bool GEContext::GetTrainGraphFlag() const {
std::string run_mode;
if ((GetThreadLocalContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == ge::GRAPH_SUCCESS) &&
(!run_mode.empty())) {
if (static_cast<ge::GraphRunMode>(std::strtol(run_mode.c_str(), nullptr, kDecimal)) >= ge::TRAIN) {
return true;
}
}
return false;
}
uint64_t GEContext::GetInputFusionSize() const {
const uint64_t default_fusion_size = 128 * 1024U;
const uint64_t max_fusion_size = 32 * 1024 * 1024U;
std::string fusion_size;
if (GetThreadLocalContext().GetOption(OPTION_EXEC_INPUT_FUSION_SIZE, fusion_size) != GRAPH_SUCCESS) {
return default_fusion_size;
}
long value = std::strtol(fusion_size.c_str(), nullptr, kDecimal);
if (value < 0) {
GELOGI("%s is %s which is less than 0, return 0", OPTION_EXEC_INPUT_FUSION_SIZE, fusion_size.c_str());
return 0U;
}
uint64_t result = static_cast<uint64_t>(value);
if (result > max_fusion_size) {
GELOGW("option [%s] is %s which is bigger than max(%" PRIu64 "), return max", OPTION_EXEC_INPUT_FUSION_SIZE,
fusion_size.c_str(), max_fusion_size);
return max_fusion_size;
}
return result;
}
std::mutex &GetGlobalOptionsMutex() {
static std::mutex global_options_mutex;
return global_options_mutex;
}
std::map<std::string, std::string> &GetMutableGlobalOptions() {
static std::map<std::string, std::string> context_global_options{};
return context_global_options;
}
std::unordered_set<std::string> &GetMutableUserGlobalOptionKeys() {
static std::unordered_set<std::string> user_global_option_keys{};
return user_global_option_keys;
}
void GEContext::Init() {
(void) GetOptionValue("ge.exec.sessionId", session_id_);
(void) GetOptionValue("ge.exec.deviceId", device_id_);
int32_t stream_sync_timeout = -1;
(void) GetOptionValue("stream_sync_timeout", stream_sync_timeout);
SetStreamSyncTimeout(stream_sync_timeout);
int32_t event_sync_timeout = -1;
(void) GetOptionValue("event_sync_timeout", event_sync_timeout);
SetEventSyncTimeout(event_sync_timeout);
}
uint64_t GEContext::SessionId() const { return session_id_; }
uint32_t GEContext::DeviceId() const {
uint32_t device_id = 0U;
auto status = GetOptionValue("ge.session_device_id", device_id);
return (status == ge::SUCCESS) ? device_id : device_id_;
}
int32_t GEContext::StreamSyncTimeout() const { return GetThreadLocalContext().StreamSyncTimeout(); }
int32_t GEContext::EventSyncTimeout() const { return GetThreadLocalContext().EventSyncTimeout(); }
void GEContext::SetSessionId(const uint64_t session_id) { session_id_ = session_id; }
void GEContext::SetContextId(const uint64_t context_id) { context_id_ = context_id; }
void GEContext::SetCtxDeviceId(const uint32_t device_id) { device_id_ = device_id; }
void GEContext::SetStreamSyncTimeout(const int32_t timeout) { GetThreadLocalContext().SetStreamSyncTimeout(timeout); }
void GEContext::SetEventSyncTimeout(const int32_t timeout) { GetThreadLocalContext().SetEventSyncTimeout(timeout); }
graphStatus GEContext::SetOptionNameMap(const std::string &option_name_map_json) {
return GetThreadLocalContext().SetOptionNameMap(option_name_map_json);
}
OptimizationOption &GEContext::GetOo() const {
return GetThreadLocalContext().GetOo();
}
}