* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
*/
#include "utils.h"
#include "model_instance_state.h"
#define TYPE_32_BYTE 4
#define TYPE_64_BYTE 8
#define TYPE_16_BYTE 2
#define TYPE_8_BYTE 1
#define BYTE_PTR (char *)
using namespace ge;
using namespace std;
using json = nlohmann::json;
namespace triton::backend::npu_ge {
std::string ModelInstanceState::GetEnvVar(const std::string &name)
{
const char *value = std::getenv(name.c_str());
return (value != nullptr) ? std::string(value) : "";
}
void ModelInstanceState::LoadGeConfig(std::string &env_value, std::map<ge::AscendString, ge::AscendString> &configMap)
{
try {
json jsonData = json::parse(env_value);
if (!jsonData.is_object()) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (std::string("Error: JSON file content is not object format")).c_str());
return;
}
for (auto &element : jsonData.items()) {
std::string key = element.key();
std::string value;
if (element.value().is_string()) {
value = element.value().get<std::string>();
} else if (element.value().is_number()) {
value = std::to_string(element.value().get<double>());
} else if (element.value().is_boolean()) {
value = element.value().get<bool>() ? "true" : "false";
} else if (element.value().is_null()) {
value = "null";
} else {
value = element.value().dump();
}
configMap[ge::AscendString(key.c_str())] = ge::AscendString(value.c_str());
LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("key: ") + key + " value: " + value).c_str());
}
return;
} catch (const json::parse_error &e) {
LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("JSON parsing error: ") + e.what()).c_str());
return;
} catch (const std::exception &e) {
LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Error: ") + e.what()).c_str());
return;
}
}
void ModelInstanceState::StaticModeConfig(std::map<ge::AscendString, ge::AscendString> &parser_options)
{
string ir_option;
for (const auto &clientTensor : model_state_->GetInputClientTensors()) {
ir_option += clientTensor.name_;
ir_option += ":";
for_each(clientTensor.dims_.begin(), clientTensor.dims_.end(), [&ir_option](int64_t dim) {
ir_option += std::to_string(dim);
ir_option += ",";
});
ir_option.back() = ';';
}
ir_option.pop_back();
parser_options[ge::AscendString(ge::ir_option::INPUT_SHAPE)] = ge::AscendString(ir_option.c_str());
LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("ir_option: ") + ir_option).c_str());
}
void ModelInstanceState::StaticModeConfigOne(std::map<ge::AscendString, ge::AscendString> &parser_options)
{
string ir_option;
for (const auto &clientTensor : model_state_->GetInputClientTensors()) {
ir_option += clientTensor.name_;
ir_option += ":1,";
for_each(clientTensor.dims_.begin(), clientTensor.dims_.end(), [&ir_option](int64_t dim) {
ir_option += std::to_string(dim);
ir_option += ",";
});
ir_option.back() = ';';
}
ir_option.pop_back();
parser_options[ge::AscendString(ge::ir_option::INPUT_SHAPE)] = ge::AscendString(ir_option.c_str());
LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("ir_option: ") + ir_option).c_str());
}
int ModelInstanceState::InitGEEnvironment()
{
aclError retInit = aclInit(nullptr);
if (retInit != ACL_ERROR_NONE) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR,
(std::string("aclInit failed, ret is: ") + std::to_string(retInit)).c_str());
return RET_ERR;
}
std::map<ge::AscendString, ge::AscendString> ge_init_options;
std::string ge_json(model_state_->GetRuntimeModelPath());
ge_json += "/../ge.json";
LOG_MESSAGE(TRITONSERVER_LOG_INFO, ge_json.c_str());
std::string env_value = GetEnvVar("GE_NPU_CONFIG");
if (!env_value.empty()) {
LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("GE_NPU_CONFIG = ") + env_value).c_str());
LoadGeConfig(env_value, ge_init_options);
} else {
LOG_MESSAGE(TRITONSERVER_LOG_INFO, "No GE_NPU_CONFIG parameters");
if (!model_state_->GetGeConfig().empty()) {
for (auto &pair : model_state_->GetGeConfig()) {
ge_init_options[ge::AscendString((pair.first).c_str())] = ge::AscendString((pair.second).c_str());
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("ge_init_options: ") + "key: " + pair.first + " value: " + pair.second).c_str());
}
}
}
LOG_MESSAGE(TRITONSERVER_LOG_INFO,
(std::string("ge_init_options.size(): ") + std::to_string(ge_init_options.size())).c_str());
ge::Status ret = ge::GEInitialize(ge_init_options);
if (ret != RET_OK) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (std::string(" GEInitialize failed.") + std::to_string(ret)).c_str());
aclFinalize();
return RET_ERR;
}
return RET_OK;
}
int ModelInstanceState::InitDevices(std::vector<int> &dev_ids_)
{
uint32_t dev_count;
aclError retInit = aclrtGetDeviceCount(&dev_count);
if (retInit != ACL_ERROR_NONE) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR,
(std::string("aclrtGetDeviceCount failed, ret is: ") + std::to_string(retInit)).c_str());
return RET_ERR;
}
if (dev_ids_.size() == 0) {
LOG_MESSAGE(TRITONSERVER_LOG_INFO,
(std::string("aclrtGetDeviceCount getDevice count: ") + std::to_string(dev_count)).c_str());
LOG_MESSAGE(TRITONSERVER_LOG_INFO, "user not set dev_ids ,use all npu card");
for (uint32_t i = 0; i < dev_count; i++) {
dev_ids_.push_back(i);
}
}
if (dev_ids_.size() == 0) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "no npu card detected!");
return RET_ERR;
}
return RET_OK;
}
int ModelInstanceState::InitGraphSession(int dev_id, int graph_id, aclrtContext context, std::mutex &mu,
ge::Session *session_)
{
ge::Status ret;
aclrtStream stream_ = nullptr;
aclError aclRet;
aclrtSetCurrentContext(context);
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
(std::string("graph_id: ") + std::to_string(graph_id) + std::string("start init")).c_str());
ge::Graph graph1;
std::map<ge::AscendString, ge::AscendString> parser_options;
ConfigureParserOptions(parser_options);
if (!ParseGraph(parser_options, graph1, mu)) {
return RET_ERR;
}
if (!AddAndCompileGraph(session_, graph_id, graph1, ret)) {
return RET_ERR;
}
if (!CreateAndLoadStream(session_, graph_id, stream_, ret, aclRet)) {
return RET_ERR;
}
model_state_->GetScheduler()->AddInstance(graph_id, context, stream_, session_, dev_id);
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
(std::string("graph_id: ") + std::to_string(graph_id) + std::string("finished")).c_str());
return RET_OK;
}
void ModelInstanceState::ConfigureParserOptions(std::map<ge::AscendString, ge::AscendString> &parser_options)
{
if (model_state_->GetGeStaticMode()) {
if (model_state_->GetInferMode() == ModelState::InferMode::DYNAMICMODEL ||
model_state_->GetDynamicBatchSupport()) {
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, "generate one batch graph");
StaticModeConfigOne(parser_options);
} else {
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, "generate fixed batch graph");
StaticModeConfig(parser_options);
}
} else {
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, "generate dynamic batch graph");
}
}
bool ModelInstanceState::ParseGraph(std::map<ge::AscendString, ge::AscendString> &parser_options, ge::Graph &graph1,
std::mutex &mu)
{
ge::graphStatus graph_ret;
{
std::lock_guard<std::mutex> lock(mu);
switch (model_state_->GetModelType()) {
case ModelState::ModelType::ONNX:
graph_ret = ge::aclgrphParseONNX(model_state_->GetModelPath().c_str(), parser_options, graph1);
if (graph_ret != RET_OK) {
LOG_MESSAGE(
TRITONSERVER_LOG_ERROR,
(std::string("aclgrphParseONNX execute failed, ret is: ") + std::to_string(graph_ret)).c_str());
return false;
}
break;
case ModelState::ModelType::TENSORFLOW:
graph_ret = ge::aclgrphParseTensorFlow(model_state_->GetModelPath().c_str(), parser_options, graph1);
if (graph_ret != RET_OK) {
LOG_MESSAGE(
TRITONSERVER_LOG_ERROR,
(std::string("aclgrphParseTensorFlow execute failed, ret is: ") + std::to_string(graph_ret))
.c_str());
return false;
}
break;
default:
TRITONSERVER_Error *error =
TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, "can't find model file in model path!");
TRITONSERVER_ErrorDelete(error);
}
}
return true;
}
bool ModelInstanceState::AddAndCompileGraph(ge::Session *session_, int graph_id, ge::Graph &graph1, ge::Status &ret)
{
std::map<std::string, std::string> options;
if (model_state_->GetGeOption().find("graph") != model_state_->GetGeOption().end()) {
options = model_state_->GetGeOption()["graph"];
}
std::map<AscendString, AscendString> ascend_options;
for (const auto &pair : options) {
ascend_options.emplace(AscendString(pair.first.c_str()), AscendString(pair.second.c_str()));
}
ret = session_->AddGraph(graph_id, graph1, ascend_options);
if (ret != RET_OK) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR,
(std::string("session_->AddGraph failed, ret is: ") + std::to_string(ret)).c_str());
return false;
}
ret = session_->CompileGraph(graph_id);
if (ret != RET_OK) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR,
(std::string("session_->CompileGraph failed, ret is: ") + std::to_string(ret)).c_str());
return false;
}
return true;
}
bool ModelInstanceState::CreateAndLoadStream(ge::Session *session_, int graph_id, aclrtStream &stream_, ge::Status &ret,
aclError &aclRet)
{
aclRet = aclrtCreateStream(&stream_);
if (aclRet != ACL_SUCCESS) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR,
(std::string("aclrtCreateStream execute failed, code is:") + std::to_string(aclRet)).c_str());
return false;
}
std::map<AscendString, AscendString> load_graph_options = {};
ret = session_->LoadGraph(graph_id, load_graph_options, stream_);
if (ret != RET_OK) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR,
(std::string("session_->LoadGraph failed, ret is: ") + std::to_string(ret)).c_str());
return false;
}
return true;
}
int ModelInstanceState::CalculateDeviceExecBlock(int device_count)
{
int device_exec_block =
(model_state_->GetTritonThreadCount() * model_state_->GetInstanceExecBlock() + device_count - 1) / device_count;
if (model_state_->GetDeviceExecBlock() != -1) {
device_exec_block = model_state_->GetDeviceExecBlock();
}
return device_exec_block;
}
void ModelInstanceState::ConfigureGeOptions(std::map<ge::AscendString, ge::AscendString> &options)
{
std::map<std::string, std::string> m1;
if (model_state_->GetGeOption().find("session") != model_state_->GetGeOption().end()) {
m1 = model_state_->GetGeOption()["session"];
}
for (auto &pair : m1) {
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, ("Pair: " + pair.first + " " + pair.second).c_str());
options[ge::AscendString(pair.first.c_str())] = ge::AscendString(pair.second.c_str());
}
if (m1.find("ge.graphRunMode") == m1.end()) {
options[ge::AscendString("ge.graphRunMode")] = ge::AscendString("0");
}
}
void ModelInstanceState::CreateDeviceThreads(int dev_id, int dev_index, int device_exec_block,
std::vector<std::thread> &threads, std::mutex &mu)
{
std::map<ge::AscendString, ge::AscendString> options;
std::string str = std::to_string(dev_id);
options[ge::AscendString("ge.session_device_id")] = ge::AscendString(str.c_str());
options[ge::AscendString("ge.constLifecycle")] = ge::AscendString("session");
ConfigureGeOptions(options);
ConfigureDumpOptions(options);
aclrtContext context_ = nullptr;
if (!SetDeviceAndContext(dev_id, context_)) {
return;
}
std::vector<ge::Session *> sessions = {};
CreateSessions(dev_index, device_exec_block, options, sessions);
CreateThreads(dev_id, dev_index, device_exec_block, mu, context_, sessions, threads);
}
void ModelInstanceState::ConfigureDumpOptions(std::map<ge::AscendString, ge::AscendString> &options)
{
if (model_state_->GetDumpData()) {
options[ge::AscendString("ge.exec.enableDump")] = ge::AscendString("1");
options[ge::AscendString("ge.exec.dumpPath")] = ge::AscendString("./dump_data");
options[ge::AscendString("ge.exec.dumpMode")] = ge::AscendString("all");
}
}
bool ModelInstanceState::SetDeviceAndContext(int dev_id, aclrtContext &context_)
{
aclError retInit = aclrtSetDevice(dev_id);
if (retInit != ACL_SUCCESS) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR,
(std::string("aclrtSetDevice execute failed, code is: ") + std::to_string(retInit)).c_str());
return false;
}
aclrtGetCurrentContext(&context_);
return true;
}
void ModelInstanceState::CreateSessions(int dev_index, int device_exec_block,
const std::map<ge::AscendString, ge::AscendString> &options,
std::vector<ge::Session *> &sessions)
{
for (int j = 0; j < device_exec_block; j++) {
ge::Session *session = nullptr;
if (!model_state_->GetGeStaticMode()) {
session = new Session(options);
} else {
if (j == 0) {
session = new Session(options);
} else {
session = sessions[0];
}
}
if (session == nullptr) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (std::string("Create session failed.")).c_str());
return;
}
sessions.push_back(session);
}
}
void ModelInstanceState::CreateThreads(int dev_id, int dev_index, int device_exec_block, std::mutex &mu,
aclrtContext context_, const std::vector<ge::Session *> &sessions,
std::vector<std::thread> &threads)
{
for (int j = 0; j < device_exec_block; j++) {
int graph_id = j + device_exec_block * dev_index + model_state_->GetInitStatus() * 1000;
threads.emplace_back([this, &mu, graph_id, context_, dev_id, sessions, j]() {
InitializeGraphSession(graph_id, dev_id, context_, mu, sessions[j]);
});
}
}
void ModelInstanceState::InitializeGraphSession(int graph_id, int dev_id, aclrtContext context_, std::mutex &mu,
ge::Session *session)
{
try {
int result = InitGraphSession(dev_id, graph_id, context_, mu, session);
if (result == RET_ERR) {
init_failed_.store(true);
RecordInitFailure(graph_id);
}
} catch (...) {
init_failed_.store(true);
RecordInitException();
}
}
void ModelInstanceState::RecordInitFailure(int graph_id)
{
std::lock_guard<std::mutex> lock(exception_mutex_);
if (!init_exception_) {
init_exception_ = std::make_exception_ptr(
std::runtime_error("InitGraphSession failed for graph_id: " + std::to_string(graph_id)));
}
}
void ModelInstanceState::RecordInitException()
{
std::lock_guard<std::mutex> lock(exception_mutex_);
if (!init_exception_) {
init_exception_ = std::current_exception();
}
}
void ModelInstanceState::JoinAllThreads(std::vector<std::thread> &threads)
{
for (auto &thread : threads) {
if (thread.joinable()) {
thread.join();
}
}
}
int ModelInstanceState::Init()
{
thread_id_ = next_id.fetch_add(1);
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("ModelName: ") + model_state_->GetModelName()).c_str());
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
(std::string("Thread initializing... : thread id ") + std::to_string(thread_id_)).c_str());
if (model_state_->GetInitStatus() == -1) {
return InitializeGlobalResources();
}
return RET_OK;
}
int ModelInstanceState::InitializeGlobalResources()
{
int lock = lock_id.fetch_add(1);
model_state_->SetInitStatus(lock);
if (lock == 0) {
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("ModelName: ") + model_state_->GetModelName() +
std::string(" start init env , lock_id : ") + std::to_string(lock))
.c_str());
if (InitGEEnvironment() != RET_OK) {
return RET_ERR;
}
} else {
while (1) {
if (notify_id.load() == lock) {
break;
}
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("ModelName: ") + model_state_->GetModelName() +
std::string(" wait , lock_id : ") + std::to_string(lock))
.c_str());
sleep(1);
}
}
std::vector<int> dev_ids_ = model_state_->GetDeviceIds();
if (InitDevices(dev_ids_) != RET_OK) {
return RET_ERR;
}
return InitializeDeviceThreads(dev_ids_, lock);
}
int ModelInstanceState::InitializeDeviceThreads(const std::vector<int> &dev_ids_, int lock)
{
int device_exec_block_ = CalculateDeviceExecBlock(dev_ids_.size());
LOG_MESSAGE(TRITONSERVER_LOG_INFO,
(std::string("final device exec block count: ") + std::to_string(device_exec_block_)).c_str());
std::vector<std::thread> threads;
std::mutex mu;
for (size_t i = 0; i < dev_ids_.size(); i++) {
CreateDeviceThreads(dev_ids_[i], i, device_exec_block_, threads, mu);
}
JoinAllThreads(threads);
if (init_failed_.load()) {
HandleInitFailure();
return RET_ERR;
}
inference_->Start();
notify_id.store(lock + 1);
return RET_OK;
}
void ModelInstanceState::HandleInitFailure()
{
if (init_exception_) {
try {
std::rethrow_exception(init_exception_);
} catch (const std::exception &e) {
std::cerr << "Init failed with exception: " << e.what() << std::endl;
}
}
}
TRITONSERVER_Error *ModelInstanceState::Create(ModelState *model_state,
TRITONBACKEND_ModelInstance *triton_model_instance,
ModelInstanceState **state)
{
*state = new ModelInstanceState(model_state, triton_model_instance);
if ((*state)->Init() != RET_OK) {
return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNKNOWN,
"init instance error, please check log for more info.");
}
return nullptr;
}
ModelInstanceState::ModelInstanceState(ModelState *model_state, TRITONBACKEND_ModelInstance *triton_model_instance)
: BackendModelInstance(model_state, triton_model_instance),
model_state_(model_state)
{
inference_ = new Inference(model_state);
}
ModelInstanceState::~ModelInstanceState()
{
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (std::string("ModelInstanceState Release Start ")).c_str());
if (inference_ != nullptr) {
delete inference_;
inference_ = nullptr;
}
}
int ModelInstanceState::ProcessRequests(TRITONBACKEND_Request **requests, const uint32_t request_count)
{
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("request_count ") + to_string(request_count)).c_str());
return inference_->HandleRequest(requests, request_count);
}
}