* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/ge_error_codes.h"
#include "exe_graph/runtime/storage_shape.h"
#include "exe_graph/runtime/tensor.h"
#include "register/node_converter_registry.h"
#include "register/kernel_registry.h"
#include "core/utils/tensor_utils.h"
#include "exe_graph/runtime/gert_tensor_data.h"
namespace gert {
namespace {
const std::string kLowerPrefix = "Lower_";
const char *kStubLowerFunc = "_stub_lower_func";
REGISTER_NODE_CONVERTER(kStubLowerFunc, [](const ge::NodePtr &node, const LowerInput &lower_input) {
LowerResult ret;
auto merged_holders = lower_input.input_shapes;
merged_holders.insert(merged_holders.end(), lower_input.input_addrs.begin(), lower_input.input_addrs.end());
int64_t logic_stream_id = node->GetOpDesc()->GetStreamId();
auto stream_id_holder = bg::ValueHolder::CreateConst(&logic_stream_id, sizeof(logic_stream_id));
merged_holders.emplace_back(stream_id_holder);
if (node->GetAllOutDataAnchorsSize() == 0U) {
ret.order_holders.push_back(bg::ValueHolder::CreateVoid<bg::ValueHolder>((kLowerPrefix + node->GetType()).c_str(), merged_holders));
} else {
auto holders = bg::DevMemValueHolder::CreateDataOutput((kLowerPrefix + node->GetType()).c_str(), merged_holders,
node->GetAllOutDataAnchorsSize() << 1U,
node->GetOpDesc()->GetStreamId());
ret.out_shapes.insert(ret.out_shapes.end(), holders.begin(), holders.begin() + node->GetAllOutDataAnchorsSize());
ret.out_addrs.insert(ret.out_addrs.end(), holders.begin() + node->GetAllOutDataAnchorsSize(), holders.end());
}
return ret;
});
}
bool MockLessImportantNodeConverter(const ge::NodePtr &node, bool override) {
if (!node->GetOpDesc()->HasAttr("_ge_attr_lowering_func")) {
if (NodeConverterRegistry::GetInstance().FindNodeConverter(node->GetType()) == nullptr || override) {
return ge::AttrUtils::SetStr(node->GetOpDesc(), "_ge_attr_lowering_func", kStubLowerFunc);
}
return false;
}
return false;
}
void MockLessImportantNodeConverter(const ge::ComputeGraphPtr &graph, const std::set<std::string> &exclude_ops,
const std::set<std::string> &exclude_nodes, bool override) {
for (auto &node : graph->GetAllNodes()) {
if (exclude_ops.count(node->GetType()) || exclude_nodes.count(node->GetName())) {
continue;
}
MockLessImportantNodeConverter(node, override);
}
}
void MockLessImportantNodeKernel(const ge::NodePtr &node) {
if (!MockLessImportantNodeConverter(node, false)) {
return;
}
KernelRegistry::KernelFuncs funcs = {};
funcs.run_func = [](KernelContext *context) {
auto stream_id_input_idx = context->GetInputNum() - 1;
auto stream_id = context->GetInputValue<int64_t>(stream_id_input_idx);
size_t num_outputs = context->GetOutputNum();
for (size_t i = (num_outputs >> 1U); i < num_outputs; i++) {
const static char kBuffer[1024] = {};
auto gtd = context->GetOutputPointer<GertTensorData>(i);
*gtd = {const_cast<char *>(kBuffer), sizeof(kBuffer), kOnDeviceHbm, stream_id};
}
return ge::GRAPH_SUCCESS;
};
funcs.outputs_creator = [](const ge::FastNode *n, KernelContext *context) {
size_t num_outputs = (context->GetOutputNum() >> 1U);
for (size_t i = 0U; i < num_outputs; i++) {
context->GetOutput(i)->SetWithDefaultDeleter(new StorageShape());
context->GetOutput(i + num_outputs)->SetWithDefaultDeleter(new GertTensorData());
}
return ge::GRAPH_SUCCESS;
};
KernelRegistry::GetInstance().RegisterKernel(kLowerPrefix + node->GetType(), {funcs, ""});
}
void MockLessImportantNodeKernel(const ge::ComputeGraphPtr &graph, const std::set<std::string> &exclude_ops,
const std::set<std::string> &exclude_nodes) {
for (auto &node : graph->GetAllNodes()) {
if (exclude_ops.count(node->GetType()) || exclude_nodes.count(node->GetName())) {
continue;
}
MockLessImportantNodeKernel(node);
}
}
}