* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "dflow/compiler/data_flow_graph/inner_pp_loader.h"
#include "dflow/compiler/data_flow_graph/compile_config_json.h"
#include "dflow/base/model/model_relation.h"
#include "common/file_constant_utils/file_constant_utils.h"
#include "common/plugin/ge_make_unique_util.h"
#include "dflow/base/model/flow_model_om_loader.h"
#include "dflow/flow_graph/data_flow_attr_define.h"
#include "dflow/inc/data_flow/model/flow_model_helper.h"
#include "common/checker.h"
namespace ge {
namespace {
const std::string kModelPpFusionInputs = "INVOKED_MODEL_FUSION_INPUTS";
}
Status InnerPpLoader::LoadProcessPoint(const dataflow::ProcessPoint &process_point, DataFlowGraph &data_flow_graph,
const NodePtr &node) {
const auto &pp_extend_attrs = process_point.pp_extend_attrs();
const auto &find_ret = pp_extend_attrs.find(dflow::INNER_PP_CUSTOM_ATTR_INNER_TYPE);
if (find_ret == pp_extend_attrs.cend()) {
GELOGE(FAILED, "custom inner pp type attr[%s] does not exist.", dflow::INNER_PP_CUSTOM_ATTR_INNER_TYPE);
return FAILED;
}
const auto &inner_type = find_ret->second;
if (inner_type == dflow::INNER_PP_TYPE_MODEL_PP) {
return LoadModelPp(process_point, data_flow_graph, node);
} else {
GELOGE(FAILED, "load inner pp failed, inner pp type[%s] is unknown, node=%s.", inner_type.c_str(),
node->GetName().c_str());
return FAILED;
}
return SUCCESS;
}
Status InnerPpLoader::LoadModelPp(const dataflow::ProcessPoint &process_point, DataFlowGraph &data_flow_graph,
const NodePtr &node) {
const auto &pp_extend_attrs = process_point.pp_extend_attrs();
const auto &find_ret = pp_extend_attrs.find(dflow::INNER_PP_CUSTOM_ATTR_MODEL_PP_MODEL_PATH);
if (find_ret == pp_extend_attrs.cend()) {
GELOGE(FAILED, "cannot find extend attr[%s], node=%s, pp name=%s.", dflow::INNER_PP_CUSTOM_ATTR_MODEL_PP_MODEL_PATH,
node->GetName().c_str(), process_point.name().c_str());
return FAILED;
}
const std::string &model_path = find_ret->second;
const std::string &pp_name = process_point.name();
const std::string &compile_file = process_point.compile_cfg_file();
std::function<Status()> load_task = [&data_flow_graph, model_path, pp_name, node, compile_file]() {
FlowModelPtr flow_model = nullptr;
GE_CHK_STATUS_RET(LoadModel(data_flow_graph, model_path, flow_model), "load model failed, model_path=%s.",
model_path.c_str());
flow_model->SetModelName(pp_name);
const auto &root_graph = flow_model->GetRootGraph();
if (flow_model->GetModelRelation() == nullptr) {
auto model_relation = MakeUnique<ModelRelation>();
GE_CHECK_NOTNULL(model_relation);
GE_CHK_STATUS_RET(ModelRelationBuilder().BuildForSingleModel(*root_graph, *model_relation),
"Failed to build model relation for graph[%s].", pp_name.c_str());
flow_model->SetModelRelation(std::shared_ptr<ModelRelation>(model_relation.release()));
}
if (!compile_file.empty()) {
CompileConfigJson::ModelPpConfig model_pp_cfg = {};
GE_CHK_STATUS_RET(CompileConfigJson::ReadModelPpConfigFromJsonFile(compile_file, model_pp_cfg),
"Failed to read process point[%s] compile config.", compile_file.c_str());
if (!model_pp_cfg.invoke_model_fusion_inputs.empty()) {
(void)AttrUtils::SetStr(root_graph, kModelPpFusionInputs, model_pp_cfg.invoke_model_fusion_inputs);
}
GELOGI("Parse compile config file [%s] success. fusion inputs config [%s]", compile_file.c_str(),
model_pp_cfg.invoke_model_fusion_inputs.c_str());
}
GE_CHK_STATUS_RET(data_flow_graph.AddLoadedModel(node->GetName(), pp_name, flow_model),
"Failed to add loaded model, pp_name:%s", pp_name.c_str());
return SUCCESS;
};
GE_CHK_STATUS_RET(data_flow_graph.CommitPreprocessTask(pp_name, load_task),
"Failed to commit load model pp task[%s].", pp_name.c_str());
return SUCCESS;
}
Status InnerPpLoader::LoadModel(const DataFlowGraph &data_flow_graph, const std::string &model_path,
FlowModelPtr &flow_model) {
flow_model = nullptr;
GE_CHK_STATUS_RET(FlowModelHelper::LoadToFlowModel(model_path, flow_model),
"load to flow model failed, model_path=%s", model_path.c_str());
const auto &root_graph = data_flow_graph.GetRootGraph();
uint64_t session_id = root_graph->GetSessionID();
uint32_t graph_id = root_graph->GetGraphID();
std::string file_constant_weight_dir;
GE_ASSERT_SUCCESS(FileConstantUtils::GetExternalWeightDirFromOmPath(model_path, file_constant_weight_dir));
GE_CHK_STATUS_RET(
FlowModelOmLoader::RefreshModel(flow_model, file_constant_weight_dir, session_id, graph_id),
"Failed to assign constant mem for model");
const std::string offline_session_graph_id = "-1_" + std::to_string(graph_id);
GE_CHK_STATUS_RET(FlowModelHelper::UpdateSessionGraphId(flow_model, offline_session_graph_id),
"Failed to update flow model session graph id, session_graph_id:%s",
offline_session_graph_id.c_str());
return SUCCESS;
}
}