* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_CXX_RUNTIME_V2_LOWERING_STATIC_MODEL_OUTPUT_ALLOCATOR_H_
#define AIR_CXX_RUNTIME_V2_LOWERING_STATIC_MODEL_OUTPUT_ALLOCATOR_H_
#include "common/ge_inner_attrs.h"
#include "common/model/ge_model.h"
#include "graph_builder/converter_checker.h"
#include "graph/load/model_manager/davinci_model.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_types.h"
#include "kernel/known_subgraph/davinci_model_kernel.h"
namespace gert {
enum class OutputReuseType { kReuseInput, kReuseOutput, kRefOutput, kRefVariable, kNoReuse, kEnd };
struct OutputReuseInfo {
bool is_reuse;
OutputReuseType reuse_type;
int32_t reuse_index;
kernel::MemoryBaseTypeOffset mem_base_type_offset;
ge::ConstGeTensorDescPtr ge_tensor_desc_ptr;
std::string var_name;
};
struct ParseParam {
ParseParam(const ge::OpDesc *op_desc, std::vector<OutputReuseInfo> &output_reuse_infos,
std::map<int64_t, int32_t> &offset_to_index_map) : op_desc(op_desc),
output_reuse_infos(output_reuse_infos),
offset_to_index_map(offset_to_index_map) {}
const ge::OpDesc *op_desc;
std::map<int64_t, int32_t> data_address_2_index{};
std::map<int64_t, ge::NodePtr> var_address_2_nodes{};
ge::NodePtr src_node{nullptr};
int32_t input_index{0};
int64_t input_address{0};
ge::ConstGeTensorDescPtr ge_tensor_desc_ptr{nullptr};
std::vector<OutputReuseInfo> &output_reuse_infos;
std::map<int64_t, int32_t> &offset_to_index_map;
};
using ParseFunc = std::function<ge::Status(ParseParam &)>;
class StaticModelOutputAllocator {
public:
StaticModelOutputAllocator(const bg::ValueHolderPtr &davinci_model_holder,
const std::vector<bg::DevMemValueHolderPtr> &input_addrs,
const bg::ValueHolderPtr &update_workspaces_holder);
StaticModelOutputAllocator(const bg::ValueHolderPtr &davinci_model_holder,
const std::vector<bg::DevMemValueHolderPtr> &input_addrs);
~StaticModelOutputAllocator();
LowerResult AllocAllOutputs(const std::vector<OutputReuseInfo> &output_reuse_infos,
LoweringGlobalData &global_data) const;
static ge::Status GenerateOutputsReuseInfos(const ge::ComputeGraphPtr &graph,
std::vector<OutputReuseInfo> &output_reuse_infos);
private:
std::vector<bg::DevMemValueHolderPtr> GetRefOutputsAddress(
const std::vector<kernel::MemoryBaseTypeOffset> &mem_base_types_offsets) const;
std::vector<bg::DevMemValueHolderPtr> AllocAllOutputsForRefOutputType(
const std::vector<OutputReuseInfo> &output_reuse_infos) const;
std::vector<bg::DevMemValueHolderPtr> AllocAllOutputsForRefVariableType(
const std::vector<OutputReuseInfo> &output_reuse_infos, LoweringGlobalData &global_data) const;
static std::vector<bg::DevMemValueHolderPtr> AllocAllOutputsForNoReuse(
const std::vector<OutputReuseInfo> &output_reuse_infos, LoweringGlobalData &global_data);
static std::vector<bg::ValueHolderPtr> GetNoReuseOutputsSize(const std::vector<OutputReuseInfo> &output_reuse_infos);
static ge::Status ParseModelOutputReuseInfo(ParseParam ¶m);
static ge::Status ParseReuseInputs(ParseParam ¶m);
static ge::Status ParseRefOutputs(ParseParam ¶m);
static ge::Status ParseRefVariable(ParseParam ¶m);
static ge::graphStatus ParseReuseOutputs(ParseParam ¶m);
private:
const bg::ValueHolderPtr davinci_model_holder_;
std::vector<bg::DevMemValueHolderPtr> input_addrs_;
const bg::ValueHolderPtr update_workspaces_holder_;
};
}
#endif