* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file circular_pad_tiling_tiling.cpp
* \brief
*/
#include "circular_pad_tiling.h"
#include "log/log.h"
#include "register/op_def_registry.h"
#include "op_host/math_tiling_templates_registry.h"
namespace optiling {
constexpr int32_t ALIGN = 32;
constexpr int64_t BUFFER_NUM = 2;
constexpr int64_t DIM_1 = 1;
constexpr int64_t DIM_2 = 2;
constexpr int64_t TYPE_MODE1 = 1;
constexpr int64_t TYPE_MODE2 = 2;
constexpr int64_t TYPE_MODE3 = 3;
constexpr int64_t TYPE_MODE4 = 4;
constexpr int64_t TYPE_MODE5 = 5;
ge::graphStatus CircularPadTiling::CheckLeftAndRight()
{
std::stringstream ss;
if (left > 0 && right > 0 && (left > inputW || right > inputW)) {
ss << "left/right should not be greater than inputW," << "when left/right is greater than 0.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
} else if (left > 0 && right < 0 && (left > inputW + right || -right >= inputW)) {
ss << "left should not be greater than inputW + right," << "and abs(right) should not be greater than inputW,"
<< "when left is greater than 0 and right is not greater than 0.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
} else if (left < 0 && right > 0 && (-left >= inputW || right > inputW + left)) {
ss << "right should not be greater than inputW + left," << "and abs(left) should not be greater than inputW,"
<< "when right is greater than 0 and left is not greater than 0.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
} else if (left < 0 && right < 0 && (-left - right >= inputW)) {
ss << "abs(left + right) should not be greater than inputW.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CircularPadTiling::CheckTopAndBottom()
{
std::stringstream ss;
if (top > 0 && bottom > 0 && (top > inputH || bottom > inputH)) {
ss << "top/bottom should not be greater than inputH," << "when top/bottom is greater than 0.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
} else if (top > 0 && bottom < 0 && (top > inputH + bottom || -bottom >= inputH)) {
ss << "top should not be greater than inputH + bottom," << "and abs(bottom) should not be greater than inputH,"
<< "when top is greater than 0 and bottom is not greater than 0.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
} else if (top < 0 && bottom > 0 && (-top >= inputH || bottom > inputH + top)) {
ss << "bottom should not be greater than inputH + top," << "and abs(top) should not be greater than inputH,"
<< "when bottom is greater than 0 and top is not greater than 0.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
} else if (top < 0 && bottom < 0 && (-top - bottom >= inputH)) {
ss << "abs(top + bottom) should not be greater than inputH.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CircularPadTiling::CheckFrontAndBack()
{
std::stringstream ss;
if (front > 0 && back > 0 && (front > inputL || back > inputL)) {
ss << "front/back should not be greater than inputL," << "when front/back is greater than 0.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
} else if (front > 0 && back < 0 && (front > inputL + back || -back >= inputL)) {
ss << "front should not be greater than inputL + back," << "and abs(back) should not be greater than inputL,"
<< "when front is greater than 0 and back is not greater than 0.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
} else if (front < 0 && back > 0 && (-front >= inputL || back > inputL + front)) {
ss << "back should not be greater than inputL + front," << "and abs(front) should not be greater than inputL,"
<< "when back is greater than 0 and front is not greater than 0.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
} else if (front < 0 && back < 0 && (-front - back >= inputL)) {
ss << "abs(front + back) should not be greater than inputL.";
OP_LOGE(context_->GetNodeName(), "%s", ss.str().c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CircularPadTiling::CheckInput()
{
if ((inputW >= (int64_t)ubSize / BUFFER_NUM / tSize)) {
OP_LOGE(context_->GetNodeName(), "The last DIM of x is valid.");
return ge::GRAPH_FAILED;
}
if ((outputH != inputH + top + bottom || outputW != inputW + left + right || outputL != inputL + front + back)) {
OP_LOGE(context_->GetNodeName(), "output's shape error.");
return ge::GRAPH_FAILED;
}
auto ret = CheckLeftAndRight();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckTopAndBottom();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckFrontAndBack();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CircularPadTiling::CheckDtype()
{
auto xDataType = context_->GetInputDesc(0)->GetDataType();
if (xDataType == ge::DataType::DT_INT8) {
dataType = TYPE_MODE1;
tSize = sizeof(int8_t);
} else if (xDataType == ge::DataType::DT_FLOAT16) {
dataType = TYPE_MODE2;
tSize = sizeof(uint16_t);
} else if (xDataType == ge::DataType::DT_BF16) {
dataType = TYPE_MODE3;
tSize = sizeof(uint16_t);
} else if (xDataType == ge::DataType::DT_FLOAT) {
dataType = TYPE_MODE4;
tSize = sizeof(float);
} else if (xDataType == ge::DataType::DT_INT32) {
dataType = TYPE_MODE5;
tSize = sizeof(int32_t);
} else {
OP_LOGE(context_->GetNodeName(), "Unsupport type.");
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CircularPadTiling::DoOpTiling()
{
auto ret = CheckDtype();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
CalculateParams();
inOutputH = inputH + nTop + nBottom;
inOutputW = inputW + nLeft + nRight;
inOutputL = inputL + nFront + nBack;
outputWAlign = (outputW * tSize + ALIGN - DIM_1) / ALIGN * ALIGN / tSize;
inOutputWAlign = (inOutputW * tSize + ALIGN - DIM_1) / ALIGN * ALIGN / tSize;
ret = CheckInput();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
if (inOutputH * outputWAlign < (int64_t)ubSize / BUFFER_NUM / tSize) {
shapeType = DIM_1;
workspaceLen = inOutputH * (leftAlign + inOutputWAlign + rightAlign);
} else {
shapeType = DIM_2;
leftAlign = left > 0 ? leftAlign : ALIGN / tSize;
rightAlign = right > 0 ? rightAlign : ALIGN / tSize;
leftAlign = leftAlign > inputW ? left : leftAlign;
rightAlign = rightAlign > inputW ? right : rightAlign;
workspaceLen = inOutputH * (leftAlign + rightAlign);
}
DivCore();
SetTilingKey();
SetTilingData();
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus Tiling4CircularPad(gert::TilingContext* context)
{
CircularPadTiling circularPadTiling(context);
auto ret = circularPadTiling.GetShapeAttrsInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = circularPadTiling.GetPlatformInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = circularPadTiling.DoOpTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = circularPadTiling.GetWorkspaceSize();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = circularPadTiling.PostTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
context->SetTilingKey(circularPadTiling.GetTilingKey());
circularPadTiling.DumpTilingInfo();
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingPrepare4CircularPad(gert::TilingParseContext* context)
{
OP_LOGD(context->GetNodeName(), "TilingPrepare4CircularPad enter.");
auto compileInfo = context->GetCompiledInfo<Tiling4CircularPadCommonCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->coreNum = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF(
(compileInfo->coreNum <= 0),
OP_LOGE(
context->GetNodeName(), "Get core num failed, core num: %u", static_cast<uint32_t>(compileInfo->coreNum)),
return ge::GRAPH_FAILED);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfo->ubSize);
OP_CHECK_IF(
(compileInfo->ubSize <= 0),
OP_LOGE(context->GetNodeName(), "Get ub size failed, ub size: %u", static_cast<uint32_t>(compileInfo->ubSize)),
return ge::GRAPH_FAILED);
compileInfo->sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
OP_LOGD(context->GetNodeName(), "TilingPrepare4CircularPad exit.");
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(CircularPad)
.Tiling(Tiling4CircularPad)
.TilingParse<Tiling4CircularPadCommonCompileInfo>(TilingPrepare4CircularPad)
.TilingInputsDataDependency({1});
}