* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
*/
#ifndef MODEL_INSTANCE_STATE_H
#define MODEL_INSTANCE_STATE_H
#include <iostream>
#include <string>
#include <cstring>
#include <vector>
#include <memory>
#include <map>
#include <thread>
#include <mutex>
#include <atomic>
#include <regex>
#include <unistd.h>
#include "triton/backend/backend_common.h"
#include "triton/backend/backend_model_instance.h"
#include "triton/backend/backend_model.h"
#include "triton/core/tritonbackend.h"
#include "nlohmann/json.hpp"
#include "graph.h"
#include "types.h"
#include "tensor.h"
#include "ge_error_codes.h"
#include "ge_api_types.h"
#include "ge_api.h"
#include "acl/acl.h"
#include "onnx_parser.h"
#include "tensorflow_parser.h"
#include "model_state.h"
#include "scheduler.h"
#include "inference.h"
using namespace ge;
using json = nlohmann::json;
using namespace std;
namespace triton::backend::npu_ge {
static std::atomic<int> next_id{0};
static std::atomic<int> lock_id{0};
static std::atomic<int> notify_id{0};
class ModelInstanceState : public BackendModelInstance {
public:
static TRITONSERVER_Error *Create(ModelState *model_state, TRITONBACKEND_ModelInstance *triton_model_instance,
ModelInstanceState **state);
ModelInstanceState(ModelState *model_state, TRITONBACKEND_ModelInstance *triton_model_instance);
virtual ~ModelInstanceState();
int ProcessRequests(TRITONBACKEND_Request **requests, const uint32_t request_count);
int Init();
static std::string GetEnvVar(const std::string &name);
static void LoadGeConfig(std::string &env_value, std::map<ge::AscendString, ge::AscendString> &configMap);
void ConfigureParserOptions(std::map<ge::AscendString, ge::AscendString> &parser_options);
bool ParseGraph(std::map<ge::AscendString, ge::AscendString> &parser_options, ge::Graph &graph1, std::mutex &mu);
bool AddAndCompileGraph(ge::Session *session_, int graph_id, ge::Graph &graph1, ge::Status &ret);
bool CreateAndLoadStream(ge::Session *session_, int graph_id, aclrtStream &stream_, ge::Status &ret,
aclError &aclRet);
void ConfigureDumpOptions(std::map<ge::AscendString, ge::AscendString> &options);
bool SetDeviceAndContext(int dev_id, aclrtContext &context_);
void CreateSessions(int dev_index, int device_exec_block,
const std::map<ge::AscendString, ge::AscendString> &options,
std::vector<ge::Session *> &sessions);
void CreateThreads(int dev_id, int dev_index, int device_exec_block, std::mutex &mu, aclrtContext context_,
const std::vector<ge::Session *> &sessions, std::vector<std::thread> &threads);
void RecordInitFailure(int graph_id);
void RecordInitException();
int InitializeGlobalResources();
int InitializeDeviceThreads(const std::vector<int> &dev_ids_, int lock);
void HandleInitFailure();
void InitializeGraphSession(int graph_id, int dev_id, aclrtContext context_, std::mutex &mu, ge::Session *session);
void ConfigureGeOptions(std::map<ge::AscendString, ge::AscendString> &options);
private:
int InitGEEnvironment();
int InitDevices(std::vector<int> &dev_ids_);
int InitGraphSession(int dev_id, int graph_id, aclrtContext context, std::mutex &mu, ge::Session *session_);
void StaticModeConfig(std::map<ge::AscendString, ge::AscendString> &parser_options);
void StaticModeConfigOne(std::map<ge::AscendString, ge::AscendString> &parser_options);
int CalculateDeviceExecBlock(int device_count);
void CreateDeviceThreads(int dev_id, int dev_index, int device_exec_block, std::vector<std::thread> &threads,
std::mutex &mu);
void JoinAllThreads(std::vector<std::thread> &threads);
ModelState *model_state_;
std::string model_config_path_;
std::string model_path_;
std::string acl_path_;
int thread_id_ = -1;
std::atomic<bool> init_failed_{false};
std::exception_ptr init_exception_{nullptr};
std::mutex exception_mutex_;
Inference *inference_ = nullptr;
};
}
#endif