* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "tiling.h"
#include "ge/ge_error_codes.h"
#include "register/kernel_registry.h"
#include "base/registry/op_impl_space_registry_v2.h"
#include "graph/utils/math_util.h"
#include "graph/args_format_desc.h"
#include "common/checker.h"
#include "engine/node_converter_utils.h"
#include "exe_graph/lowering/shape_utils.h"
#include "common/op_tiling/tiling_dfx.h"
#include "adump_pub.h"
#include "mmpa/mmpa_api.h"
#include "compatible_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "register/op_tiling/op_compile_info_manager.h"
#include "register/op_tiling_registry.h"
#include "register/op_tiling/op_tiling_utils.h"
namespace gert {
namespace kernel {
namespace {
using TilingFuncV4 = bool (*)(const ge::Operator &, const optiling::CompileInfoPtr, optiling::OpRunInfoV2 &);
using TilingFuncV3 = bool (*)(const ge::Operator &, const void *, optiling::OpRunInfoV2 &);
ge::graphStatus TilingV4Legacy(const KernelContext *const context, const ge::Operator &op,
optiling::OpRunInfoV2 &op_run_info) {
auto tiling_func_index = static_cast<size_t>(CompatibleTilingInputIndex::kTilingFwkData);
auto tiling_func = context->GetInputValue<TilingFuncV4>(tiling_func_index);
GE_ASSERT_NOTNULL(tiling_func);
auto compile_info_index = static_cast<size_t>(CompatibleTilingInputIndex::kCompileInfo);
auto compile_info = context->GetInputPointer<optiling::CompileInfoBase *>(compile_info_index);
GE_ASSERT_NOTNULL(compile_info);
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
GE_ASSERT_NOTNULL(op_desc);
const bool ret = (tiling_func)(op, (const std::shared_ptr<optiling::CompileInfoBase> &)(*compile_info), op_run_info);
GE_ASSERT_TRUE(ret, "Fail to call op tiling function v4 of op[%s, %s].", op_desc->GetName().c_str(),
op_desc->GetType().c_str());
GELOGI("Do optiling v4 succeed. op_name:%s, op_type:%s.", op_desc->GetNamePtr(), op_desc->GetTypePtr());
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingV3Legacy(const KernelContext *const context, const ge::Operator &op,
optiling::OpRunInfoV2 &op_run_info) {
auto tiling_func_index = static_cast<size_t>(CompatibleTilingInputIndex::kTilingFwkData);
auto tiling_func = context->GetInputValue<TilingFuncV3>(tiling_func_index);
GE_ASSERT_NOTNULL(tiling_func);
auto compile_info_index = static_cast<size_t>(CompatibleTilingInputIndex::kCompileInfo);
auto compile_info = context->GetInputValue<void *>(compile_info_index);
GE_ASSERT_NOTNULL(compile_info);
const bool ret = (tiling_func)(op, compile_info, op_run_info);
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
GE_ASSERT_NOTNULL(op_desc);
GE_ASSERT_TRUE(ret, "Fail to call op tiling function v3 of op[%s, %s].", op_desc->GetName().c_str(),
op_desc->GetType().c_str());
GELOGI("Do optiling v3 succeed. op_name:%s, op_type:%s.", op_desc->GetNamePtr(), op_desc->GetTypePtr());
return ge::GRAPH_SUCCESS;
}
}
ge::graphStatus TilingProcLegacy(KernelContext *context, ge::graphStatus &tiling_func_result) {
auto input_num = context->GetInputNum();
GE_ASSERT_TRUE(input_num >= 1, "[Tiling] para check failed input_num %" PRId64 "", input_num);
auto tiling_data = reinterpret_cast<TilingContext *>(context)->GetRawTilingData();
if (tiling_data != nullptr) {
tiling_data->SetDataSize(0);
}
auto tiling_func = context->GetInputValue<KernelRegistry::KernelFunc>(input_num - 1);
GE_ASSERT_NOTNULL(tiling_func);
tiling_func_result = tiling_func(context);
if (tiling_func_result == ge::GRAPH_SUCCESS) {
return AlignWorkspaceSizes(context);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingLegacy(KernelContext *context) {
ge::graphStatus tiling_func_result;
GE_ASSERT_SUCCESS(TilingProcLegacy(context, tiling_func_result));
return tiling_func_result;
}
REGISTER_KERNEL(TilingLegacy)
.RunFunc(TilingLegacy)
.OutputsCreator(BuildTilingOutputs)
.ExceptionDumpInfoFiller(FillTilingInfo)
.TracePrinter(PrintTilingData);
ge::graphStatus InnerCompatibleTilingLegacy(KernelContext *context, ge::graphStatus &tiling_func_result) {
auto input_num = context->GetInputNum();
if (input_num < static_cast<size_t>(CompatibleTilingInputIndex::kTilingFuncInputNum)) {
GELOGE(ge::GRAPH_PARAM_INVALID, "[Tiling] para check failed input_num %ld", input_num);
return ge::GRAPH_FAILED;
}
auto tiling_data = reinterpret_cast<TilingContext *>(context)->GetRawTilingData();
if (tiling_data != nullptr) {
tiling_data->SetDataSize(0);
}
auto op = context->GetInputValue<ge::Operator *>(static_cast<size_t>(CompatibleTilingInputIndex::kOp));
GE_ASSERT_NOTNULL(op);
ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op);
GE_ASSERT_NOTNULL(op_desc);
std::vector<ge::GeTensorPtr> tensor_holder;
if (!op_desc->GetOpInferDepends().empty()) {
auto callback = [&context, &tensor_holder](const ge::ConstNodePtr &node, const size_t index,
ge::GeTensorPtr &tensor) {
(void)node;
auto infer_shape_context = reinterpret_cast<TilingContext *>(context);
auto input_start_pos = static_cast<size_t>(CompatibleTilingInputIndex::kTilingFuncInputNum);
auto shape_tensor = infer_shape_context->GetInputTensor(index + input_start_pos);
return KernelCompatibleUtils::ConvertRTTensorToGeTensor(shape_tensor, tensor, tensor_holder);
};
ge::OpDescUtils::SetCallbackGetConstInputFuncToOperator(*op, callback);
}
GE_CHK_STATUS_RET_NOLOG(UpdateIOShapeToOp(context, *op));
std::vector<int32_t> indexes;
optiling::ReplaceEmptyShapeOfTensorDesc(op_desc, indexes);
auto tiling_version_index = static_cast<size_t>(CompatibleTilingInputIndex::kTilingVersion);
auto tiling_version = context->GetInputValue<size_t>(tiling_version_index);
optiling::OpRunInfoV2 op_run_info;
if (tiling_version == static_cast<uint64_t>(TilingVersion::kV4)) {
tiling_func_result = TilingV4Legacy(context, *op, op_run_info);
} else if (tiling_version == static_cast<uint64_t>(TilingVersion::kV3)) {
tiling_func_result = TilingV3Legacy(context, *op, op_run_info);
} else {
GELOGE(ge::GRAPH_FAILED, "Failed to tiling parse, not support v3 or v4.");
return ge::GRAPH_FAILED;
}
optiling::RecoveryEmptyShapeOfTensorDesc(op_desc, indexes);
if (tiling_func_result == ge::GRAPH_SUCCESS) {
UpdateTilingOutputsToContext(op_run_info, context);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CompatibleTilingLegacy(KernelContext *context) {
ge::graphStatus tiling_func_result;
GE_ASSERT_SUCCESS(InnerCompatibleTilingLegacy(context, tiling_func_result));
return tiling_func_result;
}
REGISTER_KERNEL(CompatibleTilingLegacy)
.RunFunc(CompatibleTilingLegacy)
.OutputsCreator(BuildTilingOutputs)
.ExceptionDumpInfoFiller(FillTilingInfo);
}
}