* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_CXX_BASE_COMMON_OM2_CODEGEN_HANDLER_TASK_CODE_BUILDER_FACTORY_H_
#define AIR_CXX_BASE_COMMON_OM2_CODEGEN_HANDLER_TASK_CODE_BUILDER_FACTORY_H_
#include <functional>
#include <map>
#include <memory>
#include <string>
#include "framework/common/debug/ge_log.h"
#include "common/plugin/ge_make_unique_util.h"
#include "common/opskernel/ops_kernel_info_types.h"
namespace ge {
class TaskCodeBuilder;
class AstBuildContext;
using TaskCodeBuilderPtr = std::shared_ptr<TaskCodeBuilder>;
class TaskCodeBuilderFactory {
public:
using TaskCodeBuilderCreatorFun = std::function<TaskCodeBuilderPtr(AstBuildContext &)>;
static TaskCodeBuilderFactory &Instance() {
if (g_user_defined_instance_ != nullptr) {
return *g_user_defined_instance_;
}
static TaskCodeBuilderFactory instance;
return instance;
}
static void Replace(std::shared_ptr<TaskCodeBuilderFactory> ins) {
g_user_defined_instance_ = std::move(ins);
}
TaskCodeBuilderPtr Create(const ModelTaskType builder_type, AstBuildContext &ast) {
const auto iter = creator_map_.find(builder_type);
if (iter == creator_map_.end()) {
GELOGW("Cannot find builder type %d in inner map.", static_cast<int32_t>(builder_type));
return nullptr;
}
return iter->second(ast);
}
class Registerar {
public:
Registerar(const ModelTaskType type, const TaskCodeBuilderCreatorFun &func) noexcept {
TaskCodeBuilderFactory::Instance().RegisterCreator(type, func);
}
~Registerar() = default;
};
private:
TaskCodeBuilderFactory() = default;
~TaskCodeBuilderFactory() = default;
void RegisterCreator(const ModelTaskType type, const TaskCodeBuilderCreatorFun &func) {
const auto iter = creator_map_.find(type);
if (iter != creator_map_.end()) {
GELOGD("TaskCodeBuilderFactory::RegisterCreator: %d creator already exist", static_cast<int32_t>(type));
return;
}
creator_map_[type] = func;
}
std::map<ModelTaskType, TaskCodeBuilderCreatorFun> creator_map_;
inline static std::shared_ptr<TaskCodeBuilderFactory> g_user_defined_instance_ = nullptr;
};
#define REGISTER_TASK_CODE_BUILDER(type, clazz) \
namespace { \
TaskCodeBuilderPtr Creator_Task_Code_Builder_##type(AstBuildContext &ast) { \
std::shared_ptr<clazz> ptr = nullptr; \
ptr = MakeShared<clazz>(ast); \
return ptr; \
} \
TaskCodeBuilderFactory::Registerar g_Task_Code_Builder_Creator_##type(ModelTaskType::type, \
&Creator_Task_Code_Builder_##type); \
}
}
#endif