* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_CXX_BASE_COMMON_OM2_CODEGEN_TASK_CODE_BUILDER_TASK_CODE_BUILDER_H_
#define AIR_CXX_BASE_COMMON_OM2_CODEGEN_TASK_CODE_BUILDER_TASK_CODE_BUILDER_H_
#include "common/om2/codegen/ast/ast_build_context.h"
#include "common/om2/codegen/ast/ast_nodes.h"
#include "common/om2/codegen/om2_model_class_generator_base.h"
#include "common/om2/codegen/om2_codegen_types.h"
#include "graph/op_desc.h"
#include "proto/task.pb.h"
namespace ge {
struct IoAddrRefreshRecord {
uint64_t compile_state_io_addr_offset{0U};
uint64_t host_offset{0U};
};
struct TaskSemanticContributeContext {
ModelTaskType task_type;
const domi::TaskDef &task_def;
int64_t op_index{kInvalidOpIndex};
OpDescPtr op_desc;
const RuntimeResourceSemantic *runtime{nullptr};
ModelIoSemantic *model_io{nullptr};
const std::unordered_map<std::string, uint32_t> *func_handle_indices{nullptr};
const std::unordered_map<int64_t, std::string> *weight_offset_to_varname{nullptr};
const std::unordered_map<int64_t, std::string> *fileconst_output_offset_to_varname{nullptr};
std::unordered_map<int64_t, OpInputEdges> *op_id_to_input_edges{nullptr};
std::unordered_map<uint32_t, uint32_t> *op_index_to_count_map{nullptr};
uint64_t *next_args_table_index{nullptr};
uint64_t *next_host_args_offset{nullptr};
uint32_t *aicpu_task_count{nullptr};
};
class TaskCodeBuilder : public Om2ModelClassGeneratorBase {
public:
explicit TaskCodeBuilder(AstBuildContext &ast) : Om2ModelClassGeneratorBase(ast) {}
~TaskCodeBuilder() override = default;
virtual int64_t ParseOpIndex(const domi::TaskDef &task_def) {
(void)task_def;
return -1L;
}
virtual Status Contribute(TaskSemanticContributeContext &context) {
FillTaskSemanticHeader(context, header_);
return SUCCESS;
}
virtual Status RenderInitResource(std::vector<BodyItem> &items) {
(void)items;
return SUCCESS;
}
virtual Status RenderDistribution(std::vector<BodyItem> &items) = 0;
virtual Status RenderDistHelper(std::vector<DeclNode *> &items) = 0;
const ArgsTableEntrySemantic *GetArgsTableEntry() const {
return args_table_entry_;
}
const std::vector<IoAddrRefreshRecord> &GetIoAddrRefreshRecords() const {
return io_addr_refresh_records_;
}
protected:
static void FillTaskSemanticHeader(const TaskSemanticContributeContext &context, TaskSemanticHeader &header) {
header.op_index = context.op_index;
header.stream_id = context.task_def.stream_id();
if (context.op_desc != nullptr) {
header.op_desc_id = context.op_desc->GetId();
header.op_name = context.op_desc->GetName();
header.op_type = context.op_desc->GetType();
}
}
static std::vector<Arg> ConvertToArgs(const std::vector<int64_t> &values) {
std::vector<Arg> args;
args.reserve(values.size());
for (const auto value : values) {
args.emplace_back(value);
}
return args;
}
TaskSemanticHeader header_;
const ArgsTableEntrySemantic *args_table_entry_{nullptr};
std::vector<IoAddrRefreshRecord> io_addr_refresh_records_;
};
}
#endif