* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "program_generator.h"
#include "common/om2/codegen/ast/ast_nodes.h"
#include "common/om2/codegen/emitter/cpp_emitter.h"
#include "common/om2/codegen/file_code_generator/args_manager_file_code_generator.h"
#include "common/om2/codegen/file_code_generator/interface_file_code_generator.h"
#include "common/om2/codegen/file_code_generator/kernel_reg_file_code_generator.h"
#include "common/om2/codegen/file_code_generator/load_and_run_file_code_generator.h"
#include "common/om2/codegen/file_code_generator/resources_file_code_generator.h"
#include "common/om2/codegen/om2_codegen_utils.h"
#include "framework/common/debug/ge_log.h"
namespace ge {
namespace {
Status EmitFile(const GeneratedFileIndex file_index, const AstNode *unit, Om2CodePrinter &code_printer) {
GE_ASSERT_TRUE(unit != nullptr, "[OM2] Program %zu AST is null.", static_cast<size_t>(file_index));
CppEmitter emitter;
std::string code_content;
GE_ASSERT_SUCCESS(unit->Accept(emitter, code_content), "[OM2] Program %zu code generation failed.",
static_cast<size_t>(file_index));
code_printer.AddContent(file_index, code_content);
return SUCCESS;
}
}
Status ProgramGenerator::GenerateProgram(Om2CodePrinter &code_printer) {
GE_ASSERT_SUCCESS(GenerateKernelRegSource(code_printer));
GE_ASSERT_SUCCESS(GenerateInterfaceHeader(code_printer));
GE_ASSERT_SUCCESS(GenerateResourcesSource(code_printer));
GE_ASSERT_SUCCESS(GenerateLoadAndRunSource(code_printer));
GE_ASSERT_SUCCESS(GenerateArgsManagerSource(code_printer));
GE_ASSERT_SUCCESS(GenerateMakeFile(code_printer));
return SUCCESS;
}
Status ProgramGenerator::GenerateInterfaceHeader(Om2CodePrinter &code_printer) {
InterfaceFileCodeGenerator interface_handler(ast_);
auto external_api_decls = interface_handler.BuildExternalApiDecls();
external_api_decls.insert(external_api_decls.begin(), ast_.StablePart(StablePartId::kInterfaceDumpApis));
auto *translation_unit = ast_.File({
ast_.Include("iostream", IncludeDecl::Kind::kAngle),
ast_.Include("cstddef", IncludeDecl::Kind::kAngle),
ast_.Include("ctime", IncludeDecl::Kind::kAngle),
ast_.Include("chrono", IncludeDecl::Kind::kAngle),
ast_.Include("fstream", IncludeDecl::Kind::kAngle),
ast_.Include("iomanip", IncludeDecl::Kind::kAngle),
ast_.Include("sstream", IncludeDecl::Kind::kAngle),
ast_.Include("vector", IncludeDecl::Kind::kAngle),
ast_.Include("map", IncludeDecl::Kind::kAngle),
ast_.Include("unordered_map", IncludeDecl::Kind::kAngle),
ast_.Include("functional", IncludeDecl::Kind::kAngle),
ast_.Include("cstdint", IncludeDecl::Kind::kAngle),
ast_.Include("type_traits", IncludeDecl::Kind::kAngle),
ast_.Include("array", IncludeDecl::Kind::kAngle),
ast_.Include("securec.h"),
ast_.Include("acl/acl.h"),
ast_.Include("acl/acl_base.h"),
ast_.Include("exe_graph/runtime/tensor.h"),
ast_.Include("rt.h"),
ast_.Space(),
ast_.StablePart(StablePartId::kInterfaceMacros),
ast_.StablePart(StablePartId::kInterfacePointerHelpers),
ast_.Namespace("om2", {
ast_.Field("constexpr int32_t", "INPUT_NUM", static_cast<int>(codegen_model_.model_io.input_count)),
ast_.Field("constexpr int32_t", "OUTPUT_NUM", static_cast<int>(codegen_model_.model_io.output_count)),
interface_handler.BuildOm2ModelHandleAlias(),
interface_handler.BuildBinDataInfoStruct(),
interface_handler.BuildAicpuParamHeadStruct(),
interface_handler.BuildAicpuSessionInfoStruct(),
interface_handler.BuildArgsInfoStruct(),
interface_handler.BuildTfAiCpuExInfoStruct(),
ast_.StablePart(StablePartId::kScopeGuard, StablePartPlacement::kNamespace),
interface_handler.BuildOm2ArgsTableClass(),
interface_handler.BuildOm2ModelClass(codegen_model_),
}),
ast_.ExternBlock("C", external_api_decls),
});
GE_ASSERT_SUCCESS(EmitFile(GeneratedFileIndex::kInterfaceHeaderFile, translation_unit, code_printer));
return SUCCESS;
}
Status ProgramGenerator::GenerateResourcesSource(Om2CodePrinter &code_printer) {
ResourcesFileCodeGenerator resources_handler(ast_);
std::vector<DeclNode *> resources_items = {
resources_handler.BuildOm2ModelConstructor(codegen_model_),
resources_handler.BuildOm2ModelDestructor(),
resources_handler.BuildInitResourcesMethod(codegen_model_, task_code_builder_list_),
resources_handler.BuildReleaseResourcesMethod(codegen_model_),
};
if (codegen_model_.runtime.has_label_switch) {
resources_items.push_back(ast_.StablePart(StablePartId::kCreateLabelListForLabelSwitch));
}
if (codegen_model_.runtime.has_label_goto) {
resources_items.push_back(ast_.StablePart(StablePartId::kCreateLabelListForLabelGotoEx));
}
auto *translation_unit = ast_.File({
ast_.Include(codegen_model_.model_name + "_interface.h"),
ast_.Space(),
ast_.Namespace("om2", resources_items),
});
GE_ASSERT_SUCCESS(EmitFile(GeneratedFileIndex::kResourcesFile, translation_unit, code_printer));
GELOGD("[OM2] Interface header file code is generated.");
return SUCCESS;
}
Status ProgramGenerator::GenerateArgsManagerSource(Om2CodePrinter &code_printer) {
ArgsManagerFileCodeGenerator args_manager_handler(ast_);
auto *translation_unit = ast_.File({
ast_.Include(codegen_model_.model_name + "_interface.h"),
ast_.Space(),
ast_.Namespace("om2", {
args_manager_handler.BuildInitMethod(codegen_model_),
args_manager_handler.BuildDestructor(),
args_manager_handler.BuildGetArgsInfoMethod(),
args_manager_handler.BuildGetDevArgAddrMethod(),
args_manager_handler.BuildGetHostArgAddrMethod(),
args_manager_handler.BuildUpdateHostArgsMethod(),
args_manager_handler.BuildCopyArgsToDeviceMethod(),
}),
});
GE_ASSERT_SUCCESS(EmitFile(GeneratedFileIndex::kArgsManagerFile, translation_unit, code_printer));
GELOGD("[OM2] Args Manager source file code is generated.");
return SUCCESS;
}
Status ProgramGenerator::GenerateKernelRegSource(Om2CodePrinter &code_printer) {
KernelRegFileCodeGenerator kernel_reg_handler(ast_);
std::vector<DeclNode *> anonymous_items = {
ast_.Field("constexpr uint32_t", "kMaxJsonFileLen", ast_.UInt(512)),
kernel_reg_handler.BuildBinaryBufferStruct(),
kernel_reg_handler.BuildAicoreRegisterInfoStruct(),
kernel_reg_handler.BuildAicpuRegisterInfoStruct(),
kernel_reg_handler.BuildCustAicpuRegisterInfoStruct(),
ast_.StablePart(StablePartId::kReadBinaryFileToBuffer, StablePartPlacement::kNamespace),
ast_.StablePart(StablePartId::kGenerateJsonFile, StablePartPlacement::kNamespace),
kernel_reg_handler.BuildAssembleAicpuLoadOptions(),
kernel_reg_handler.BuildRegisterAicoreKernel(),
kernel_reg_handler.BuildRegisterAicpuKernel(),
kernel_reg_handler.BuildRegisterCustAicpuKernel(),
};
auto *translation_unit = ast_.File({
ast_.Include(codegen_model_.model_name + "_interface.h"),
ast_.Namespace("om2", {
ast_.Namespace("", anonymous_items),
kernel_reg_handler.BuildRegisterKernels(codegen_model_),
}),
});
GE_ASSERT_SUCCESS(EmitFile(GeneratedFileIndex::kKernelRegistryFile, translation_unit, code_printer));
GELOGD("[OM2] Kernel Reg source file code is generated.");
return SUCCESS;
}
Status ProgramGenerator::GenerateLoadAndRunSource(Om2CodePrinter &code_printer) {
LoadAndRunFileCodeGenerator load_and_run_handler(ast_);
auto anonymous_items = load_and_run_handler.BuildAnonymousNamespaceItems(codegen_model_, task_code_builder_list_);
anonymous_items.insert(anonymous_items.begin(), ast_.StablePart(StablePartId::kLoadAndRunDumpHelpers,
StablePartPlacement::kNamespace));
auto *translation_unit = ast_.File({
ast_.Include(codegen_model_.model_name + "_interface.h"),
ast_.Space(),
ast_.Namespace("om2", {
ast_.Namespace("", anonymous_items),
load_and_run_handler.BuildGetRtModelHandleMethod(),
load_and_run_handler.BuildLoadMethod(codegen_model_, task_code_builder_list_),
load_and_run_handler.BuildRunAsyncMethod(codegen_model_),
load_and_run_handler.BuildRunMethod(codegen_model_),
}),
ast_.StablePart(StablePartId::kLoadAndRunExternalApis),
});
GE_ASSERT_SUCCESS(EmitFile(GeneratedFileIndex::kLoadingAndRunningFile, translation_unit, code_printer));
GELOGD("[OM2] Load and run source file code is generated.");
return SUCCESS;
}
Status ProgramGenerator::GenerateMakeFile(Om2CodePrinter &code_printer) {
const std::string model_name = codegen_model_.model_name;
const std::string lib_name = model_name + "_om2";
std::string cmakelists_content = R"(CANN_ROOT ?= $(ASCEND_HOME_PATH)
USE_STUB_LIB ?= 0
CXX := g++
TARGET := lib)" + lib_name + R"(.so
SRC_FILES := )" + model_name + R"(_resources.cpp )" + model_name + R"(_kernel_reg.cpp )" + model_name +
R"(_load_and_run.cpp )" + model_name + R"(_args_manager.cpp
CXXFLAGS := -std=c++17 -O2 -fPIC \
-I$(CANN_ROOT)/include \
-I$(CANN_ROOT)/pkg_inc \
-I$(CANN_ROOT)/pkg_inc/runtime \
-I$(CANN_ROOT)/pkg_inc/runtime/runtime \
-I$(CANN_ROOT)/pkg_inc/profiling \
-I$(CURDIR)/include
ifeq ($(USE_STUB_LIB),1)
LIB_PATH := $(CANN_ROOT)/devlib
else
LIB_PATH := $(CANN_ROOT)/lib64
endif
LDFLAGS := -shared -L$(LIB_PATH) -Wl,--no-as-needed -lacl_rt -Wl,--as-needed
.PHONY: all clean
all: $(TARGET)
$(TARGET): $(SRC_FILES)
@$(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
clean:
@rm -f $(TARGET)
)";
code_printer.AddContent(GeneratedFileIndex::kCMakeListsFile, cmakelists_content + "\n");
GELOGD("[OM2] Makefile code is generated.");
return SUCCESS;
}
}