* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "exe_graph/lowering/bg_kernel_context_extend.h"
#include "framework/common/debug/ge_log.h"
#include "common/checker.h"
#include "exe_graph/lowering/bg_ir_attrs.h"
#include "exe_graph/runtime/context_extend.h"
#include "graph/debug/ge_attr_define.h"
#include "graph_metadef/graph/debug/ge_util.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/math_util.h"
namespace gert {
namespace bg {
namespace {
ge::graphStatus InitIOInstanceInfo(const ge::NodePtr &node, ComputeNodeInfo &compute_node_info) {
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
auto in_ir_index_to_instance_index_pair_map
= ge::OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(node->GetOpDesc());
if (in_ir_index_to_instance_index_pair_map.empty()) {
GELOGI("node [%s(%s)] ir_index_to_instance_index_pair_map is empty",
node->GetNamePtr(), node->GetTypePtr());
} else {
const auto &ir_inputs = op_desc->GetIrInputs();
size_t input_index = 0;
for (size_t i = 0; i < ir_inputs.size(); ++i) {
auto ins_info = compute_node_info.MutableInputInstanceInfo(i);
GE_ASSERT_NOTNULL(ins_info);
size_t instance_num = in_ir_index_to_instance_index_pair_map[i].second;
compute_node_info.MutableInputInstanceInfo(i)->SetInstantiationNum(instance_num);
compute_node_info.MutableInputInstanceInfo(i)->SetInstanceStart(input_index);
input_index += instance_num;
}
}
auto out_ir_index_to_instance_index_pair_map
= ge::OpDescUtils::GetOutputIrIndexes2InstanceIndexesPairMap(node->GetOpDesc());
if (out_ir_index_to_instance_index_pair_map.empty()) {
GELOGI("node [%s(%s)] output ir_index_to_instance_index_pair_map is empty",
node->GetNamePtr(), node->GetTypePtr());
return ge::GRAPH_SUCCESS;
}
const auto &ir_outputs = op_desc->GetIrOutputs();
size_t output_index = 0;
for (size_t i = 0; i < ir_outputs.size(); ++i) {
auto ins_info = compute_node_info.MutableOutputInstanceInfo(i);
GE_ASSERT_NOTNULL(ins_info);
size_t instance_num = out_ir_index_to_instance_index_pair_map[i].second;
compute_node_info.MutableOutputInstanceInfo(i)->SetInstantiationNum(instance_num);
compute_node_info.MutableOutputInstanceInfo(i)->SetInstanceStart(output_index);
output_index += instance_num;
}
return ge::GRAPH_SUCCESS;
}
void SetCompileTimeTd(const ge::ConstGeTensorDescPtr &desc, CompileTimeTensorDesc &td) {
td.SetDataType(desc->GetDataType());
td.SetOriginFormat(desc->GetOriginFormat());
td.SetStorageFormat(desc->GetFormat());
int64_t reshape_type_mask = 0;
if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_RESHAPE_TYPE_MASK, reshape_type_mask)) {
td.SetExpandDimsType(ExpandDimsType(reshape_type_mask));
}
bool is_null_output = false;
bool ret = ge::AttrUtils::GetBool(desc, ge::ATTR_NAME_IS_NULL_OUTPUT, is_null_output);
if (ret && is_null_output) {
GELOGI("op name: %s set exist false", desc->GetName().c_str());
td.SetExist(false);
}
}
ge::graphStatus GetConnectedEdgeIndexesToAnchorIndexMap(const ge::NodePtr &node,
std::map<size_t, size_t> &connected_edge_indexes_to_anchor_index) {
size_t compute_node_index = 0U;
for (const auto anchor : node->GetAllInDataAnchorsPtr()) {
GE_ASSERT_NOTNULL(anchor);
if (anchor->GetPeerOutAnchor() == nullptr) {
continue;
}
GE_ASSERT_NOTNULL(anchor->GetPeerOutAnchor()->GetOwnerNodeBarePtr());
connected_edge_indexes_to_anchor_index[compute_node_index] = static_cast<size_t>(anchor->GetIdx());
++compute_node_index;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus InitCompileTimeTD(const ge::NodePtr &node, ComputeNodeInfo &compute_node_info) {
std::map<size_t, size_t> connected_edge_indexes_to_anchor_index;
const auto ret = GetConnectedEdgeIndexesToAnchorIndexMap(node, connected_edge_indexes_to_anchor_index);
if (ret != ge::GRAPH_SUCCESS) {
GELOGE(ret, "get connected edge indexes to anchor index map failed. node:%s(%s)",
node->GetName().c_str(), node->GetType().c_str());
return ret;
}
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
GE_ASSERT_TRUE(connected_edge_indexes_to_anchor_index.size() == compute_node_info.GetInputsNum());
for (size_t i = 0; i < compute_node_info.GetInputsNum(); ++i) {
GE_ASSERT_TRUE(i < op_desc->GetAllInputsSize());
const auto &desc_need_check = op_desc->GetInputDesc(connected_edge_indexes_to_anchor_index[i]);
if (desc_need_check.IsValid() != ge::GRAPH_SUCCESS) {
continue;
}
const auto &desc = op_desc->GetInputDescPtr(connected_edge_indexes_to_anchor_index[i]);
GE_ASSERT_NOTNULL(desc);
auto td = compute_node_info.MutableInputTdInfo(i);
GE_ASSERT_NOTNULL(td);
SetCompileTimeTd(desc, *td);
}
for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
const auto &desc = op_desc->GetOutputDescPtr(i);
GE_ASSERT_NOTNULL(desc);
auto td = compute_node_info.MutableOutputTdInfo(i);
GE_ASSERT_NOTNULL(td);
SetCompileTimeTd(desc, *td);
}
return ge::SUCCESS;
}
bool GetPrivateAttrsList(const ge::NodePtr &node, const gert::OpImplRegisterV2::PrivateAttrList &private_attrs,
std::vector<ge::AnyValue> &runtime_attrs_list) {
const auto op = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op);
const auto &all_attrs = op->GetAllAttrs();
for (auto &private_attr : private_attrs) {
auto &private_attr_name = private_attr.first;
auto iter = all_attrs.find(private_attr_name.GetString());
if (iter == all_attrs.end()) {
if (!private_attr.second.IsEmpty()) {
runtime_attrs_list.push_back(private_attr.second);
continue;
}
GELOGE(ge::FAILED, "Cannot find the private attr %s from node %s",
private_attr_name.GetString(), node->GetName().c_str());
return false;
}
runtime_attrs_list.push_back(iter->second);
}
return true;
}
std::unique_ptr<uint8_t[]> CreateComputeNodeInfoImpl(const std::unique_ptr<uint8_t[]> &attr_buf,
const size_t attr_size,
const ge::NodePtr &node,
BufferPool &buffer_pool,
size_t &total_size) {
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
const size_t ir_input_num = op_desc->GetIrInputs().size();
const size_t ir_output_num = op_desc->GetIrOutputs().size();
const size_t input_num = node->GetInDataNodesAndAnchors().size();
const uint32_t output_num = node->GetAllOutDataAnchorsSize();
GELOGD("node: %s(%s), ir_input_num:%zu, ir_output_num:%zu, input_num:%zu, output_num:%u.", node->GetNamePtr(),
node->GetTypePtr(), ir_input_num, ir_output_num, input_num, output_num);
GE_ASSERT_SUCCESS(ComputeNodeInfo::CalcSize(ir_input_num, ir_output_num, input_num, output_num, total_size));
GE_ASSERT_TRUE(!ge::AddOverflow(total_size, attr_size, total_size));
auto compute_node_info_holder = ge::ComGraphMakeUnique<uint8_t[]>(total_size);
GE_ASSERT_NOTNULL(compute_node_info_holder, "Create compute node info holder failed");
auto node_name = buffer_pool.AddStr(node->GetNamePtr());
auto node_type = buffer_pool.AddStr(node->GetTypePtr());
auto compute_node_info = ge::PtrToPtr<uint8_t, ComputeNodeInfo>(compute_node_info_holder.get());
compute_node_info->Init(ir_input_num, ir_output_num, input_num, output_num, attr_size,
ge::PtrToPtr<void, ge::char_t>(ge::ValueToPtr(node_name)),
ge::PtrToPtr<void, ge::char_t>(ge::ValueToPtr(node_type)));
auto ret = InitIOInstanceInfo(node, *compute_node_info);
GE_ASSERT_SUCCESS(ret, "Init input instance info for node:%s failed.", node->GetNamePtr());
ret = InitCompileTimeTD(node, *compute_node_info);
GE_ASSERT_SUCCESS(ret, "Init compile time tensor desc for node:%s failed.", node->GetNamePtr());
auto attr = compute_node_info->MutableAttrs();
const auto offset = ge::PtrToPtr<RuntimeAttrs, uint8_t>(attr) - compute_node_info_holder.get();
if (static_cast<size_t>(offset) > total_size) {
GELOGE(
ge::FAILED,
"Failed to create kernel context extend info, the offset of attr %zu beyond the total size of ExtendInfo %zu",
offset, attr_size);
return nullptr;
}
const auto outputs_ins_info_size = compute_node_info->GetIrOutputsNum() * sizeof(AnchorInstanceInfo);
ret = ge::GeMemcpy(ge::PtrToPtr<RuntimeAttrs, uint8_t>(attr), (total_size - offset - outputs_ins_info_size),
attr_buf.get(), attr_size);
GE_ASSERT_SUCCESS(ret, "memcpy_s failed, copy size is %zu, dst size is %zu", attr_size,
(total_size - offset - outputs_ins_info_size));
GELOGI("Node %s, compute_node_info attr_size %zu, outputs_ins_info_size:%zu, offset:%zu, total_size:%zu.",
node->GetNamePtr(), attr_size, outputs_ins_info_size, offset, total_size);
return compute_node_info_holder;
}
}
std::unique_ptr<uint8_t[]> CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool, size_t &total_size) {
size_t attr_size;
const auto attr_buf = CreateAttrBuffer(node, attr_size);
GE_ASSERT_NOTNULL(attr_buf, "Create attr buffer for node: %s failed", node->GetNamePtr());
return CreateComputeNodeInfoImpl(attr_buf, attr_size, node, buffer_pool, total_size);
}
std::unique_ptr<uint8_t[]> CreateComputeNodeInfo(const ge::NodePtr &node,
BufferPool &buffer_pool,
const gert::OpImplRegisterV2::PrivateAttrList &private_attrs,
size_t &total_size) {
std::vector<ge::AnyValue> runtime_attrs_list;
GE_ASSERT_TRUE(GetPrivateAttrsList(node, private_attrs, runtime_attrs_list));
size_t attr_size;
const auto attr_buf = CreateAttrBuffer(node, runtime_attrs_list, attr_size);
GE_ASSERT_NOTNULL(attr_buf, "Create attr buffer for node: %s failed", node->GetNamePtr());
return CreateComputeNodeInfoImpl(attr_buf, attr_size, node, buffer_pool, total_size);
}
std::unique_ptr<uint8_t[]> CreateComputeNodeInfoWithoutIrAttr(const ge::NodePtr &node, BufferPool &buffer_pool,
const gert::OpImplRegisterV2::PrivateAttrList &private_attrs, size_t &total_size) {
std::vector<ge::AnyValue> runtime_attrs_list;
GE_ASSERT_TRUE(GetPrivateAttrsList(node, private_attrs, runtime_attrs_list));
size_t attr_size;
const auto attr_buf = CreateAttrBufferWithoutIr(node, runtime_attrs_list, attr_size);
GE_ASSERT_NOTNULL(attr_buf, "Create attr buffer without ir for node: %s failed", node->GetNamePtr());
return CreateComputeNodeInfoImpl(attr_buf, attr_size, node, buffer_pool, total_size);
}
std::unique_ptr<uint8_t[]> CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool) {
size_t total_size;
return CreateComputeNodeInfo(node, buffer_pool, total_size);
}
}
}