* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <memory>
#include "framework/common/ge_types.h"
#include "graph/ge_context.h"
#include "hybrid/model/hybrid_model.h"
#include "hybrid/executor/hybrid_model_async_executor.h"
#include "hybrid/node_executor/node_executor.h"
#include "graph/manager/graph_manager_utils.h"
#include "hybrid/hybrid_davinci_model.h"
namespace ge {
namespace hybrid {
class HybridDavinciModel::Impl {
public:
explicit Impl(GeRootModelPtr ge_model) : model_(std::move(ge_model)), executor_(&model_), load_stream_(nullptr) {}
~Impl() {
NodeExecutorManager::GetInstance().FinalizeExecutors();
}
Status ResolveStreamPolicy() {
constexpr const char_t *kParallelModeMultiStreams = "0";
constexpr const char_t *kParallelModeSerial = "1";
constexpr const char_t *kParallelModeSingleStream = "2";
const std::set<std::string>
kValidValues = {"", kParallelModeMultiStreams, kParallelModeSerial, kParallelModeSingleStream};
std::string parallel_mode;
(void) GetContext().GetOption(OPTION_EXEC_DYNAMIC_GRAPH_PARALLEL_MODE, parallel_mode);
GE_CHK_BOOL_RET_STATUS(kValidValues.count(parallel_mode) > 0, PARAM_INVALID,
"Option %s is invalid, value = [%s]",
OPTION_EXEC_DYNAMIC_GRAPH_PARALLEL_MODE, parallel_mode.c_str());
use_default_stream_ = (parallel_mode == kParallelModeSingleStream);
GELOGI("Option %s = [%s]", OPTION_EXEC_DYNAMIC_GRAPH_PARALLEL_MODE, parallel_mode.c_str());
return ge::SUCCESS;
}
Status Init() {
GE_CHK_STATUS_RET(ResolveStreamPolicy(), "[Initialize][ResolveStreamPolicy] failed");
GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().EnsureInitialized(),
"[Initialize][NodeExecutorManager] failed");
GE_CHK_STATUS_RET(model_.Init(), "[Init][HybridModel] failed.");
GE_CHK_STATUS_RET(executor_.Init(load_stream_), "[Init][HybridModelAsyncExecutor] failed.");
return SUCCESS;
}
Status Execute(const std::vector<DataBuffer> &inputs,
const std::vector<GeTensorDesc> &input_desc,
std::vector<DataBuffer> &outputs,
std::vector<GeTensorDesc> &output_desc,
const aclrtStream stream) {
const auto main_stream = use_default_stream_ ? nullptr : stream;
return executor_.Execute(inputs, input_desc, outputs, output_desc, main_stream);
}
Status Execute(const std::vector<gert::Tensor> &inputs, std::vector<gert::Tensor> &outputs) {
return executor_.Execute(inputs, outputs);
}
Status ExecuteWithStreamAsync(const std::vector<gert::Tensor> &inputs,
std::vector<gert::Tensor> &outputs,
const aclrtStream stream) {
return executor_.ExecuteWithStreamAsync(inputs, outputs, stream);
}
Status ExecuteWithStreamAsync(const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs,
const aclrtStream stream) {
return executor_.ExecuteWithStreamAsync(inputs, outputs, stream);
}
Status ModelRunStart() {
return executor_.Start(listener_);
}
Status ModelRunStop() {
return executor_.Stop();
}
Status EnqueueData(const std::shared_ptr<RunArgs> &args) {
return executor_.EnqueueData(args);
}
void SetListener(const shared_ptr<ModelListener> &listener) {
listener_ = listener;
}
void SetModelId(const uint32_t model_id) {
executor_.SetModelId(model_id);
model_.SetModelId(model_id);
}
void SetDeviceId(const uint32_t device_id) {
model_.SetDeviceId(device_id);
executor_.SetDeviceId(device_id);
}
void SetOmName(const std::string &model_name) {
model_.SetOmName(model_name);
}
void SetLoadStream(const aclrtStream stream) {
load_stream_ = stream;
}
void SetFileConstantWeightDir(const std::string &file_constant_weight_dir) {
model_.SetFileConstantWeightDir(file_constant_weight_dir);
}
uint32_t GetDeviceId() const {
return model_.GetDeviceId();
}
uint64_t GetGlobalStepAddr() const {
return PtrToValue(model_.GetGlobalStep());
}
const GraphExecutionContext *GeContext() const { return executor_.GeContext(); }
GraphExecutionContext *GeContext() { return executor_.GeContext(); }
uint64_t GetSessionId() const {
return model_.GetSessionId();
}
Status GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) const {
return model_.GetDynamicBatchInfo(batch_info, dynamic_type);
}
void GetUserDesignateShapeOrder(std::vector<std::string> &user_input_shape_order) const {
model_.GetUserDesignateShapeOrder(user_input_shape_order);
}
void GetModelAttr(std::vector<std::string> &dynamic_output_shape_info) const {
model_.GetModelAttr(dynamic_output_shape_info);
}
Status GetInputOutputDescInfo(std::vector<InputOutputDescInfo> &input_desc,
std::vector<InputOutputDescInfo> &output_desc,
std::vector<uint32_t> &input_formats,
std::vector<uint32_t> &output_formats) {
return model_.GetInputOutputDescInfo(input_desc, output_desc, input_formats, output_formats);
}
void SetModelDescVersion(const bool is_new_model_desc) {
model_.SetModelDescVersion(is_new_model_desc);
}
uint32_t GetDataInputerSize() const { return executor_.GetDataInputerSize(); }
bool GetRunningFlag() const { return executor_.GetRunningFlag(); }
Status SetRunAsyncListenerCallback(const RunAsyncCallbackV2 &callback) const {
const auto listener = dynamic_cast<RunAsyncListener *>(listener_.get());
GE_CHECK_NOTNULL(listener);
listener->SetCallback(callback);
return SUCCESS;
}
Status GetOpAttr(const std::string &op_name, const std::string &attr_name, std::string &attr_value) const {
return model_.GetOpAttr(op_name, attr_name, attr_value);
}
Status GetAippInfo(const uint32_t index, AippConfigInfo &aipp_info) const {
return model_.GetAippInfo(index, aipp_info);
}
Status GetAippType(const uint32_t index, InputAippType &aipp_type, size_t &aipp_index) const {
return model_.GetAippType(index, aipp_type, aipp_index);
}
Status ReportProfilingData() const {
return model_.ReportProfilingData();
}
private:
std::shared_ptr<ModelListener> listener_;
HybridModel model_;
HybridModelAsyncExecutor executor_;
aclrtStream load_stream_ = nullptr;
bool use_default_stream_ = false;
};
HybridDavinciModel::~HybridDavinciModel() {
delete impl_;
}
std::unique_ptr<HybridDavinciModel> HybridDavinciModel::Create(const GeRootModelPtr &ge_root_model) {
auto instance = std::unique_ptr<HybridDavinciModel>(new (std::nothrow)HybridDavinciModel());
if (instance != nullptr) {
instance->impl_ = new (std::nothrow) HybridDavinciModel::Impl(ge_root_model);
if (instance->impl_ != nullptr) {
return instance;
}
}
return nullptr;
}
Status HybridDavinciModel::Init() {
GE_CHECK_NOTNULL(impl_);
return impl_->Init();
}
Status HybridDavinciModel::Execute(const std::vector<DataBuffer> &inputs,
const std::vector<GeTensorDesc> &input_desc,
std::vector<DataBuffer> &outputs,
std::vector<GeTensorDesc> &output_desc,
const aclrtStream stream) {
GE_CHECK_NOTNULL(impl_);
return impl_->Execute(inputs, input_desc, outputs, output_desc, stream);
}
Status HybridDavinciModel::Execute(const std::vector<gert::Tensor> &inputs, std::vector<gert::Tensor> &outputs) {
GE_CHECK_NOTNULL(impl_);
return impl_->Execute(inputs, outputs);
}
Status HybridDavinciModel::ExecuteWithStreamAsync(const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs,
const aclrtStream stream) {
GE_CHECK_NOTNULL(impl_);
return impl_->ExecuteWithStreamAsync(inputs, outputs, stream);
}
Status HybridDavinciModel::ExecuteWithStreamAsync(const std::vector<gert::Tensor> &inputs,
std::vector<gert::Tensor> &outputs,
const aclrtStream stream) {
GE_CHECK_NOTNULL(impl_);
return impl_->ExecuteWithStreamAsync(inputs, outputs, stream);
}
Status HybridDavinciModel::ModelRunStart() {
GE_CHECK_NOTNULL(impl_);
return impl_->ModelRunStart();
}
Status HybridDavinciModel::ModelRunStop() {
GE_CHECK_NOTNULL(impl_);
return impl_->ModelRunStop();
}
Status HybridDavinciModel::EnqueueData(const shared_ptr<RunArgs> &args) {
GE_CHECK_NOTNULL(impl_);
return impl_->EnqueueData(args);
}
void HybridDavinciModel::SetListener(const shared_ptr<ModelListener> &listener) {
if (impl_ != nullptr) {
impl_->SetListener(listener);
}
}
void HybridDavinciModel::SetModelId(const uint32_t model_id) {
if (impl_ != nullptr) {
impl_->SetModelId(model_id);
}
}
void HybridDavinciModel::SetDeviceId(const uint32_t device_id) {
if (impl_ != nullptr) {
impl_->SetDeviceId(device_id);
}
}
void HybridDavinciModel::SetOmName(const std::string &om_name) {
if (impl_ != nullptr) {
impl_->SetOmName(om_name);
}
}
void HybridDavinciModel::SetFileConstantWeightDir(const std::string &file_constant_weight_dir) {
if (impl_ != nullptr) {
impl_->SetFileConstantWeightDir(file_constant_weight_dir);
}
}
void HybridDavinciModel::SetLoadStream(const aclrtStream stream) {
if (impl_ != nullptr) {
impl_->SetLoadStream(stream);
}
}
uint32_t HybridDavinciModel::GetDeviceId() const {
GE_CHECK_NOTNULL(impl_);
return impl_->GetDeviceId();
}
uint64_t HybridDavinciModel::GetGlobalStepAddr() const {
if (impl_ == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Param: impl_ is nullptr, check invalid");
GELOGE(ge::FAILED, "[Check][Param: impl_]null is invalid");
return 0UL;
}
return impl_->GetGlobalStepAddr();
}
Status HybridDavinciModel::GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info,
int32_t &dynamic_type) const {
GE_CHECK_NOTNULL(impl_);
return impl_->GetDynamicBatchInfo(batch_info, dynamic_type);
}
void HybridDavinciModel::GetUserDesignateShapeOrder(std::vector<std::string> &user_input_shape_order) const {
if (impl_ != nullptr) {
impl_->GetUserDesignateShapeOrder(user_input_shape_order);
}
}
void HybridDavinciModel::GetOutputShapeInfo(std::vector<std::string> &dynamic_output_shape_info) const {
if (impl_ != nullptr) {
impl_->GetModelAttr(dynamic_output_shape_info);
}
}
Status HybridDavinciModel::GetInputOutputDescInfo(std::vector<InputOutputDescInfo> &input_desc,
std::vector<InputOutputDescInfo> &output_desc,
std::vector<uint32_t> &input_formats,
std::vector<uint32_t> &output_formats) {
GE_CHECK_NOTNULL(impl_);
return impl_->GetInputOutputDescInfo(input_desc, output_desc, input_formats, output_formats);
}
void HybridDavinciModel::SetModelDescVersion(const bool is_new_model_desc) {
if (impl_ != nullptr) {
impl_->SetModelDescVersion(is_new_model_desc);
}
}
uint64_t HybridDavinciModel::GetSessionId() {
GE_CHECK_NOTNULL(impl_);
return impl_->GetSessionId();
}
uint32_t HybridDavinciModel::GetDataInputerSize() const {
GE_CHECK_NOTNULL(impl_);
return impl_->GetDataInputerSize();
}
bool HybridDavinciModel::GetRunningFlag() const {
if (impl_ == nullptr) {
return false;
}
return impl_->GetRunningFlag();
}
Status HybridDavinciModel::SetRunAsyncListenerCallback(
const RunAsyncCallbackV2 &callback) {
GE_CHECK_NOTNULL(impl_);
return impl_->SetRunAsyncListenerCallback(callback);
}
bool HybridDavinciModel::GetOpDescInfo(const uint32_t stream_id, const uint32_t task_id,
OpDescInfo &op_desc_info) const {
if (impl_ == nullptr) {
return false;
}
auto context = impl_->GeContext();
GE_RT_FALSE_CHECK_NOTNULL(context);
const bool ret =
context->exception_dumper.GetOpDescInfo(OpDescInfoId(stream_id, task_id, GetDeviceId()), op_desc_info);
if (!ret) {
for (const auto &iter : context->davinci_model) {
if (iter->GetOpDescInfo(stream_id, task_id, op_desc_info)) {
return true;
}
}
}
return ret;
}
Status HybridDavinciModel::GetOpAttr(const std::string &op_name, const std::string &attr_name,
std::string &attr_value) const {
GE_CHECK_NOTNULL(impl_);
return impl_->GetOpAttr(op_name, attr_name, attr_value);
}
Status HybridDavinciModel::GetAippInfo(const uint32_t index, AippConfigInfo &aipp_info) const {
GE_CHECK_NOTNULL(impl_);
return impl_->GetAippInfo(index, aipp_info);
}
Status HybridDavinciModel::GetAippType(const uint32_t index, InputAippType &aipp_type, size_t &aipp_data_index) const {
GE_CHECK_NOTNULL(impl_);
return impl_->GetAippType(index, aipp_type, aipp_data_index);
}
Status HybridDavinciModel::ReportProfilingData() const {
GE_CHECK_NOTNULL(impl_);
return impl_->ReportProfilingData();
}
}
}