* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "exe_graph/lowering/tiling_context_builder.h"
#include "exe_graph/lowering/bg_kernel_context_extend.h"
#include "lowering/data_dependent_interpreter.h"
#include "graph/compute_graph.h"
#include "graph/operator.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils_ex.h"
#include "graph_metadef/graph/debug/ge_util.h"
#include "graph/def_types.h"
#include "common/checker.h"
#include "graph_metadef/graph/debug/ge_util.h"
namespace gert {
namespace {
void GetStorageShape(const ge::GeTensorDesc &tensor_desc, gert::StorageShape &storage_shape) {
const auto &storage_dims = tensor_desc.GetShape().GetDims();
for (const auto &dim : storage_dims) {
(void)storage_shape.MutableStorageShape().AppendDim(dim);
}
const auto &origin_dims = tensor_desc.GetOriginShape().GetDims();
for (const auto &dim : origin_dims) {
(void)storage_shape.MutableOriginShape().AppendDim(dim);
}
}
}
TilingContextBuilder &TilingContextBuilder::CompileInfo(void *compile_info) {
compile_info_ = compile_info;
return *this;
}
TilingContextBuilder &TilingContextBuilder::PlatformInfo(void *platform_info) {
platform_info_ = platform_info;
return *this;
}
TilingContextBuilder &TilingContextBuilder::Deterministic(int32_t deterministic) {
deterministic_ = deterministic;
return *this;
}
TilingContextBuilder &TilingContextBuilder::DeterministicLevel(int32_t deterministic_level) {
deterministic_level_ = deterministic_level;
return *this;
}
TilingContextBuilder &TilingContextBuilder::TilingData(void *tiling_data) {
outputs_[TilingContext::kOutputTilingData] = tiling_data;
return *this;
}
TilingContextBuilder &TilingContextBuilder::Workspace(ContinuousVector *workspace) {
outputs_[TilingContext::kOutputWorkspace] = workspace;
return *this;
}
TilingContextBuilder &TilingContextBuilder::SetSpaceRegistryV2(const OpImplSpaceRegistryV2Ptr &space_registry, OppImplVersionTag version_tag) {
if (version_tag >= OppImplVersionTag::kVersionEnd) {
GELOGE(ge::PARAM_INVALID, "version_tag %d is invalid", static_cast<int>(version_tag));
return *this;
}
space_registries_v2_[static_cast<size_t>(version_tag)] = space_registry;
return *this;
}
ge::graphStatus TilingContextBuilder::GetDependInputTensorAddr(const ge::Operator &op, const size_t input_idx,
TensorAddress &address) {
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
GE_ASSERT_NOTNULL(op_desc);
auto depend_tensor = ge::ComGraphMakeUnique<ge::Tensor>();
depend_ge_tensor_holders_.emplace_back(std::move(depend_tensor));
GE_ASSERT_NOTNULL(depend_ge_tensor_holders_.back());
auto input_name = op_desc->GetValidInputNameByIndex(static_cast<uint32_t>(input_idx));
if (op.GetInputConstData(input_name.c_str(), *(depend_ge_tensor_holders_.back().get())) == ge::GRAPH_SUCCESS) {
address = depend_ge_tensor_holders_.back()->GetData();
} else {
address = nullptr;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingContextBuilder::BuildRtTensor(const ge::GeTensorDesc &tensor_desc,
TensorAddress address,
std::unique_ptr<uint8_t[]> &rt_tensor_holder) const {
gert::StorageShape storage_shape;
GetStorageShape(tensor_desc, storage_shape);
rt_tensor_holder = ge::ComGraphMakeUnique<uint8_t[]>(sizeof(gert::Tensor));
GE_ASSERT_NOTNULL(rt_tensor_holder, "Create context holder inputs failed.");
if (address == nullptr) {
new (rt_tensor_holder.get())
gert::Tensor(storage_shape,
{tensor_desc.GetOriginFormat(), tensor_desc.GetFormat(), {}},
tensor_desc.GetDataType());
} else {
new (rt_tensor_holder.get())
gert::Tensor(storage_shape,
{tensor_desc.GetOriginFormat(), tensor_desc.GetFormat(), {}},
gert::kOnHost, tensor_desc.GetDataType(), address);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingContextBuilder::BuildRTInputTensors(const ge::Operator &op) {
const auto node = ge::NodeUtilsEx::GetNodeFromOperator(op);
auto shared_node = const_cast<ge::Node *>(node.get())->shared_from_this();
std::shared_ptr<DataDependentInterpreter> ddi = nullptr;
ddi = ge::ComGraphMakeShared<DataDependentInterpreter>(shared_node->GetOpDesc(), space_registries_v2_);
GE_ASSERT_NOTNULL(ddi);
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
size_t valid_input_idx = 0U;
const auto &all_in_data_anchors = node->GetAllInDataAnchorsPtr();
for (size_t i = 0U; i < all_in_data_anchors.size(); ++i) {
GE_ASSERT_NOTNULL(all_in_data_anchors.at(i));
if (all_in_data_anchors.at(i)->GetPeerOutAnchor() == nullptr) {
continue;
}
TensorAddress address = nullptr;
bool is_data_dependent = false;
GE_ASSERT_SUCCESS(ddi->IsDataDependent(static_cast<int32_t>(valid_input_idx), is_data_dependent));
bool is_tiling_dependent = false;
if (!is_data_dependent) {
GE_ASSERT_SUCCESS(ddi->IsTilingInputDataDependent(static_cast<int32_t>(valid_input_idx), is_tiling_dependent));
}
GELOGD("Node: %s input: %zu data/tiling depend flag: %d/%d", node->GetNamePtr(), valid_input_idx, is_data_dependent,
is_tiling_dependent);
is_data_dependent = is_data_dependent || is_tiling_dependent;
if (is_data_dependent) {
GE_ASSERT_GRAPH_SUCCESS(GetDependInputTensorAddr(op, valid_input_idx, address));
}
std::unique_ptr<uint8_t[]> tensor_holder;
const auto &valid_op_desc = op_desc->GetInputDescPtr(i);
GE_ASSERT_NOTNULL(valid_op_desc);
GE_ASSERT_GRAPH_SUCCESS(BuildRtTensor(*valid_op_desc, address, tensor_holder));
rt_tensor_holders_.emplace_back(std::move(tensor_holder));
++valid_input_idx;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingContextBuilder::BuildRTOutputShapes(const ge::Operator &op) {
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
GE_ASSERT_NOTNULL(op_desc);
for (size_t i = 0U; i < op_desc->GetOutputsSize(); ++i) {
gert::StorageShape storage_shape;
GetStorageShape(op_desc->GetOutputDesc(i), storage_shape);
std::unique_ptr<uint8_t[]> tensor_holder;
GE_ASSERT_GRAPH_SUCCESS(BuildRtTensor(op_desc->GetOutputDesc(i), nullptr, tensor_holder));
GE_ASSERT_NOTNULL(tensor_holder, "Create context holder outputs failed, op[%s]", op_desc->GetNamePtr());
rt_tensor_holders_.emplace_back(std::move(tensor_holder));
}
return ge::GRAPH_SUCCESS;
}
KernelContextHolder TilingContextBuilder::Build(const ge::Operator &op, ge::graphStatus &ret) {
ret = ge::GRAPH_FAILED;
KernelContextHolder holder{};
if (compile_info_ == nullptr) {
GELOGE(ge::GRAPH_PARAM_INVALID, "Please give tiling context builder compile info.");
return holder;
}
if (platform_info_ == nullptr) {
GELOGE(ge::GRAPH_PARAM_INVALID, "Please give tiling context builder platform info.");
return holder;
}
auto node = ge::NodeUtilsEx::GetNodeFromOperator(op);
std::vector<void *> context_inputs;
auto build_ret = BuildRTInputTensors(op);
if (build_ret != ge::GRAPH_SUCCESS) {
GELOGE(ge::GRAPH_PARAM_INVALID, "Fail to BuildRTInputTensors.");
return holder;
}
build_ret = BuildRTOutputShapes(op);
if (build_ret != ge::GRAPH_SUCCESS) {
GELOGE(ge::GRAPH_PARAM_INVALID, "Fail to BuildRTOutputShapes.");
return holder;
}
for (const auto &input_holder : rt_tensor_holders_) {
context_inputs.emplace_back(input_holder.get());
}
context_inputs.emplace_back(compile_info_);
context_inputs.emplace_back(platform_info_);
context_inputs.emplace_back(nullptr);
context_inputs.emplace_back(reinterpret_cast<void *>(deterministic_));
context_inputs.emplace_back(reinterpret_cast<void *>(deterministic_level_));
return base_builder_.Inputs(context_inputs).Outputs(outputs_).Build(node->GetOpDesc(), ret);
}
KernelContextHolder TilingContextBuilder::Build(const ge::Operator &op) {
ge::Status ret = ge::GRAPH_FAILED;
return Build(op, ret);
}
AtomicTilingContextBuilder &AtomicTilingContextBuilder::CompileInfo(void *compile_info) {
compile_info_ = compile_info;
return *this;
}
AtomicTilingContextBuilder &AtomicTilingContextBuilder::CleanWorkspaceSizes(ContinuousVector *workspace_sizes) {
worksapce_sizes_ = reinterpret_cast<void *>(workspace_sizes);
return *this;
}
AtomicTilingContextBuilder &AtomicTilingContextBuilder::CleanOutputSizes(const std::vector<int64_t> &output_sizes) {
clean_output_sizes_ = output_sizes;
return *this;
}
AtomicTilingContextBuilder &AtomicTilingContextBuilder::TilingData(void *tiling_data) {
outputs_[TilingContext::kOutputTilingData] = tiling_data;
return *this;
}
AtomicTilingContextBuilder &AtomicTilingContextBuilder::Workspace(ContinuousVector *workspace) {
outputs_[TilingContext::kOutputWorkspace] = workspace;
return *this;
}
KernelContextHolder AtomicTilingContextBuilder::Build(const ge::Operator &op, ge::graphStatus &ret) {
ret = ge::GRAPH_FAILED;
KernelContextHolder holder{};
if (compile_info_ == nullptr) {
GELOGE(ge::GRAPH_PARAM_INVALID, "Please give tiling context builder compile info.");
return holder;
}
std::vector<void *> context_inputs;
context_inputs.emplace_back(worksapce_sizes_);
for (const int64_t out_size : clean_output_sizes_) {
context_inputs.emplace_back(reinterpret_cast<void *>(out_size));
}
context_inputs.emplace_back(compile_info_);
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
return base_builder_.Inputs(context_inputs).Outputs(outputs_).Build(op_desc, ret);
}
KernelContextHolder AtomicTilingContextBuilder::Build(const ge::Operator &op) {
ge::graphStatus ret = ge::GRAPH_FAILED;
return Build(op, ret);
}
}