* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "tiling_dfx.h"
#include "framework/common/debug/ge_log.h"
#include "graph/args_format_desc.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/math_util.h"
#include "common/checker.h"
#include "register/op_tiling/op_tiling_constants.h"
#include "exe_graph/runtime/tiling_data.h"
#include "graph/utils/op_desc_utils.h"
namespace optiling {
ge::Status TilingDfx::GetArgsSizeWithArgsFormat(const ge::OpDescPtr &op_desc,
const std::vector<ge::ArgDesc> &arg_descs,
std::vector<int64_t> &args_size_list,
std::vector<ArgsIndexToIoIndex> &args_index_to_io_index) {
std::map<size_t, std::pair<size_t, size_t>> ir_input_2_range;
(void)(ge::OpDescUtils::GetIrInputInstanceDescRange(op_desc, ir_input_2_range));
std::map<size_t, std::pair<size_t, size_t>> ir_output_2_range;
(void)(ge::OpDescUtils::GetIrOutputDescRange(op_desc, ir_output_2_range));
for (size_t idx = 0U; idx < arg_descs.size(); idx++) {
switch (arg_descs[idx].addr_type) {
case ge::AddrType::INPUT_DESC: {
const size_t ir_idx = static_cast<size_t>(arg_descs[idx].ir_idx);
const auto iter = ir_input_2_range.find(ir_idx);
GE_ASSERT(iter != ir_input_2_range.end(), "input ir idx [%zu] is not found", ir_idx);
size_t arg_size = 0UL;
GE_ASSERT_GRAPH_SUCCESS(ge::ArgsFormatDesc::GetArgSize(op_desc, arg_descs[idx], arg_size));
GE_ASSERT_TRUE(arg_size > 0UL);
(void)args_size_list.emplace_back(static_cast<int64_t>(arg_size));
break;
}
case ge::AddrType::OUTPUT_DESC: {
const size_t ir_idx = static_cast<size_t>(arg_descs[idx].ir_idx);
const auto iter = ir_output_2_range.find(ir_idx);
GE_ASSERT(iter != ir_output_2_range.end(), "output ir idx [%zu] is not found", ir_idx);
size_t arg_size = 0UL;
GE_ASSERT_GRAPH_SUCCESS(ge::ArgsFormatDesc::GetArgSize(op_desc, arg_descs[idx], arg_size));
GE_ASSERT_TRUE(arg_size > 0UL);
(void)args_size_list.emplace_back(static_cast<int64_t>(arg_size));
break;
}
case ge::AddrType::INPUT_INSTANCE: {
const size_t instance_index = static_cast<size_t>(arg_descs[idx].ir_idx);
ArgsIndexToIoIndex args_idx_to_io_idx = {ArgsRole::kInput, args_size_list.size(), instance_index};
(void)args_index_to_io_index.emplace_back(std::move(args_idx_to_io_idx));
(void)args_size_list.emplace_back(0);
break;
}
case ge::AddrType::OUTPUT_INSTANCE: {
const size_t instance_index = static_cast<size_t>(arg_descs[idx].ir_idx);
ArgsIndexToIoIndex args_idx_to_io_idx = {ArgsRole::kOutput, args_size_list.size(), instance_index};
(void)args_index_to_io_index.emplace_back(std::move(args_idx_to_io_idx));
(void)args_size_list.emplace_back(0);
break;
}
case ge::AddrType::INPUT: {
const size_t ir_idx = static_cast<size_t>(arg_descs[idx].ir_idx);
const auto iter = ir_input_2_range.find(ir_idx);
GE_ASSERT(iter != ir_input_2_range.end(), "input ir idx [%zu] is not found", ir_idx);
const auto &range_pair = iter->second;
if (range_pair.second == 0UL) {
(void)args_size_list.emplace_back(0);
break;
}
size_t begin_idx = range_pair.first;
size_t loop_times = range_pair.second;
while (loop_times-- > 0UL) {
ArgsIndexToIoIndex args_idx_to_io_idx = {ArgsRole::kInput, args_size_list.size(), begin_idx};
(void)args_index_to_io_index.emplace_back(std::move(args_idx_to_io_idx));
(void)args_size_list.emplace_back(0);
++begin_idx;
}
break;
}
case ge::AddrType::OUTPUT: {
const size_t ir_idx = static_cast<size_t>(arg_descs[idx].ir_idx);
const auto iter = ir_output_2_range.find(ir_idx);
GE_ASSERT(iter != ir_output_2_range.end(), "output ir idx [%zu] is not found", ir_idx);
const auto &range_pair = iter->second;
size_t begin_idx = range_pair.first;
size_t loop_times = range_pair.second;
while (loop_times-- > 0UL) {
ArgsIndexToIoIndex args_idx_to_io_idx = {ArgsRole::kOutput, args_size_list.size(), begin_idx};
(void)args_index_to_io_index.emplace_back(std::move(args_idx_to_io_idx));
(void)args_size_list.emplace_back(0);
++begin_idx;
}
break;
}
case ge::AddrType::FFTS_ADDR:
break;
case ge::AddrType::HIDDEN_INPUT:
case ge::AddrType::PLACEHOLDER:
(void)args_size_list.emplace_back(0);
break;
default:
return ge::SUCCESS;
}
}
return ge::SUCCESS;
}
ge::Status TilingDfx::GetArgsSizeWithoutArgsFormat(size_t input_size,
size_t output_size,
std::vector<int64_t> &args_size_list,
std::vector<ArgsIndexToIoIndex> &args_index_to_io_index) {
size_t io_size = 0U;
GE_ASSERT_TRUE(!ge::AddOverflow(input_size, output_size, io_size));
(void) args_size_list.insert(args_size_list.cend(), io_size, 0);
for (size_t index = 0U; index < input_size; index++) {
ArgsIndexToIoIndex args_idx_to_io_idx = {ArgsRole::kInput, index, index};
(void)args_index_to_io_index.emplace_back(std::move(args_idx_to_io_idx));
}
for (size_t index = 0U; index < output_size; index++) {
ArgsIndexToIoIndex args_idx_to_io_idx = {ArgsRole::kOutput, input_size + index, index};
(void)args_index_to_io_index.emplace_back(std::move(args_idx_to_io_idx));
}
return ge::SUCCESS;
}
}