* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/build/task_generator.h"
#include <cinttypes>
#include <string>
#include "graph/passes/memory_conflict/atomic_addr_clean_pass.h"
#include "graph/ge_context.h"
#include "graph/manager/graph_var_manager.h"
#include "graph/model_serialize.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "common/checker.h"
#include "api/gelib/gelib.h"
#include "engines/manager/opskernel_manager/ops_kernel_builder_manager.h"
#include "common/preload/model/pre_model_partition_utils.h"
#include "platform/platform_info.h"
#include "graph/utils/op_type_utils.h"
#include "graph/utils/graph_utils.h"
#include "task_generator_utils.h"
#include "common/preload/model/pre_model_utils.h"
#include "framework/common/tlv/pre_model_desc.h"
#include "common/compile_profiling/ge_trace_wrapper.h"
#include "base/err_mgr.h"
#include "acl/acl_rt.h"
using domi::LogTimeStampDef;
using domi::ModelTaskDef;
using domi::TaskDef;
namespace {
const char *const kIsFirstNode = "is_first_node";
const char *const kIsLastNode = "is_last_node";
const char *const kIsInputVar = "INPUT_IS_VAR";
const char *const kIsOutputVar = "OUTPUT_IS_VAR";
const std::set<std::string> kNanoSocVersion{"Ascend035", "Ascend035A", "Ascend035B"};
const int64_t kHashFactor = 100000;
const int64_t kInvalidGroupId = -1;
const int64_t kInvalidStream = -1;
const int32_t kInvalidDeviceId = -1;
const int64_t kDefaultThreadNum = 1;
const uint64_t invalidAddrType = 0xFFFFFFFFFFFFFFFFULL;
const std::set<std::string> kHcclDvppKernelInfoNames = {"ops_kernel_info_hccl", "ops_kernel_info_hccl_gradtune",
"hvd_ops_kernel_info", "dvpp_ops_kernel_info_store"};
const std::set<std::string> kAicpuKernelLibs = {"aicpu_ascend_kernel", "aicpu_tf_kernel"};
}
namespace ge {
namespace {
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) {
return !(node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH) ||
node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH));
};
enum class GenTaskCallKey {
kAtomicEngine = 0,
kFftsEngine,
};
const unordered_map<GenTaskCallKey, std::string> key_2_string = {{GenTaskCallKey::kAtomicEngine, "Normal engine"},
{GenTaskCallKey::kFftsEngine, "FFTS engine"}};
static const std::set<std::string> kDataOpType{DATA, AIPPDATA, ANN_DATA};
GenTaskCallKey GetKey(const Node *const node) {
const auto &op_desc = node->GetOpDesc();
if (op_desc->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH)) {
return GenTaskCallKey::kFftsEngine;
} else {
return GenTaskCallKey::kAtomicEngine;
}
}
using GenTaskCall = std::function<Status(
TaskGenerator *, Node *, const std::string &, std::vector<domi::TaskDef> &task_def_list_per_node,
const GEThreadLocalContext &, const error_message::ErrorManagerContext &, int32_t)>;
uint32_t GetThreadNum() {
const char_t *value = nullptr;
MM_SYS_GET_ENV(MM_ENV_MAX_COMPILE_CORE_NUMBER, value);
int64_t thread_num = ((value != nullptr) && (value[0U] != '\0')) ?
std::strtol(value, nullptr, 10) : kDefaultThreadNum;
if (thread_num <= 0) {
GELOGW("Get invalid MAX_COMPILE_CORE_NUMBER env value %s, use default thread number %ld", value, kDefaultThreadNum);
thread_num = kDefaultThreadNum;
}
GELOGI("Thread num is %ld", thread_num);
return thread_num;
}
void RefreshStreamId(const OpDescPtr &op_desc, const size_t start_index, std::vector<domi::TaskDef> &task_defs) {
if (op_desc->HasValidAttachedStreamId() || op_desc->GetType() == "SuperKernel") {
GELOGI("Node {%s %s} keep origin stream info, task size %zu.", op_desc->GetNamePtr(), op_desc->GetTypePtr(),
task_defs.size());
return;
}
for (size_t index = start_index; index < task_defs.size(); ++index) {
auto stream_id = op_desc->GetStreamId() == kInvalidStream ? 0 : op_desc->GetStreamId();
task_defs[index].set_stream_id(static_cast<uint32_t>(stream_id));
}
}
bool NeedDoFusionTask() {
std::string buffer_optimize = "off_optimize";
(void)ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize);
GELOGI("Get buffer option %s with value %s", BUFFER_OPTIMIZE.c_str(), buffer_optimize.c_str());
return ((buffer_optimize == "l1_optimize") || (buffer_optimize == "l2_optimize"));
}
}
TaskGenerator::TaskGenerator(uint8_t *var_mem_base, uint64_t var_mem_size, RunContext *run_context) {
var_mem_base_ = var_mem_base;
var_mem_size_ = var_mem_size;
run_context_ = run_context;
}
TaskGenerator::~TaskGenerator() {}
Status TaskGenerator::AddModelTaskToModel(const ModelTaskDef &model_task_def, uint64_t session_id, ge::Model &model,
RunContext &run_context) const {
GE_CHK_BOOL_EXEC(
AttrUtils::SetInt(model, MODEL_ATTR_TASK_GEN_BASE_ADDR, reinterpret_cast<uintptr_t>(run_context.dataMemBase)),
REPORT_INNER_ERR_MSG("E19999", "Set Attr:%s fail for model:%s", MODEL_ATTR_TASK_GEN_BASE_ADDR.c_str(),
model.GetName().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for model:%s", MODEL_ATTR_TASK_GEN_BASE_ADDR.c_str(),
model.GetName().c_str());
return FAILED);
GE_CHK_BOOL_EXEC(
AttrUtils::SetInt(model, MODEL_ATTR_TASK_GEN_WEIGHT_ADDR, reinterpret_cast<uintptr_t>(run_context.weightMemBase)),
REPORT_INNER_ERR_MSG("E19999", "Set Attr:%s fail for model:%s", MODEL_ATTR_TASK_GEN_WEIGHT_ADDR.c_str(),
model.GetName().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for model:%s", MODEL_ATTR_TASK_GEN_WEIGHT_ADDR.c_str(),
model.GetName().c_str());
return FAILED);
GE_CHK_BOOL_EXEC(
AttrUtils::SetInt(model, ATTR_MODEL_TASK_GEN_VAR_ADDR, reinterpret_cast<uintptr_t>(var_mem_base_)),
REPORT_INNER_ERR_MSG("E19999", "Set Attr:%s fail for model:%s", ATTR_MODEL_TASK_GEN_VAR_ADDR.c_str(),
model.GetName().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for model:%s", ATTR_MODEL_TASK_GEN_VAR_ADDR.c_str(), model.GetName().c_str());
return FAILED);
GE_CHK_BOOL_EXEC(
AttrUtils::SetInt(model, ATTR_MODEL_VAR_SIZE, var_mem_size_),
REPORT_INNER_ERR_MSG("E19999", "Set Attr:%s fail for model:%s", ATTR_MODEL_VAR_SIZE.c_str(),
model.GetName().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for model:%s", ATTR_MODEL_VAR_SIZE.c_str(), model.GetName().c_str());
return FAILED);
GE_CHK_BOOL_EXEC(
AttrUtils::SetInt(model, MODEL_ATTR_SESSION_ID, session_id),
REPORT_INNER_ERR_MSG("E19999", "Set Attr:%s fail for mode:%s", MODEL_ATTR_SESSION_ID.c_str(),
model.GetName().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for mode:%s", MODEL_ATTR_SESSION_ID.c_str(), model.GetName().c_str());
return FAILED);
std::vector<std::string> task_index_op_name{std::make_move_iterator(op_names_.begin()),
std::make_move_iterator(op_names_.end())};
GE_ASSERT_TRUE(AttrUtils::SetListStr(model, ATTR_MODEL_TASK_INDEX_OP_NAME, task_index_op_name));
size_t task_size = model_task_def.ByteSizeLong();
ge::Buffer serial_buff(task_size);
google::protobuf::io::ArrayOutputStream array_stream(serial_buff.GetData(),
static_cast<int32_t>(serial_buff.GetSize()));
google::protobuf::io::CodedOutputStream output_stream(&array_stream);
output_stream.SetSerializationDeterministic(true);
GE_ASSERT_TRUE(model_task_def.SerializeToCodedStream(&output_stream));
GE_ASSERT_TRUE(AttrUtils::SetZeroCopyBytes(model, MODEL_ATTR_TASKS, std::move(serial_buff)));
return SUCCESS;
}
Status TaskGenerator::UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id) const {
GELOGD("Update is var attr, node[name:%s(%s), id:%ld, stream_id:%ld].", op_desc->GetName().c_str(),
op_desc->GetType().c_str(), op_desc->GetId(), op_desc->GetStreamId());
std::vector<int64_t> input_offsets = op_desc->GetInputOffset();
if (!(input_offsets.empty())) {
std::vector<bool> input_var;
size_t valid_input_index = 0;
for (uint32_t i = 0; i < op_desc->GetAllInputsSize(); i++) {
std::vector<int64_t> output_list;
auto input_tensor_desc = op_desc->MutableInputDesc(i);
if (input_tensor_desc == nullptr) {
continue;
}
if (valid_input_index >= input_offsets.size()) {
break;
}
int64_t inner_offset = 0;
(void)ge::AttrUtils::GetInt(input_tensor_desc, ATTR_NAME_INNER_OFFSET, inner_offset);
GELOGD("Node[%s] input[%u] has inner_offset[%ld]", op_desc->GetName().c_str(), i, inner_offset);
input_var.push_back(VarManager::Instance(session_id)->IsVarAddr(input_offsets[valid_input_index] - inner_offset));
valid_input_index++;
}
GE_CHK_BOOL_EXEC(AttrUtils::SetListBool(op_desc, kIsInputVar, input_var),
REPORT_INNER_ERR_MSG("E19999", "Set Attr:%s fail for op:%s(%s)", kIsInputVar,
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", kIsInputVar, op_desc->GetName().c_str(),
op_desc->GetType().c_str());
return FAILED);
}
std::vector<int64_t> output_offsets = op_desc->GetOutputOffset();
if (!(output_offsets.empty())) {
std::vector<bool> output_var;
size_t valid_output_index = 0;
for (uint32_t i = 0; i < op_desc->GetAllOutputsDescSize(); i++) {
std::vector<int64_t> output_list;
auto output_tensor_desc = op_desc->MutableOutputDesc(i);
if (output_tensor_desc == nullptr) {
continue;
}
if (valid_output_index >= output_offsets.size()) {
break;
}
int64_t inner_offset = 0;
(void)ge::AttrUtils::GetInt(output_tensor_desc, ATTR_NAME_INNER_OFFSET, inner_offset);
GELOGD("Node[%s] output[%u] has inner_offset[%ld]", op_desc->GetName().c_str(), i, inner_offset);
output_var.push_back(
VarManager::Instance(session_id)->IsVarAddr(output_offsets[valid_output_index] - inner_offset));
valid_output_index++;
}
GE_CHK_BOOL_EXEC(AttrUtils::SetListBool(op_desc, kIsOutputVar, output_var),
REPORT_INNER_ERR_MSG("E19999", "Set Attr:%s fail for op:%s(%s)", kIsOutputVar,
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", kIsOutputVar, op_desc->GetName().c_str(),
op_desc->GetType().c_str());
return FAILED);
}
return SUCCESS;
}
Status TaskGenerator::GenTaskForPartiallySupportedNode(const NodePtr &node, RunContext &context,
std::vector<domi::TaskDef> &tasks) const {
bool partially_supported = false;
const auto &op_desc = node->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc);
(void)AttrUtils::GetBool(op_desc, kPartiallySupported, partially_supported);
if (!partially_supported) {
return SUCCESS;
}
const auto &instance_ptr = ge::GELib::GetInstance();
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][GELib] GELib is not init before.");
REPORT_INNER_ERR_MSG("E19999", "[Get][GELib] GELib is not init before, check invalid.");
return GE_CLI_GE_NOT_INITIALIZED;
}
std::string op_kernel_lib_name;
for (const auto &aicpu_kernel_name : kAicpuKernelLibs) {
auto kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(aicpu_kernel_name);
GE_ASSERT_NOTNULL(kernel_info);
std::vector<NodePtr> nodes;
nodes.emplace_back(node);
std::string unsupported_reason;
if (kernel_info->CheckSupported(op_desc, unsupported_reason)) {
GE_CHK_STATUS_RET(kernel_info->CompileOp(nodes),
"Failed to invoke compile op, libName = %s, type = %s, node = %s.", aicpu_kernel_name.c_str(),
op_desc->GetType().c_str(), op_desc->GetName().c_str());
op_kernel_lib_name = aicpu_kernel_name;
break;
} else {
GELOGI("Aicpu engine not support, kernel_name is %s, op type is %s, op name is %s", aicpu_kernel_name.c_str(),
op_desc->GetType().c_str(), op_desc->GetName().c_str());
}
}
if (op_kernel_lib_name.empty()) {
GELOGW("Partially supported task doesn't find aicpu ops kernel info.");
return SUCCESS;
}
const auto &ops_kernel_builder = OpsKernelBuilderManager::Instance().GetOpsKernelBuilder(op_kernel_lib_name);
GE_ASSERT_NOTNULL(ops_kernel_builder);
GELOGD("To invoke GenerateTask, node = %s, lib name = %s", node->GetName().c_str(), op_kernel_lib_name.c_str());
GE_CHK_STATUS_RET(ops_kernel_builder->GenerateTask(*node, context, tasks),
"Failed to invoke GenerateTask, libName = %s, node = %s", op_kernel_lib_name.c_str(),
node->GetName().c_str());
(void)AttrUtils::SetStr(node->GetOpDesc(), kAICpuKernelLibName, op_kernel_lib_name);
GELOGD("Done invoking GenerateTask successfully, kernel lib is %s.", op_kernel_lib_name.c_str());
return SUCCESS;
}
Status TaskGenerator::GenTaskForNodeByAliasEngine(const NodePtr &node, RunContext &context,
std::vector<domi::TaskDef> &tasks) const {
const auto &op_desc = node->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc);
std::string op_kernel_lib_name;
if (!AttrUtils::GetStr(op_desc, ATTR_NAME_ALIAS_ENGINE_NAME, op_kernel_lib_name)) {
return SUCCESS;
}
if (op_kernel_lib_name.empty()) {
GELOGW("node %s find empty kernel lib name.", op_desc->GetName().c_str());
return SUCCESS;
}
const auto &ops_kernel_builder = OpsKernelBuilderManager::Instance().GetOpsKernelBuilder(op_kernel_lib_name);
GE_ASSERT_NOTNULL(ops_kernel_builder);
GE_CHK_STATUS_RET(ops_kernel_builder->GenerateTask(*node, context, tasks),
"Failed to invoke GenerateTask, libName = %s, node = %s", op_kernel_lib_name.c_str(),
node->GetName().c_str());
GELOGD("Done invoking GenerateTask successfully, node = %s, lib name = %s", node->GetName().c_str(),
op_kernel_lib_name.c_str());
return SUCCESS;
}
Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) const {
GELOGD("Start UpdateAnchorStatus for %s.", node->GetName().c_str());
GE_ASSERT_GRAPH_SUCCESS(NodeUtils::SetAllAnchorStatus(node), "SetAllAnchorStatus fail for op:%s(%s)",
node->GetName().c_str(), node->GetType().c_str());
for (auto &anchor : node->GetAllInDataAnchors()) {
auto peer_anchor = anchor->GetPeerOutAnchor();
if (peer_anchor == nullptr) {
GE_ASSERT_GRAPH_SUCCESS(AnchorUtils::SetStatus(anchor, ANCHOR_SUSPEND),
"Set in peer anchor status fail for op:%s(%s), anchor_index:%d", node->GetName().c_str(),
node->GetType().c_str(), anchor->GetIdx());
continue;
}
std::string const_type;
bool is_const = NodeUtils::GetConstOpType(peer_anchor->GetOwnerNode(), const_type);
if (is_const && (const_type == CONSTANT) &&
((!ge::OpTypeUtils::IsDataNode(peer_anchor->GetOwnerNodeBarePtr()->GetType())) ||
(!NodeUtils::GetParentInput(peer_anchor->GetOwnerNode())->GetOwnerComputeGraph()->GetGraphUnknownFlag()))) {
GE_ASSERT_GRAPH_SUCCESS(AnchorUtils::SetStatus(anchor, ANCHOR_CONST),
"Set in anchor CONST status fail for op:%s(%s), anchor_index:%d", node->GetName().c_str(),
node->GetType().c_str(), anchor->GetIdx());
} else {
GE_ASSERT_GRAPH_SUCCESS(AnchorUtils::SetStatus(anchor, ANCHOR_DATA),
"Set in anchor DATA status fail for op:%s(%s), anchor_index:%d", node->GetName().c_str(),
node->GetType().c_str(), anchor->GetIdx());
}
}
return SUCCESS;
}
Status TaskGenerator::MarkNodeAndSetIndex(const ComputeGraphPtr &graph) const {
auto ge_lib = GELib::GetInstance();
if ((ge_lib == nullptr) || !ge_lib->InitFlag()) {
REPORT_INNER_ERR_MSG("E19999", "Check GELib instance not init before");
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized or is finalized.");
return GE_CLI_GE_NOT_INITIALIZED;
}
for (const auto &node : graph->GetAllNodes()) {
GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed.",
node->GetName().c_str(), node->GetType().c_str());
const auto &op_desc = node->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc);
GE_CHK_STATUS_RET(UpdateOpIsVarAttr(op_desc, graph->GetSessionID()));
}
std::unordered_set<Node *> need_to_gen_task_nodes;
GE_ASSERT_SUCCESS(MarkFirstAndLastOpsForGraph(graph, need_to_gen_task_nodes));
return SUCCESS;
}
Status TaskGenerator::MarkFirstAndLastOpsForGraph(const ComputeGraphPtr &graph, std::unordered_set<Node *> &target_nodes) const {
auto ge_lib = GELib::GetInstance();
GE_ASSERT_NOTNULL(ge_lib);
const auto all_nodes = graph->GetNodes(graph->GetGraphUnknownFlag());
GE_ASSERT(!all_nodes.empty(), "Check param all_nodes empty in graph:%s", graph->GetName().c_str());
std::map<int64_t, std::vector<Node *>> all_stream_nodes;
for (auto &node : all_nodes) {
const auto &op_desc = node->GetOpDesc();
if (op_desc->GetOpKernelLibName().empty()) {
(void)ge_lib->DNNEngineManagerObj().GetDNNEngineName(node);
}
(void)op_desc->DelAttr(kIsFirstNode);
(void)op_desc->DelAttr(kIsLastNode);
if (op_desc->GetStreamId() != kInvalidStream) {
all_stream_nodes[op_desc->GetStreamId()].emplace_back(node.get());
}
}
bool is_single_stream = all_stream_nodes.size() == 1;
for (const auto &stream_nodes : all_stream_nodes) {
Status status = MarkFirstAndLastOps(stream_nodes.second, is_single_stream, target_nodes);
if (status != SUCCESS) {
GELOGE(status, "[Mark][FirstAndLastOps] failed, graph:%s.", graph->GetName().c_str());
return status;
}
}
return SUCCESS;
}
Status TaskGenerator::MarkFirstAndLastOps(const std::vector<Node *> &nodes, bool is_single_stream,
std::unordered_set<Node *> &target_nodes) const {
std::vector<std::vector<Node *>> continuous_node_lists(1);
static const std::set<std::string> separator_types({LABELSET, LABELGOTOEX, LABELSWITCHBYINDEX, STREAMSWITCH});
for (auto &node : nodes) {
auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
bool attr_notask = false;
if (ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOTASK, attr_notask) && attr_notask) {
continue;
}
std::string op_type = op_desc->GetType();
if ((!is_single_stream && !op_desc->GetSubgraphInstanceNames().empty()) || separator_types.count(op_type) != 0) {
continuous_node_lists.emplace_back(std::vector<Node *>());
} else {
continuous_node_lists.back().emplace_back(node);
}
}
GELOGD("Number of continuous node lists is %zu.", continuous_node_lists.size());
for (const auto &continuous_nodes : continuous_node_lists) {
std::map<std::string, std::pair<Node *, Node *>> first_and_last_nodes;
for (auto &node : continuous_nodes) {
auto op_desc = node->GetOpDescBarePtr();
std::string op_kernel_lib_name = op_desc->GetOpKernelLibName();
if (op_kernel_lib_name.empty()) {
REPORT_INNER_ERR_MSG("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_kernel_lib_name.c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] node:%s(%s) get op kernel lib failed.", op_desc->GetName().c_str(),
op_desc->GetType().c_str());
return INTERNAL_ERROR;
}
auto it = first_and_last_nodes.find(op_kernel_lib_name);
if (it == first_and_last_nodes.end()) {
first_and_last_nodes.emplace(op_kernel_lib_name, std::make_pair(node, node));
} else {
it->second.second = node;
}
}
for (auto &it : first_and_last_nodes) {
auto &op_pair = it.second;
GE_ASSERT_TRUE(ge::AttrUtils::SetBool(op_pair.first->GetOpDescBarePtr(), kIsFirstNode, true),
"[Set][Attr] %s fail for op:%s(%s)", kIsFirstNode, op_pair.first->GetName().c_str(),
op_pair.first->GetType().c_str());
GE_ASSERT_TRUE(ge::AttrUtils::SetBool(op_pair.second->GetOpDescBarePtr(), kIsLastNode, true),
"[Set][Attr] %s fail for op:%s(%s)", kIsLastNode, op_pair.second->GetName().c_str(),
op_pair.second->GetType().c_str());
target_nodes.insert(op_pair.first);
target_nodes.insert(op_pair.second);
}
}
return SUCCESS;
}
std::unordered_map<int64_t , std::vector<domi::TaskDef>> &TaskGenerator::MutableNodeId2TaskDefs() {
return node_id_2_node_tasks_;
}
Status TaskGenerator::SetKernelInfo() {
const auto &ge_lib = GELib::GetInstance();
GE_ASSERT_NOTNULL(ge_lib);
GE_ASSERT_TRUE(ge_lib->InitFlag());
for (const auto node : nodes_) {
GE_ASSERT_NOTNULL(node);
const auto &op_desc = node->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc);
const std::string &op_kernel_lib_name = op_desc->GetOpKernelLibName();
if (kHcclDvppKernelInfoNames.count(op_kernel_lib_name) > 0) {
auto kernel_info_store = ge_lib->OpsKernelManagerObj().GetOpsKernelInfoStore(op_kernel_lib_name);
GE_ASSERT_NOTNULL(kernel_info_store);
(void)op_desc->SetExtAttr("OpsKernelInfoStorePtr", kernel_info_store.get());
}
}
return SUCCESS;
}
Status TaskGenerator::FilterCandidatesNodes(const ComputeGraphPtr &graph) {
for (const auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) {
const auto &op_desc = node->GetOpDesc();
if (NoNeedGenTask(op_desc)) {
continue;
}
nodes_.emplace_back(node.get());
}
return SUCCESS;
}
Status TaskGenerator::PrepareForGenerateTask(const ComputeGraphPtr &graph) {
GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "[Call][MarkNodeAndSetIndex] failed, graph:%s.",
graph->GetName().c_str());
GE_CHK_STATUS_RET(FilterCandidatesNodes(graph), "[Call][FilterCandidatesNodes] failed, graph:%s.",
graph->GetName().c_str());
GE_CHK_STATUS_RET(SetKernelInfo(), "[Set][KernelInfo] failed, graph:%s.", graph->GetName().c_str());
GE_CHK_STATUS_RET(FindProfilingNodeIndex(graph, profiling_point_), "[Call][FindProfilingNodeIndex] failed, graph:%s.",
graph->GetName().c_str());
return SUCCESS;
}
Status TaskGenerator::GenerateTaskForNormalNode(Node *const node, const std::string &tag,
std::vector<domi::TaskDef> &task_def_list_per_node,
const GEThreadLocalContext &ge_context,
const error_message::ErrorManagerContext &error_context,
int32_t device_id) {
GetThreadLocalContext() = ge_context;
error_message::SetErrMgrContext(error_context);
if (device_id != kInvalidDeviceId) {
GE_CHK_RT_RET(aclrtSetDevice(device_id));
}
GE_MAKE_GUARD(reset_device, [device_id]() {
if (device_id != kInvalidDeviceId) {
GE_CHK_RT(aclrtResetDevice(device_id));
}
});
return GenTaskForNormalNode(node, tag, task_def_list_per_node);
}
Status TaskGenerator::GenTaskForNormalNode(Node *const node, const std::string &tag,
std::vector<domi::TaskDef> &task_def_list_per_node) {
const auto &op_desc = node->GetOpDesc();
const size_t before_size = task_def_list_per_node.size();
GE_CHK_STATUS_RET(ProfilingTaskUtils::InsertProfilingTaskBefore(op_desc, profiling_point_, task_def_list_per_node),
"[Insert][profiling] task failed");
GE_ASSERT_NOTNULL(run_context_);
RunContext run_context = *run_context_;
GELOGI("%s node %s %s, lib name %s, start to gen task, origin task size is %zu.", tag.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_desc->GetOpKernelLibName().c_str(), before_size);
GE_CHK_STATUS_RET(OpsKernelBuilderManager::Instance().GenerateTask(*node, run_context, task_def_list_per_node),
"[Generate][Task] fail for op:%s(%s)", node->GetName().c_str(), node->GetType().c_str());
GE_CHK_STATUS_RET(GenTaskForPartiallySupportedNode(node->shared_from_this(), run_context, task_def_list_per_node),
"[Generate][PartiallySupportedTask] fail for op:%s(%s)", node->GetName().c_str(),
node->GetType().c_str());
GE_CHK_STATUS_RET(GenTaskForNodeByAliasEngine(node->shared_from_this(), run_context, task_def_list_per_node),
"[Generate][AliasTask] fail for op:%s(%s)", node->GetName().c_str(), node->GetType().c_str());
GE_CHK_STATUS_RET(ProfilingTaskUtils::InsertProfilingTaskAfter(op_desc, profiling_point_, task_def_list_per_node),
"[Insert][profiling] task failed");
GE_ASSERT_TRUE((task_def_list_per_node.size() >= before_size));
RefreshStreamId(op_desc, before_size, task_def_list_per_node);
GELOGI("%s node %s %s, lib name %s, gen task finished, generate %zu task(s).", tag.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_desc->GetOpKernelLibName().c_str(),
task_def_list_per_node.size() - before_size);
return SUCCESS;
}
Status TaskGenerator::GenerateTaskForFftsNode(Node *ffts_node, const std::string &tag,
std::vector<domi::TaskDef> &task_def_list_per_node,
const GEThreadLocalContext &ge_context,
const error_message::ErrorManagerContext &error_context,
int32_t device_id) {
GetThreadLocalContext() = ge_context;
error_message::SetErrMgrContext(error_context);
if (device_id != kInvalidDeviceId) {
GE_CHK_RT_RET(aclrtSetDevice(device_id));
}
GE_MAKE_GUARD(reset_device, [device_id]() {
if (device_id != kInvalidDeviceId) {
GE_CHK_RT(aclrtResetDevice(device_id));
}
});
const auto &op_desc = ffts_node->GetOpDesc();
if (!op_desc->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH)) {
return SUCCESS;
}
GE_CHK_STATUS_RET(ProfilingTaskUtils::InsertProfilingTaskBefore(op_desc, profiling_point_, task_def_list_per_node),
"[Insert][profiling] task failed");
{
const std::lock_guard<std::mutex> lock(ffts_mutex_);
if (ffts_inner_thread_pool_ == nullptr) {
ffts_inner_thread_pool_ = MakeUnique<ThreadPool>("ge_fftsgtsk", GetThreadNum(), false);
}
}
GE_ASSERT_NOTNULL(ffts_inner_thread_pool_);
std::vector<ComputeGraphPtr> subgraphs;
GE_CHK_STATUS_RET(NodeUtils::GetDirectSubgraphs(ffts_node->shared_from_this(), subgraphs),
"[Check][Param] Get subgraphs of node %s failed", op_desc->GetName().c_str());
std::map<int64_t, std::vector<domi::TaskDef>> node_id_2_node_tasks;
std::vector<std::future<Status>> vector_future;
GenTaskCall func = &TaskGenerator::GenerateTaskForNormalNode;
for (const auto &subgraph : subgraphs) {
for (const auto &node : subgraph->GetAllNodes()) {
auto tmp_op_dec = node->GetOpDesc();
GE_ASSERT_NOTNULL(tmp_op_dec);
if (NoNeedGenTask(tmp_op_dec)) {
continue;
}
std::vector<NodePtr> atomic_node_vec;
atomic_node_vec.emplace_back(node);
const auto &parent_graph = subgraph->GetParentGraph();
if ((parent_graph != nullptr) && (parent_graph->GetGraphUnknownFlag()) &&
(!subgraph->GetGraphUnknownFlag()) && (NodeUtils::IsLikeAtomicClean(node))) {
GE_CHK_STATUS_RET(AtomicAddrCleanPass::CallCompileOp(atomic_node_vec),
"compile single op failed, parent_graph:%s, subgraph:%s, node:%s, tmp_type:%d.",
parent_graph->GetName().c_str(), subgraph->GetName().c_str(),
node->GetName().c_str(), node->GetType().c_str());
}
auto &task_defs = node_id_2_node_tasks[node->GetOpDesc()->GetId()];
std::future<Status> f =
ffts_inner_thread_pool_->commit(func, this, node.get(), "FFTS INNER", std::ref(task_defs),
ge_context, error_context, device_id);
if (f.valid()) {
vector_future.emplace_back(std::move(f));
}
}
}
for (auto &i : vector_future) {
GE_CHK_STATUS_RET(i.get(), "[GenTask] gen inner ffts task ctx failed.");
}
for (const auto &iter : node_id_2_node_tasks) {
task_def_list_per_node.insert(task_def_list_per_node.end(), iter.second.begin(), iter.second.end());
}
GE_CHK_STATUS_RET(
OpsKernelBuilderManager::Instance().GenerateTask(*ffts_node, *run_context_, task_def_list_per_node, false),
"[Generate][Task] fail for ffts op:%s(%s)", ffts_node->GetName().c_str(), ffts_node->GetType().c_str());
GE_CHK_STATUS_RET(ProfilingTaskUtils::InsertProfilingTaskAfter(op_desc, profiling_point_, task_def_list_per_node),
"[Insert][profiling] task failed");
RefreshStreamId(op_desc, 0U, task_def_list_per_node);
GELOGI("%s node %s %s gen task finished, generate %zu task(s).", tag.c_str(), op_desc->GetName().c_str(),
op_desc->GetType().c_str(), task_def_list_per_node.size());
return SUCCESS;
}
Status TaskGenerator::SaveFusionNodes(std::map<int64_t, std::vector<NodePtr>> &fusion_nodes,
const std::vector<Node *> nodes) const {
std::map<NodePtr, int64_t> nodes_with_group_attr;
for (auto node_bare_ptr : nodes) {
GE_ASSERT_NOTNULL(node_bare_ptr);
const auto &node = node_bare_ptr->shared_from_this();
GE_ASSERT_NOTNULL(node);
const auto &op_desc = node->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc);
int64_t group_id = kInvalidGroupId;
std::string name = node->GetName();
std::string type = node->GetType();
if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_L1_FUSION_GROUP_ID, group_id) ||
ge::AttrUtils::GetInt(op_desc, ATTR_NAME_L2_FUSION_GROUP_ID, group_id)) {
auto stream_id = op_desc->GetStreamId();
auto group_key = group_id + stream_id * kHashFactor;
(void)ge::AttrUtils::SetInt(op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key);
GELOGD("Fusion: store node[name:%s(%s), group id:%ld, group key:%ld, stream_id:%ld] task.", name.c_str(),
type.c_str(), group_id, group_key, op_desc->GetStreamId());
fusion_nodes[group_key].push_back(node);
nodes_with_group_attr.insert({node, group_id});
}
bool call_check = true;
std::set<int64_t> input_group_ids;
for (const auto &input_node : node->GetInNodes()) {
std::map<NodePtr, int64_t>::const_iterator iter = nodes_with_group_attr.find(input_node);
if (iter == nodes_with_group_attr.cend()) {
call_check = false;
break;
} else {
input_group_ids.insert(iter->second);
}
}
call_check = (call_check && (input_group_ids.size() == 1));
if (call_check) {
auto input_group_id = *input_group_ids.cbegin();
if (group_id != input_group_id) {
GELOGW("Fusion: node[name:%s(%s) with group id:%ld and diff from it's input nodes's group id:%ld ",
name.c_str(), type.c_str(), group_id, input_group_id);
}
}
}
GELOGD("Fusion: get fusion group numbers [%zu].", fusion_nodes.size());
return SUCCESS;
}
Status TaskGenerator::GenerateTaskForFusionNode(Node *const node,
const std::map<int64_t, std::vector<NodePtr>> &fusion_nodes,
std::unordered_set<Node *> &fusion_nodes_seen) {
int64_t group_key;
const auto &fusion_op_desc = node->GetOpDesc();
if (!(ge::AttrUtils::GetInt(fusion_op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key) &&
(fusion_nodes_seen.count(node) == 0U))) {
return SUCCESS;
}
GELOGI("Fusion: start fusion group index[%ld], nodes size[%zu].", group_key, fusion_nodes.at(group_key).size());
for (const auto &fusion_node : fusion_nodes.at(group_key)) {
auto &task_defs = node_id_2_node_tasks_[fusion_node->GetOpDesc()->GetId()];
GE_ASSERT_SUCCESS(GenTaskForNormalNode(fusion_node.get(), "Fusion inner", task_defs));
fusion_ordered_node_list_.emplace_back(fusion_node->GetOpDesc()->GetId());
fusion_task_node_name_list_.emplace_back(fusion_op_desc->GetName());
fusion_nodes_seen.insert(fusion_node.get());
}
return SUCCESS;
}
Status TaskGenerator::GenTaskForFusionNodes(const std::map<int64_t, std::vector<NodePtr>> &fusion_nodes) {
fusion_nodes_seen_.clear();
fusion_ordered_node_list_.clear();
fusion_task_node_name_list_.clear();
int64_t group_key;
for (const auto node : nodes_) {
const auto &op_desc = node->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc);
const auto &name = op_desc->GetName();
const auto &type = op_desc->GetType();
GE_ASSERT_SUCCESS(GenerateTaskForFusionNode(node, fusion_nodes, fusion_nodes_seen_),
"[Call][GenerateTaskForFusionNode] node:%s(%s) failed", name.c_str(), type.c_str());
if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key)) {
GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str());
continue;
}
const auto node_id = op_desc->GetId();
auto &task_defs = node_id_2_node_tasks_[node_id];
GE_ASSERT_SUCCESS(GenTaskForNormalNode(node, "Fusion outer", task_defs));
fusion_ordered_node_list_.emplace_back(node_id);
fusion_task_node_name_list_.emplace_back(name);
}
return SUCCESS;
}
Status TaskGenerator::GenerateTaskForNodes(const std::vector<Node *> nodes) {
if (NeedDoFusionTask()) {
std::map<int64_t, std::vector<NodePtr>> fusion_nodes;
GE_RUN_PERF(TaskGenerator, SaveFusionNodes, fusion_nodes, nodes_);
if (!fusion_nodes.empty()) {
for(const auto &node : nodes_) {
node_id_2_node_tasks_[node->GetOpDesc()->GetId()].clear();
}
return GenTaskForFusionNodes(fusion_nodes);
}
}
for(const auto &node : nodes) {
node_id_2_node_tasks_[node->GetOpDesc()->GetId()].clear();
}
static std::map<GenTaskCallKey, GenTaskCall> handles =
{{GenTaskCallKey::kAtomicEngine, &TaskGenerator::GenerateTaskForNormalNode},
{GenTaskCallKey::kFftsEngine, &TaskGenerator::GenerateTaskForFftsNode}};
thread_pool_ = MakeUnique<ThreadPool>("ge_gentask", GetThreadNum(), false);
GE_ASSERT_NOTNULL(thread_pool_);
std::vector<std::future<Status>> vector_future;
int32_t device_id = kInvalidDeviceId;
(void)aclrtGetDevice(&device_id);
GELOGI("Get device id %d", device_id);
for (const auto node : nodes) {
const auto key = GetKey(node);
auto &task_defs = node_id_2_node_tasks_[node->GetOpDesc()->GetId()];
const auto &func = handles.find(key)->second;
std::future<Status> f = thread_pool_->commit(func, this, node, key_2_string.at(key), std::ref(task_defs),
GetThreadLocalContext(), error_message::GetErrMgrContext(), device_id);
if (f.valid()) {
vector_future.emplace_back(std::move(f));
}
}
for (auto &i : vector_future) {
GE_CHK_STATUS_RET(i.get(), "[GenTask] Fail!");
}
thread_pool_.reset(nullptr);
ffts_inner_thread_pool_.reset(nullptr);
return SUCCESS;
}
Status TaskGenerator::GenerateTask(const ComputeGraphPtr &graph, Model &model) {
GE_RUN_PERF(TaskGenerator, PrepareForGenerateTask, graph);
std::string soc_version;
(void)GetThreadLocalContext().GetOption(ge::SOC_VERSION, soc_version);
if (kNanoSocVersion.count(soc_version) > 0) {
PreRuntimeParam runtime_param;
auto model_ptr = ge::MakeShared<ge::Model>(model);
InitRuntimeParams(model_ptr, runtime_param);
GE_RUN_PERF(TaskGenerator, InitZeroCopyInfo, graph, runtime_param);
}
GE_RUN_PERF(TaskGenerator, GenerateTaskForNodes, nodes_);
return SUCCESS;
}
Status TaskGenerator::InitRuntimeParams(const ModelPtr &model_ptr, PreRuntimeParam &runtime_param) {
(void)AttrUtils::GetInt(model_ptr, ATTR_MODEL_MEMORY_SIZE, runtime_param.mem_size);
(void)AttrUtils::GetInt(model_ptr, ATTR_MODEL_WEIGHT_SIZE, runtime_param.weight_size);
return SUCCESS;
}
Status TaskGenerator::InitZeroCopyInfo(const ComputeGraphPtr &graph, const PreRuntimeParam &runtime_param) {
uint32_t search_id = 0U;
SymbolToAnchors symbol_to_anchors;
AnchorToSymbol anchor_to_symbol;
GE_ASSERT_SUCCESS(GraphUtils::GetRefMapping(graph, symbol_to_anchors, anchor_to_symbol),
"[Call][GetRefMapping] for graph:%s failed.", graph->GetName().c_str());
for (const auto &node : graph->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
const auto &op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (node->GetOwnerComputeGraph() != graph) {
return SUCCESS;
}
if (kDataOpType.count(node->GetType()) > 0U) {
GE_CHK_STATUS_RET(GenZeroCopyTable(node->GetOpDesc(), search_id, true),
"GenZeroCopyTable for node[%s] input failed", node->GetName().c_str());
for (const auto &anchor : node->GetAllOutDataAnchors()) {
for (const auto &in_anchors : anchor->GetPeerInDataAnchors()) {
if (in_anchors == nullptr) {
continue;
}
GE_IF_BOOL_EXEC(in_anchors == nullptr, continue);
auto owner_node = in_anchors->GetOwnerNode();
auto index = in_anchors->GetIdx();
GE_CHK_STATUS_RET(SetAnchorsOffset(owner_node, true, index, runtime_param, node->GetOpDesc()));
}
}
}
if (node->GetType() == NETOUTPUT) {
GE_CHK_STATUS_RET(GenZeroCopyTable(node->GetOpDesc(), search_id, false),
"GenZeroCopyTable for node[%s] input failed", node->GetName().c_str());
for (const auto &anchor : node->GetAllInDataAnchors()) {
GE_ASSERT_SUCCESS(
SetNetOutputNodeInAnchorAndPeerOffset(anchor, runtime_param, symbol_to_anchors, anchor_to_symbol));
auto in_anchor = anchor->GetPeerOutAnchor();
if (in_anchor == nullptr) {
continue;
}
auto owner_node = in_anchor->GetOwnerNode();
auto index = in_anchor->GetIdx();
GE_CHK_STATUS_RET(SetAnchorsOffset(owner_node, false, index, runtime_param, node->GetOpDesc()));
}
}
}
return SUCCESS;
}
Status TaskGenerator::SetNetOutputNodeInAnchorAndPeerOffset(const InDataAnchorPtr &in_anchor,
const PreRuntimeParam &runtime_param,
SymbolToAnchors &symbol_to_anchors,
AnchorToSymbol &anchor_to_symbol) {
auto out_node = in_anchor->GetOwnerNode();
auto out_node_inanchor_index = in_anchor->GetIdx();
auto out_opdesc = out_node->GetOpDesc();
auto ori_offset_list = out_opdesc->GetInputOffset();
const int64_t ori_offset = ori_offset_list.at(out_node_inanchor_index);
const auto zero_copy_offset_to_ids = PreModelPartitionUtils::GetInstance().GetZeroCopyTable();
GE_ASSERT_TRUE(zero_copy_offset_to_ids.find(ori_offset) != zero_copy_offset_to_ids.end(),
"NetOutput input no[%d] offset[%" PRId64 "] get zerocopy offset id failed.", out_node_inanchor_index,
ori_offset);
const uint32_t offset_to_id = zero_copy_offset_to_ids.find(ori_offset)->second;
NodeIndexIO out_node_in_anchor_node(out_node, out_node_inanchor_index, kIn);
const auto anchor_symbol_iter = anchor_to_symbol.find(out_node_in_anchor_node.ToString());
GELOGD("anchor[%s] origin offset[%ld] set to idx[%u]", out_node_in_anchor_node.ToString().c_str(), ori_offset, offset_to_id);
if (anchor_symbol_iter != anchor_to_symbol.end()) {
const auto &node_indexes = symbol_to_anchors[anchor_symbol_iter->second];
for (const auto &node_index : node_indexes) {
const bool is_input = node_index.io_type_ == kIn ? true : false;
const auto peer_node = node_index.node_;
const uint32_t peer_index = node_index.index_;
GE_IF_BOOL_EXEC(peer_node == nullptr, continue);
GE_ASSERT_SUCCESS(
SetNetOutputNodePeerNodeOffset(peer_node, is_input, peer_index, ori_offset, offset_to_id, runtime_param),
"NetOutput set peer offset fail, anchor symbol:%s", node_index.ToString().c_str());
}
}
return SUCCESS;
}
Status TaskGenerator::SetNetOutputNodePeerNodeOffset(const NodePtr &peer_node, const bool is_input, uint32_t index,
const int64_t ori_offset, const uint32_t offset_to_id,
const PreRuntimeParam &runtime_param) {
auto peer_op_desc = peer_node->GetOpDesc();
GE_ASSERT_NOTNULL(peer_op_desc);
uint64_t args_ddr_type = invalidAddrType;
std::vector<KernelArgsParam> args_param;
std::vector<uint64_t> args_offset_vals;
if (is_input) {
std::vector<uint32_t> index_to_valid_idx;
const auto input_data_addr_offset = PreModelUtils::GetInputDataAddrOffset(runtime_param, peer_op_desc, args_param,
args_offset_vals, index_to_valid_idx);
index = index_to_valid_idx[index];
args_ddr_type = args_param.empty() ? invalidAddrType : args_param.at(index).para;
} else {
const auto output_data_addr_offset =
PreModelUtils::GetOutputDataAddrOffset(runtime_param, peer_op_desc, args_param, args_offset_vals);
args_ddr_type = args_param.empty() ? invalidAddrType : args_param.at(index).para;
}
GELOGI("node[%s] %s no.%u, args_ddr_type[%lu]", peer_node->GetName().c_str(), is_input ? "input" : "output", index,
args_ddr_type);
if ((args_ddr_type != KERNEL_ARG_UPADTE_ADDR_TYPE_ARGS) && (args_ddr_type != KERNEL_ARG_UPADTE_ADDR_TYPE_WORKSPACE)) {
return SUCCESS;
}
auto ori_offset_list = is_input ? peer_op_desc->GetInputOffset() : peer_op_desc->GetOutputOffset();
GELOGD("get origin offset[%ld] set offset from[%ld] to idx[%d]", ori_offset_list.at(index), ori_offset, offset_to_id);
if (ori_offset == ori_offset_list.at(index)) {
ori_offset_list[index] = offset_to_id;
if (is_input) {
peer_op_desc->SetInputOffset(ori_offset_list);
} else {
peer_op_desc->SetOutputOffset(ori_offset_list);
}
return SUCCESS;
}
GELOGW("node[%s] %s no.%u offset[%ld] should equal to %ld", peer_node->GetName().c_str(),
is_input ? "input" : "output", index, ori_offset_list.at(index), ori_offset);
return SUCCESS;
}
Status TaskGenerator::SetAnchorsOffset(const NodePtr &owner_node, const bool is_input, const uint32_t index,
const PreRuntimeParam &runtime_param, const OpDescPtr &op_desc) {
auto owner_node_op_desc = owner_node->GetOpDesc();
GE_ASSERT_NOTNULL(owner_node_op_desc);
std::vector<KernelArgsParam> args_param;
std::vector<uint64_t> args_offset_vals;
uint64_t args_ddr_type = invalidAddrType;
uint32_t valid_idx = index;
if (is_input) {
std::vector<uint32_t> index_to_valid_idx;
const auto input_data_addr_offset =
PreModelUtils::GetInputDataAddrOffset(runtime_param, owner_node_op_desc, args_param, args_offset_vals,
index_to_valid_idx);
valid_idx = index_to_valid_idx[index];
args_ddr_type = args_param.empty() ? invalidAddrType : args_param.at(valid_idx).para;
} else {
const auto output_data_addr_offset =
PreModelUtils::GetOutputDataAddrOffset(runtime_param, owner_node_op_desc, args_param, args_offset_vals);
args_ddr_type = args_param.empty() ? invalidAddrType : args_param.at(valid_idx).para;
}
GELOGI("node[%s] %s no.%u valid_idx:%u, args_ddr_type[%lu]", owner_node->GetName().c_str(),
is_input ? "input": "output", index, valid_idx, args_ddr_type);
if ((args_ddr_type != KERNEL_ARG_UPADTE_ADDR_TYPE_ARGS) &&
(args_ddr_type != KERNEL_ARG_UPADTE_ADDR_TYPE_WORKSPACE)) {
return SUCCESS;
}
auto base_offset_list = is_input ? owner_node_op_desc->GetInputOffset() : owner_node_op_desc->GetOutputOffset();
const auto zero_copy_offset_to_ids = PreModelPartitionUtils::GetInstance().GetZeroCopyTable();
auto base_offset = base_offset_list.at(valid_idx);
if (zero_copy_offset_to_ids.find(base_offset) != zero_copy_offset_to_ids.end()) {
const auto offset_to_id = zero_copy_offset_to_ids.find(base_offset)->second;
base_offset_list[valid_idx] = offset_to_id;
GELOGD("base offset[%ld] set to idx[%d]", base_offset, offset_to_id);
GE_CHK_STATUS_RET(SetOpOffset(op_desc, is_input, base_offset, offset_to_id));
}
if (is_input) {
owner_node_op_desc->SetInputOffset(base_offset_list);
} else {
owner_node_op_desc->SetOutputOffset(base_offset_list);
}
return SUCCESS;
}
Status TaskGenerator::SetOpOffset(const OpDescPtr &op_desc, const bool is_input, const int64_t offset,
const uint32_t offset_to_id) {
GE_ASSERT_NOTNULL(op_desc);
auto base_offset_list = is_input ? op_desc->GetOutputOffset() : op_desc->GetInputOffset();
for (uint32_t i = 0; i < base_offset_list.size(); i++) {
if (base_offset_list.at(i) == offset) {
base_offset_list[i] = offset_to_id;
}
}
if (is_input) {
op_desc->SetOutputOffset(base_offset_list);
} else {
op_desc->SetInputOffset(base_offset_list);
}
return SUCCESS;
}
Status TaskGenerator::GenZeroCopyTable(const OpDescPtr &op_desc, uint32_t &search_id, const bool is_input) {
std::vector<int64_t> zero_copy_basic_offset;
std::vector<int64_t> zero_copy_relative_offset;
(void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset);
(void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset);
GE_ASSERT_TRUE((zero_copy_basic_offset.size() == zero_copy_relative_offset.size()),
"[Check][Param] basic_offset_size:%zu should be equal to relative_offset_size:%zu",
zero_copy_basic_offset.size(), zero_copy_relative_offset.size());
auto base_offset_list = is_input ? op_desc->GetOutputOffset() : op_desc->GetInputOffset();
constexpr uint32_t offset_size = 8U;
for (const int64_t base_offset : base_offset_list) {
PreModelPartitionUtils::GetInstance().SetZeroCopyTable(base_offset, search_id * offset_size);
GELOGI("zero_copy_offset_to_ids_ node name is %s, offset[%ld], ids[%u].",
op_desc->GetName().c_str(), base_offset, search_id * offset_size);
search_id++;
for (size_t position = 0U; position < zero_copy_basic_offset.size(); position++) {
if ((base_offset == zero_copy_basic_offset[position]) && (zero_copy_relative_offset[position] != 0)) {
PreModelPartitionUtils::GetInstance().SetZeroCopyTable(base_offset + zero_copy_relative_offset[position],
(search_id++) * offset_size);
GELOGI("zero_copy_offset_to_ids_ offset[%ld], ids[%u].", base_offset + zero_copy_relative_offset[position],
search_id * offset_size);
}
}
}
return SUCCESS;
}
Status TaskGenerator::GetTaskInfo(const ComputeGraphPtr &graph, uint64_t session_id, Model &model) {
GE_ASSERT_NOTNULL(graph);
session_id_ = session_id;
GELOGD("Begin to gen task info with graph:%s session_id:%lu", graph->GetName().c_str(), session_id);
GE_RUN_PERF(TaskGenerator, GenerateTask, graph, model);
return SUCCESS;
}
Status TaskGenerator::ReGetTaskInfo(const ComputeGraphPtr &comp_graph) {
std::unordered_set<Node *> first_last_nodes;
GE_ASSERT_SUCCESS(MarkFirstAndLastOpsForGraph(comp_graph, first_last_nodes));
const auto &all_nodes = comp_graph->GetNodes(comp_graph->GetGraphUnknownFlag(), nullptr, ffts_filter);
std::unordered_set<Node *> need_to_gen_task_nodes;
for (const auto &node : all_nodes) {
if (!NoNeedGenTask(node->GetOpDesc()) &&
(std::find(nodes_.begin(), nodes_.end(), node.get()) == nodes_.end())) {
need_to_gen_task_nodes.insert(node.get());
nodes_.emplace_back(node.get());
}
}
if (!need_to_gen_task_nodes.empty()) {
need_to_gen_task_nodes.insert(first_last_nodes.begin(), first_last_nodes.end());
std::vector<Node *> second_gen_task_nodes;
for (auto &node : need_to_gen_task_nodes) {
second_gen_task_nodes.emplace_back(node);
}
GE_ASSERT_SUCCESS(GenerateTaskForNodes(second_gen_task_nodes));
for (const auto &node : need_to_gen_task_nodes) {
auto stream_id = node->GetOpDesc()->GetStreamId();
const auto &iter = node_id_2_node_tasks_.find(node->GetOpDesc()->GetId());
GE_ASSERT_TRUE(iter != node_id_2_node_tasks_.end(), "node: %s doesn't have taskdef", node->GetNamePtr());
bool has_attached_stream = node->GetOpDesc()->HasValidAttachedStreamId();
if (!has_attached_stream) {
RefreshTaskDefStreamId(has_attached_stream, stream_id, stream_id, iter->second);
}
}
GE_ASSERT_SUCCESS(UpdateTaskDef());
}
return SUCCESS;
}
Status TaskGenerator::UpdateTaskDef() {
for (const auto &node : nodes_) {
const auto &iter = node_id_2_node_tasks_.find(node->GetOpDesc()->GetId());
if ((iter != node_id_2_node_tasks_.end()) && (!iter->second.empty())) {
auto &task_defs = iter->second;
GE_ASSERT_SUCCESS(OpsKernelBuilderManager::Instance().UpdateTask(*node, task_defs),
"[Update][Task] fail for op:%s(%s)", node->GetName().c_str(), node->GetType().c_str());
}
}
return SUCCESS;
}
Status TaskGenerator::GenModelTaskDef(const ComputeGraphPtr &graph, uint64_t session_id, Model &model) {
ModelTaskDef model_task_def;
model_task_def.set_memory_size(run_context_->dataMemSize);
model_task_def.set_weight_size(run_context_->weightMemSize);
std::vector<TaskDef> task_def_list;
task_def_list.reserve(nodes_.size());
if (!fusion_ordered_node_list_.empty()) {
GE_ASSERT_TRUE(fusion_ordered_node_list_.size() == fusion_task_node_name_list_.size());
for (size_t i = 0U; i < fusion_ordered_node_list_.size(); i++) {
const auto &iter = node_id_2_node_tasks_.find(fusion_ordered_node_list_[i]);
GE_ASSERT_TRUE(iter != node_id_2_node_tasks_.end());
op_names_.insert(op_names_.end(), iter->second.size(), fusion_task_node_name_list_[i]);
task_def_list.insert(task_def_list.end(), iter->second.begin(), iter->second.end());
}
} else {
for (const auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) {
if(!NoNeedGenTask(node->GetOpDesc())) {
const auto &iter = node_id_2_node_tasks_.find(node->GetOpDesc()->GetId());
GE_ASSERT_TRUE(iter != node_id_2_node_tasks_.end(), "node %s does not gen task", node->GetNamePtr());
op_names_.insert(op_names_.end(), iter->second.size(), node->GetOpDesc()->GetName());
task_def_list.insert(task_def_list.end(), iter->second.begin(), iter->second.end());
}
}
}
GELOGI("Graph %s gen task success, task list:%zu, logic mem base:%p, logic weight base:%p, logic var base:%p",
graph->GetName().c_str(), task_def_list.size(), run_context_->dataMemBase, run_context_->weightMemBase,
var_mem_base_);
for (const auto &task_def_temp : task_def_list) {
auto *task_def = model_task_def.add_task();
GE_ASSERT_NOTNULL(task_def);
*task_def = task_def_temp;
}
GE_RUN_PERF(TaskGenerator, AddModelTaskToModel, model_task_def, session_id, model, *run_context_);
return SUCCESS;
}
Status TaskGenerator::FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) {
if (nodes_.empty()) {
GE_CHK_STATUS_RET(FilterCandidatesNodes(graph), "[Call][FilterCandidatesNodes] failed, graph:%s.",
graph->GetName().c_str());
}
ProfilingTaskUtils profiling_task_utils(nodes_);
return profiling_task_utils.FindProfilingTaskIndex(graph, profiling_point);
}
}