* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "custom_graph_optimizer.h"
#include <memory>
#include <unordered_map>
#include <vector>
#include "custom_op_factory.h"
#include "common/ge_common/ge_types.h"
#include "common/checker.h"
#include "lowering/kernel_run_context_builder.h"
#include "common/compile_profiling/ge_trace_wrapper.h"
#include "common/thread_pool/thread_pool.h"
#include "common/context/local_context.h"
#include "debug/ge_util.h"
#include "platform/platform_info.h"
#include "mmpa/mmpa_api.h"
namespace {
uint32_t GetThreadNum() {
const char_t *value = nullptr;
MM_SYS_GET_ENV(MM_ENV_MAX_COMPILE_CORE_NUMBER, value);
const int64_t thread_num = ((value != nullptr) && (value[0U] != '\0')) ?
std::strtol(value, nullptr, 10) : 16;
if (thread_num <= 0) {
GELOGW("Get invalid MAX_COMPILE_CORE_NUMBER env value %s, use default thread number 16", value);
}
return (thread_num > 0) ? static_cast<uint32_t>(thread_num) : 16U;
}
void GetStorageShape(const ge::GeTensorDesc &tensor_desc, gert::StorageShape &storage_shape) {
const auto &storage_dims = tensor_desc.GetShape().GetDims();
for (const auto &dim : storage_dims) {
(void)storage_shape.MutableStorageShape().AppendDim(dim);
}
const auto &origin_dims = tensor_desc.GetOriginShape().GetDims();
for (const auto &dim : origin_dims) {
(void)storage_shape.MutableOriginShape().AppendDim(dim);
}
}
std::vector<void *> GetHoldersRawPtr(const std::vector<std::unique_ptr<uint8_t[]>> &holders) {
std::vector<void *> holder_raw_ptr;
holder_raw_ptr.reserve(holders.size());
for (const auto &holder : holders) {
(void)holder_raw_ptr.emplace_back(holder.get());
}
return holder_raw_ptr;
}
ge::Status ConstructCustomCompileContextInputs(const ge::OpDescPtr &op_desc,
std::vector<std::unique_ptr<uint8_t[]>> &inputs) {
for (size_t i = 0UL; i < op_desc->GetAllInputsSize(); ++i) {
if (op_desc->GetInputDesc(static_cast<uint32_t>(i)).IsValid() != ge::GRAPH_SUCCESS) {
GELOGD("input desc is not valid, skip add input[%zu] into compile context inputs.", i);
continue;
}
auto input_desc = op_desc->MutableInputDesc(i);
GE_ASSERT_NOTNULL(input_desc);
gert::StorageShape storage_shape;
GetStorageShape(*input_desc, storage_shape);
std::unique_ptr<uint8_t[]> tensor_holder = ge::ComGraphMakeUnique<uint8_t[]>(sizeof(gert::Tensor));
GE_ASSERT_NOTNULL(tensor_holder, "Create compile context input tensor holder failed.");
new (tensor_holder.get())
gert::Tensor(storage_shape, {input_desc->GetOriginFormat(), input_desc->GetFormat(), {}},
input_desc->GetDataType());
(void)inputs.emplace_back(std::move(tensor_holder));
}
return ge::SUCCESS;
}
ge::Status ConstructCustomCompileContextOutputs(const ge::OpDescPtr &op_desc,
std::vector<std::unique_ptr<uint8_t[]>> &outputs) {
for (size_t i = 0UL; i < op_desc->GetAllOutputsDescSize(); ++i) {
auto output_desc = op_desc->MutableOutputDesc(i);
GE_ASSERT_NOTNULL(output_desc);
gert::StorageShape storage_shape;
GetStorageShape(*output_desc, storage_shape);
auto tensor_holder = ge::ComGraphMakeUnique<uint8_t[]>(sizeof(gert::Tensor));
GE_ASSERT_NOTNULL(tensor_holder, "Create compile context output tensor holder failed.");
new (tensor_holder.get())
gert::Tensor(storage_shape, {output_desc->GetOriginFormat(), output_desc->GetFormat(), {}},
output_desc->GetDataType());
(void)outputs.emplace_back(std::move(tensor_holder));
}
return ge::SUCCESS;
}
struct CompileTask {
CompileTask(ge::CompilableOp *op_ptr, std::vector<std::unique_ptr<uint8_t[]>> &&inputs_holder,
std::vector<std::unique_ptr<uint8_t[]>> &&outputs_holder, gert::KernelContextHolder &&holder,
std::string name, std::string type)
: compilable_op_ptr(op_ptr),
op_compile_inputs_holder(std::move(inputs_holder)),
op_compile_outputs_holder(std::move(outputs_holder)),
op_compile_context_holder(std::move(holder)),
op_name(std::move(name)),
op_type(std::move(type)) {}
CompileTask(CompileTask &&) = default;
CompileTask &operator=(CompileTask &&) = default;
CompileTask(const CompileTask &) = delete;
CompileTask &operator=(const CompileTask &) = delete;
ge::CompilableOp *compilable_op_ptr;
std::vector<std::unique_ptr<uint8_t[]>> op_compile_inputs_holder;
std::vector<std::unique_ptr<uint8_t[]>> op_compile_outputs_holder;
gert::KernelContextHolder op_compile_context_holder;
std::string op_name;
std::string op_type;
};
ge::Status CompileCustomOp(CompileTask *task) {
GELOGI("[Compile][CustomOp] call compile, op_name:%s, op_type:%s", task->op_name.c_str(), task->op_type.c_str());
auto *const op_compile_context =
reinterpret_cast<gert::OpCompileContext *>(task->op_compile_context_holder.GetKernelContext());
if (op_compile_context == nullptr) {
GELOGE(ge::FAILED, "[Compile][CustomOp] compile context is null, op_type:%s", task->op_type.c_str());
return ge::FAILED;
}
return task->compilable_op_ptr->Compile(op_compile_context);
}
ge::Status CompileCustomOpSerially(const std::vector<CompileTask *> *tasks) {
for (auto *task : *tasks) {
const auto ret = CompileCustomOp(task);
if (ret != ge::SUCCESS) {
GELOGE(ret, "[Compile][CustomOp] compile failed, op_name:%s, op_type:%s",
task->op_name.c_str(), task->op_type.c_str());
return ret;
}
}
return ge::SUCCESS;
}
ge::Status AppendCompileTaskIfNeeded(const ge::NodePtr &node, std::vector<CompileTask> &compile_tasks) {
const auto op_type = node->GetType();
const ge::AscendString op_type_ascend(op_type.c_str());
if (!ge::CustomOpFactory::IsExistOp(op_type_ascend)) {
return ge::SUCCESS;
}
GELOGI("during optimize whole graph, %s is custom op", op_type_ascend.GetString());
auto *const base_custom_op_ptr = ge::CustomOpFactory::CreateOrGetCustomOp(op_type_ascend);
if (base_custom_op_ptr == nullptr) {
GELOGE(ge::FAILED, "[Compile][CustomOp] create custom op failed, op_name:%s, op_type:%s",
node->GetName().c_str(), op_type.c_str());
return ge::FAILED;
}
auto *const compilable_op_ptr = dynamic_cast<ge::CompilableOp *>(base_custom_op_ptr);
if (compilable_op_ptr == nullptr) {
GELOGI("[Compile][CustomOp] custom op did not implement CompilableOp, op_name:%s, op_type:%s",
node->GetName().c_str(), op_type.c_str());
return ge::SUCCESS;
}
std::vector<std::unique_ptr<uint8_t[]>> op_compile_inputs_holder;
std::vector<std::unique_ptr<uint8_t[]>> op_compile_outputs_holder;
GE_ASSERT_SUCCESS(ConstructCustomCompileContextInputs(node->GetOpDesc(), op_compile_inputs_holder));
GE_ASSERT_SUCCESS(ConstructCustomCompileContextOutputs(node->GetOpDesc(), op_compile_outputs_holder));
auto op_compile_context_holder = gert::KernelRunContextBuilder()
.Inputs(GetHoldersRawPtr(op_compile_inputs_holder))
.Outputs(GetHoldersRawPtr(op_compile_outputs_holder))
.Build(node->GetOpDesc());
compile_tasks.emplace_back(compilable_op_ptr, std::move(op_compile_inputs_holder),
std::move(op_compile_outputs_holder),
std::move(op_compile_context_holder), node->GetName(), op_type);
return ge::SUCCESS;
}
ge::Status CollectCompileTasks(ge::ComputeGraph &graph, std::vector<CompileTask> &compile_tasks) {
for (const auto &node : graph.GetAllNodes()) {
GE_ASSERT_SUCCESS(AppendCompileTaskIfNeeded(node, compile_tasks));
}
return ge::SUCCESS;
}
std::unordered_map<ge::CompilableOp *, std::vector<CompileTask *>> GroupCompileTasksByOp(
std::vector<CompileTask> &compile_tasks) {
std::unordered_map<ge::CompilableOp *, std::vector<CompileTask *>> compile_tasks_by_op;
compile_tasks_by_op.reserve(compile_tasks.size());
for (auto &task : compile_tasks) {
compile_tasks_by_op[task.compilable_op_ptr].push_back(&task);
}
return compile_tasks_by_op;
}
ge::Status CompileCustomOpsInParallel(std::vector<CompileTask> &compile_tasks) {
if (compile_tasks.empty()) {
return ge::SUCCESS;
}
const auto actual_thread_num = GetThreadNum();
GELOGI("Custom op compile thread num is %u", actual_thread_num);
ge::ThreadPool executor("custom_opt_compile", actual_thread_num, true);
std::vector<std::future<ge::Status>> vector_future;
auto compile_tasks_by_op = GroupCompileTasksByOp(compile_tasks);
vector_future.reserve(compile_tasks_by_op.size());
for (auto &compile_task_group : compile_tasks_by_op) {
std::future<ge::Status> f = executor.commit(CompileCustomOpSerially, &compile_task_group.second);
if (!f.valid()) {
GELOGE(ge::FAILED, "[Call][Commit] failed, Future is invalid");
return ge::FAILED;
}
vector_future.emplace_back(std::move(f));
}
for (auto &f : vector_future) {
const ge::Status ret = f.get();
if (ret != ge::SUCCESS) {
GELOGE(ret, "[Compile][CustomOp] failed");
return ret;
}
}
return ge::SUCCESS;
}
}
namespace ge {
CustomGraphOptimizer::~CustomGraphOptimizer() = default;
Status CustomGraphOptimizer::Initialize(const std::map<std::string, std::string> &options,
ge::OptimizeUtility *const optimize_utility) {
(void)options;
(void)optimize_utility;
return SUCCESS;
}
ge::Status CustomGraphOptimizer::Finalize() {
return SUCCESS;
}
ge::Status CustomGraphOptimizer::OptimizeOriginalGraph(ge::ComputeGraph &graph) {
(void)graph;
return SUCCESS;
}
ge::Status CustomGraphOptimizer::OptimizeFusedGraph(ge::ComputeGraph &graph) {
(void)graph;
return SUCCESS;
}
ge::Status CustomGraphOptimizer::OptimizeWholeGraph(ge::ComputeGraph &graph) {
GELOGI("entering optimize whole graph");
GE_TIMESTAMP_START(CustomGraphOptimizer);
std::vector<CompileTask> compile_tasks;
GE_ASSERT_SUCCESS(CollectCompileTasks(graph, compile_tasks));
GE_ASSERT_SUCCESS(CompileCustomOpsInParallel(compile_tasks));
GE_TIMESTAMP_END(CustomGraphOptimizer, "CustomOptimizeWholeGraph");
return SUCCESS;
}
ge::Status CustomGraphOptimizer::GetAttributes(ge::GraphOptimizerAttribute &attrs) const {
attrs.engineName = kEngineNameCustom;
return SUCCESS;
}
}