* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "op_tiling_kernel.h"
#include "graph/ge_error_codes.h"
#include "graph/node.h"
#include "register/kernel_registry.h"
#include "exe_graph/runtime/continuous_vector.h"
#include "common/checker.h"
#include "tiling.h"
#include "exe_graph/runtime/gert_tensor_data.h"
#include "core/executor/multi_thread_topological/executor/schedule/producer/producers/kernel_tags/critical_section_config.h"
#include "exe_graph/runtime/gert_mem_allocator.h"
namespace gert {
namespace kernel {
ge::graphStatus CreateTensorDataOutputs(const ge::FastNode *node, KernelContext *context) {
for (uint32_t i = 0U; i < kOpTilingOutputSize; ++i) {
auto av = context->GetOutput(i);
GE_ASSERT_NOTNULL(av);
auto tensor_data = new (std::nothrow) GertTensorData(0U, kOnHost, -1, nullptr);
if (tensor_data == nullptr) {
GELOGE(ge::FAILED, "Failed create addr outputs at index:%u for node %s", i, node->GetName().c_str());
return ge::GRAPH_FAILED;
}
av->SetWithDefaultDeleter(tensor_data);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CreateOpTilingShapeOutputs(const ge::FastNode *node, KernelContext *context) {
for (uint32_t i = 0U; i < kOpTilingOutputSize; ++i) {
auto av = context->GetOutput(i);
GE_ASSERT_NOTNULL(av);
auto storage_shape = new (std::nothrow) StorageShape();
if (storage_shape == nullptr) {
GELOGE(ge::FAILED, "Failed create shape outputs at index:%u for node %s", i, node->GetName().c_str());
return ge::GRAPH_FAILED;
}
if (kTilingOutputIndexToInputIndex[i] != TilingContext::kOutputTilingData) {
storage_shape->MutableStorageShape().SetDimNum(0UL);
storage_shape->MutableOriginShape().SetDimNum(0UL);
}
av->SetWithDefaultDeleter(storage_shape);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus BuildOpTilingTensorDataInner(const void *input_addr, const uint32_t index, KernelContext *context) {
auto stream_id = context->GetInputPointer<int64_t>(static_cast<size_t>(OpTilingExtendInputs::kStreamId));
GE_ASSERT_NOTNULL(stream_id);
if (input_addr == nullptr) {
GELOGE(ge::GRAPH_PARAM_INVALID,
"Index:%u of input is a invalid nullptr of Optiling(0-data 1-key 2-blockdim 3-cond).", index);
return ge::GRAPH_PARAM_INVALID;
}
auto out_tensor_data = context->GetOutputPointer<GertTensorData>(index);
if (out_tensor_data == nullptr) {
GELOGE(ge::GRAPH_PARAM_INVALID,
"Index:%u of output is a invalid nullptr of Optiling(0-data 1-key 2-blockdim 3-cond).", index);
return ge::GRAPH_PARAM_INVALID;
}
*out_tensor_data = GertTensorData{const_cast<void *>(input_addr), out_tensor_data->GetSize(),
out_tensor_data->GetPlacement(), *stream_id};
return ge::GRAPH_SUCCESS;
}
ge::graphStatus BuildOpTilingUnmanagedTensorData(KernelContext *context) {
int32_t tiling_cond = context->GetInputValue<int32_t>(TilingContext::kOutputTilingCond);
if (tiling_cond < 0) {
GELOGE(ge::GRAPH_PARAM_INVALID, "kOutputTilingCond value:%d is invalid which should be >= 0.", tiling_cond);
return ge::GRAPH_PARAM_INVALID;
}
for (uint32_t i = 0U; i < kOpTilingOutputSize; ++i) {
if ((kTilingOutputIndexToInputIndex[i] == TilingContext::kOutputTilingCond) ||
(kTilingOutputIndexToInputIndex[i] == TilingContext::kOutputBlockDim)) {
auto input_addr = context->GetInputPointer<void *>(kTilingOutputIndexToInputIndex[i]);
GE_ASSERT_SUCCESS(BuildOpTilingTensorDataInner(input_addr, i, context));
} else {
auto input_addr = context->GetInputValue<void *>(kTilingOutputIndexToInputIndex[i]);
GE_ASSERT_SUCCESS(BuildOpTilingTensorDataInner(input_addr, i, context));
}
}
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(BuildOpTilingUnmanagedTensorData)
.RunFunc(BuildOpTilingUnmanagedTensorData)
.OutputsCreator(CreateTensorDataOutputs);
ge::graphStatus BuildOpTilingOutputShape(KernelContext *context) {
auto tiling_data = context->GetInputValue<TilingData *>(TilingContext::kOutputTilingData);
if (tiling_data == nullptr) {
GELOGE(ge::FAILED, "Get tilling data nullptr result for build output shape failed.");
return ge::GRAPH_PARAM_INVALID;
}
const size_t data_size = tiling_data->GetDataSize();
if (data_size > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
GELOGE(ge::FAILED, "Data size:%zu is overflow to construct storage shape.", data_size);
return ge::GRAPH_PARAM_INVALID;
}
auto storage_shape = context->GetOutputPointer<StorageShape>(kTilingDataIndexOfOpTiling);
if (storage_shape == nullptr) {
GELOGE(ge::GRAPH_PARAM_INVALID, "Shape of tiling data output is nullptr");
return ge::GRAPH_PARAM_INVALID;
}
storage_shape->MutableStorageShape().SetDimNum(1UL);
storage_shape->MutableOriginShape().SetDimNum(1UL);
storage_shape->MutableStorageShape().SetDim(0UL, data_size);
storage_shape->MutableOriginShape().SetDim(0UL, data_size);
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(BuildOpTilingOutputShape)
.RunFunc(BuildOpTilingOutputShape)
.OutputsCreator(CreateOpTilingShapeOutputs);
}
}