* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file pad_tiling_impl.cpp
* \brief
*/
#include "include/adv_api/pad/pad_tiling.h"
#include "../../detail/host_log.h"
namespace optiling {
REGISTER_TILING_DATA_CLASS(UnPadTilingOpApi, UnPadTiling);
REGISTER_TILING_DATA_CLASS(PadTilingOpApi, PadTiling);
}
namespace AscendC {
namespace {
const uint32_t PAD_ONE_BLK_SIZE = 32;
const uint32_t PAD_MAX_REPEAT_TIMES = 255;
const uint32_t PAD_NCHW_CONV_ADDR_LIST_SIZE = 16;
const uint32_t TMP_BUFFER_NUM = 2;
constexpr uint32_t PAD_DIM_TWO = 2;
constexpr uint32_t PAD_TYPE_TWO = 2;
constexpr uint32_t PAD_TYPE_FOUR = 4;
void CheckPadMaxMinTmpSizeParams(const ge::Shape& srcShape, const uint32_t typeSize, const char* funcName)
{
ASCENDC_HOST_ASSERT(
typeSize == PAD_TYPE_TWO || typeSize == PAD_TYPE_FOUR, continue,
"[Pad][%s] The value of typeSize is %u, should be 2 or 4.", funcName, typeSize);
ASCENDC_HOST_ASSERT(
srcShape.GetDimNum() == PAD_DIM_TWO, continue, "[Pad][%s] The dims of srcShape is %zu, should be 2.", funcName,
srcShape.GetDimNum());
}
void CheckPadTilingFuncParams(
const ge::Shape srcShape, const ge::Shape oriSrcShape, const uint32_t stackBufferSize, const uint32_t typeSize,
const char* funcName)
{
(void)stackBufferSize;
CheckPadMaxMinTmpSizeParams(srcShape, typeSize, funcName);
ASCENDC_HOST_ASSERT(
oriSrcShape.GetDimNum() == PAD_DIM_TWO, continue, "[Pad][%s] The dims of oriSrcShape is %zu, should be 2.",
funcName, oriSrcShape.GetDimNum());
ASCENDC_HOST_ASSERT(
oriSrcShape.GetShapeSize() <= srcShape.GetShapeSize(), continue,
"[Pad][%s] The size of oriSrcShape is %li, should be less than or equal to srcShape size %li.", funcName,
oriSrcShape.GetShapeSize(), srcShape.GetShapeSize());
}
void CheckUnPadParams(const ge::Shape& srcShape, const uint32_t typeSize, const char* funcName)
{
ASCENDC_HOST_ASSERT(
typeSize == PAD_TYPE_TWO || typeSize == PAD_TYPE_FOUR, continue,
"[UnPad][%s] The value of typeSize is %u, should be 2 or 4.", funcName, typeSize);
ASCENDC_HOST_ASSERT(
srcShape.GetDimNum() == PAD_DIM_TWO, continue, "[UnPad][%s] The dims of srcShape is %zu, should be 2.",
funcName, srcShape.GetDimNum());
}
void CheckGetUnPadMaxMinTmpSizeParams(const ge::Shape& srcShape, const uint32_t typeSize, const char* funcName)
{
CheckUnPadParams(srcShape, typeSize, funcName);
}
}
void GetPadMaxMinTmpSize(const ge::Shape& srcShape, const uint32_t typeSize, uint32_t& maxValue, uint32_t& minValue)
{
CheckPadMaxMinTmpSizeParams(srcShape, typeSize, "GetPadMaxMinTmpSize");
ASCENDC_HOST_ASSERT(typeSize > 0, return, "typeSize must be greater than 0.");
std::vector<int64_t> shapeDims = srcShape.GetDims();
uint32_t srcHeight = shapeDims[0];
uint32_t srcWidth = shapeDims[1];
maxValue = ((srcHeight * srcWidth - 1) / PAD_NCHW_CONV_ADDR_LIST_SIZE + 1) * PAD_NCHW_CONV_ADDR_LIST_SIZE *
(PAD_ONE_BLK_SIZE / typeSize) * TMP_BUFFER_NUM * typeSize / sizeof(uint8_t);
minValue =
PAD_NCHW_CONV_ADDR_LIST_SIZE * (PAD_ONE_BLK_SIZE / typeSize) * TMP_BUFFER_NUM * typeSize / sizeof(uint8_t);
}
void PadTilingFunc(
const ge::Shape srcShape, const ge::Shape oriSrcShape, const uint32_t stackBufferSize, const uint32_t typeSize,
optiling::PadTiling& tiling)
{
CheckPadTilingFuncParams(srcShape, oriSrcShape, stackBufferSize, typeSize, "PadTilingFunc");
ASCENDC_HOST_ASSERT(typeSize > 0, return, "typeSize must be greater than 0.");
std::vector<int64_t> shapeDims = srcShape.GetDims();
std::vector<int64_t> oriShapeDims = oriSrcShape.GetDims();
uint32_t srcHeight = shapeDims[0];
uint32_t srcWidth = shapeDims[1];
uint32_t srcOriWidth = oriShapeDims[1];
tiling.set_srcHeight(srcHeight);
tiling.set_srcWidth(srcWidth);
tiling.set_srcOriWidth(srcOriWidth);
uint32_t widthWithoutLastBlock = srcWidth - PAD_ONE_BLK_SIZE / typeSize;
tiling.set_widthWithoutLastBlock(widthWithoutLastBlock);
uint32_t blocksPerRow = srcWidth * typeSize / PAD_ONE_BLK_SIZE;
tiling.set_blocksPerRow(blocksPerRow);
uint32_t heightTiling = PAD_MAX_REPEAT_TIMES;
uint32_t heightFractal = srcHeight / heightTiling;
uint32_t heightFractalTail = srcHeight % heightTiling;
tiling.set_heightTiling(heightTiling);
tiling.set_heightFractal(heightFractal);
tiling.set_heightFractalTail(heightFractalTail);
uint32_t mainLoopOffset = heightTiling * srcWidth;
uint32_t tailBlockOffset = heightFractal * heightTiling * srcWidth + widthWithoutLastBlock;
tiling.set_mainLoopOffset(mainLoopOffset);
tiling.set_tailBlockOffset(tailBlockOffset);
if (stackBufferSize == 0) {
return;
}
uint32_t baseBlockLen = PAD_NCHW_CONV_ADDR_LIST_SIZE * PAD_ONE_BLK_SIZE;
uint32_t baseBlockSize = baseBlockLen / typeSize;
uint32_t tmpBuffer1BlockNum = stackBufferSize * sizeof(uint8_t) / typeSize / baseBlockSize / 2;
uint32_t tmpBuffer2Offset = tmpBuffer1BlockNum * baseBlockSize;
tiling.set_tmpBuffer1RowNum(PAD_NCHW_CONV_ADDR_LIST_SIZE * tmpBuffer1BlockNum);
tiling.set_tmpBuffer1BlockNum(tmpBuffer1BlockNum);
tiling.set_tmpBuffer2Offset(tmpBuffer2Offset);
uint32_t widthTiling = PAD_NCHW_CONV_ADDR_LIST_SIZE * tmpBuffer1BlockNum;
uint32_t widthFractalTail = srcWidth % widthTiling;
uint32_t widthFractal = srcWidth / widthTiling;
uint32_t widthFractalTailAlingned =
((widthFractalTail - 1) / (PAD_ONE_BLK_SIZE / typeSize) + 1) * (PAD_ONE_BLK_SIZE / typeSize);
tiling.set_widthTiling(widthTiling);
tiling.set_widthFractal(widthFractal);
tiling.set_widthFractalTail(widthFractalTail);
tiling.set_widthFractalTailAlingned(widthFractalTailAlingned);
uint32_t brcbTiling = PAD_NCHW_CONV_ADDR_LIST_SIZE * tmpBuffer1BlockNum;
uint32_t brcbFractal = srcHeight * srcWidth / brcbTiling;
uint32_t brcbFractalTail = srcHeight * srcWidth % brcbTiling;
tiling.set_brcbTiling(brcbTiling);
tiling.set_brcbFractal(brcbFractal);
tiling.set_brcbFractalTail(brcbFractalTail);
uint32_t maxRepeatTimes = 254;
uint32_t brcbTilingRepeatTimes = brcbTiling / 8 / maxRepeatTimes;
uint32_t brcbTilingRepeatTimesTail = brcbTiling / 8 % maxRepeatTimes;
uint32_t brcbFractalTailRepeatTimes = brcbFractalTail / 8 / maxRepeatTimes;
uint32_t brcbFractalTailRepeatTimesTail = brcbFractalTail / 8 % maxRepeatTimes;
tiling.set_maxRepeatTimes(maxRepeatTimes);
tiling.set_brcbTilingRepeatTimes(brcbTilingRepeatTimes);
tiling.set_brcbTilingRepeatTimesTail(brcbTilingRepeatTimesTail);
tiling.set_brcbFractalTailRepeatTimes(brcbFractalTailRepeatTimes);
tiling.set_brcbFractalTailRepeatTimesTail(brcbFractalTailRepeatTimesTail);
}
void PadTilingFunc(
const ge::Shape srcShape, const ge::Shape oriSrcShape, const uint32_t stackBufferSize, const uint32_t typeSize,
AscendC::tiling::PadTiling& tiling)
{
optiling::PadTiling tilingData;
PadTilingFunc(srcShape, oriSrcShape, stackBufferSize, typeSize, tilingData);
tilingData.SaveToBuffer(&tiling, sizeof(PadTiling));
}
void GetUnPadMaxMinTmpSize(
const platform_ascendc::PlatformAscendC& ascendcPlatform, const ge::Shape& srcShape, const uint32_t typeSize,
uint32_t& maxValue, uint32_t& minValue)
{
CheckGetUnPadMaxMinTmpSizeParams(srcShape, typeSize, "GetUnPadMaxMinTmpSize");
ASCENDC_HOST_ASSERT(typeSize > 0, return, "typeSize must be greater than 0.");
std::vector<int64_t> shapeDims = srcShape.GetDims();
uint32_t srcHeight = shapeDims[0];
uint32_t srcWidth = shapeDims[1];
auto npuArch = ascendcPlatform.GetCurNpuArch();
if (npuArch == NpuArch::DAV_2201 || npuArch == NpuArch::DAV_3510 || npuArch == NpuArch::DAV_5102) {
maxValue = 0;
minValue = 0;
} else {
maxValue = ((srcHeight * srcWidth - 1) / PAD_NCHW_CONV_ADDR_LIST_SIZE + 1) * PAD_NCHW_CONV_ADDR_LIST_SIZE *
(PAD_ONE_BLK_SIZE / typeSize) * TMP_BUFFER_NUM * typeSize / sizeof(uint8_t);
minValue =
PAD_NCHW_CONV_ADDR_LIST_SIZE * (PAD_ONE_BLK_SIZE / typeSize) * TMP_BUFFER_NUM * typeSize / sizeof(uint8_t);
}
}
void UnPadTilingFunc(
const ge::Shape srcShape, const uint32_t stackBufferSize, const uint32_t typeSize, optiling::UnPadTiling& tiling)
{
CheckGetUnPadMaxMinTmpSizeParams(srcShape, typeSize, "UnPadTilingFunc");
ASCENDC_HOST_ASSERT(typeSize > 0, return, "typeSize must be greater than 0.");
std::vector<int64_t> shapeDims = srcShape.GetDims();
uint32_t srcHeight = shapeDims[0];
uint32_t srcWidth = shapeDims[1];
tiling.set_srcHeight(srcHeight);
tiling.set_srcWidth(srcWidth);
if (stackBufferSize == 0) {
return;
}
uint32_t baseBlockLen = PAD_NCHW_CONV_ADDR_LIST_SIZE * PAD_ONE_BLK_SIZE;
uint32_t baseBlockSize = baseBlockLen / typeSize;
uint32_t tmpBuffer1BlockNum = stackBufferSize * sizeof(uint8_t) / typeSize / baseBlockSize / 2;
uint32_t tmpBuffer2Offset = tmpBuffer1BlockNum * baseBlockSize;
tiling.set_tmpBuffer1BlockNum(tmpBuffer1BlockNum);
tiling.set_tmpBuffer1RowNum(PAD_NCHW_CONV_ADDR_LIST_SIZE * tmpBuffer1BlockNum);
tiling.set_tmpBuffer2Offset(tmpBuffer2Offset);
uint32_t widthTiling = PAD_NCHW_CONV_ADDR_LIST_SIZE * tmpBuffer1BlockNum;
uint32_t widthFractal = srcWidth / widthTiling;
uint32_t widthFractalTail = srcWidth % widthTiling;
if (widthFractalTail != 0) {
widthFractal += 1;
} else {
widthFractalTail = widthTiling;
}
tiling.set_widthTiling(widthTiling);
tiling.set_widthFractal(widthFractal);
tiling.set_widthFractalTail(widthFractalTail);
}
void UnPadTilingFunc(
const ge::Shape srcShape, const uint32_t stackBufferSize, const uint32_t typeSize,
AscendC::tiling::UnPadTiling& tiling)
{
optiling::UnPadTiling tilingData;
UnPadTilingFunc(srcShape, stackBufferSize, typeSize, tilingData);
tilingData.SaveToBuffer(&tiling, sizeof(UnPadTiling));
}
}