* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file groupnorm_tiling_impl.cpp
* \brief
*/
#include "include/adv_api/normalization/groupnorm_tiling.h"
#include "../../detail/host_log.h"
namespace optiling {
REGISTER_TILING_DATA_CLASS(GroupNormTilingOpApi, GroupNormTiling);
}
namespace AscendC {
namespace {
constexpr uint32_t GROUPNORM_SRC_DIM_NUM = 4;
constexpr uint32_t GROUPNORM_SIZEOF_FLOAT = 4;
constexpr uint32_t GROUPNORM_SIZEOF_HALF = 2;
constexpr uint32_t GROUPNORM_ONE_BLK_SIZE = 32;
constexpr uint32_t GROUPNORM_THREE_TIMES = 3;
constexpr uint32_t GROUPNORM_TWO_TIMES = 2;
constexpr uint32_t GROUPNORM_ONE_NUMBER = 1;
constexpr uint32_t GROUPNORM_ZERO_NUMBER = 0;
constexpr float GROUPNORM_ONE_FLOAT_VALUE = 1.0f;
constexpr uint32_t GROUPNORM_MAX_MASK_VAL = 64;
constexpr uint32_t GROUPNORM_STEP_MASK_VAL = 8;
constexpr uint32_t GROUPNORM_MAX_REPEAT_VAL = 255;
constexpr uint32_t GROUPNORM_MIN_BSCURLENGTH_IN_ITERATION = 8;
constexpr uint32_t GROUPNORM_REDUCESUM_MAX_FLOAT_NUM = 64;
constexpr uint32_t GROUPNORM_REDUCESUM_MAX_REPEAT_SMALLSHAPE = 8;
uint32_t GetGroupNormTmpSize(
const ge::Shape& srcShape, const uint32_t typeSize, const bool isReuseSource, uint32_t groupNum,
const bool isMaxValue)
{
ASCENDC_HOST_ASSERT(typeSize > 0, return 0, "typeSize must be greater than 0.");
ASCENDC_HOST_ASSERT(groupNum > 0, return 0, "groupNum must be greater than 0.");
std::vector<int64_t> shapeDims = srcShape.GetDims();
const uint32_t n = static_cast<uint32_t>(shapeDims[0]);
const uint32_t c = static_cast<uint32_t>(shapeDims[1]);
const uint32_t h = static_cast<uint32_t>(shapeDims[2]);
const uint32_t w = static_cast<uint32_t>(shapeDims[3]);
ASCENDC_HOST_ASSERT(c != 0, return 0, "the value of c should not be zero!");
uint32_t mvTmpLen = n * groupNum * sizeof(float);
uint32_t hwLen = h * w * typeSize;
uint32_t dhwLen = {0};
mvTmpLen =
(mvTmpLen + GROUPNORM_ONE_BLK_SIZE - GROUPNORM_ONE_NUMBER) / GROUPNORM_ONE_BLK_SIZE * GROUPNORM_ONE_BLK_SIZE;
if (isMaxValue) {
dhwLen = n * c *
((hwLen + GROUPNORM_ONE_BLK_SIZE - GROUPNORM_ONE_NUMBER) / GROUPNORM_ONE_BLK_SIZE *
GROUPNORM_ONE_BLK_SIZE / typeSize * sizeof(float));
} else {
dhwLen = c / groupNum *
((hwLen + GROUPNORM_ONE_BLK_SIZE - GROUPNORM_ONE_NUMBER) / GROUPNORM_ONE_BLK_SIZE *
GROUPNORM_ONE_BLK_SIZE / typeSize * sizeof(float));
}
if (isReuseSource && (typeSize == GROUPNORM_SIZEOF_FLOAT)) {
return GROUPNORM_TWO_TIMES * dhwLen + GROUPNORM_TWO_TIMES * mvTmpLen;
}
return GROUPNORM_THREE_TIMES * dhwLen + GROUPNORM_TWO_TIMES * mvTmpLen;
}
void CheckGroupNormHostCommon(
const char* apiName, const char* hostFuncName, const ge::Shape& srcShape, const uint32_t typeSize)
{
ASCENDC_HOST_ASSERT(
srcShape.GetShapeSize() > 0, return, "[%s][%s] Input Shape size must be greater than 0.", apiName,
hostFuncName);
ASCENDC_HOST_ASSERT(
srcShape.GetDimNum() == GROUPNORM_SRC_DIM_NUM, return,
"[%s][%s] The dims of srcShape is %zu, should be 4 (e.g. [N, C, H, W])!", apiName, hostFuncName,
srcShape.GetDimNum());
ASCENDC_HOST_ASSERT(
typeSize == GROUPNORM_SIZEOF_HALF || typeSize == GROUPNORM_SIZEOF_FLOAT, return,
"[%s][%s] Type size %u is unsupported!", apiName, hostFuncName, typeSize);
return;
}
}
void GetGroupNormMaxMinTmpSize(
const ge::Shape& srcShape, const uint32_t typeSize, const bool isReuseSource, const uint32_t groupNum,
uint32_t& maxValue, uint32_t& minValue)
{
CheckGroupNormHostCommon("GroupNorm", "GetGroupNormMaxMinTmpSize", srcShape, typeSize);
maxValue = GetGroupNormTmpSize(srcShape, typeSize, isReuseSource, groupNum, true);
minValue = GetGroupNormTmpSize(srcShape, typeSize, isReuseSource, groupNum, false);
}
void GetGroupNormNDTilingInfo(
const ge::Shape& srcShape, const uint32_t stackBufferSize, const uint32_t typeSize, const bool isReuseSource,
const uint32_t groupNum, optiling::GroupNormTiling& tiling)
{
CheckGroupNormHostCommon("GroupNorm", "GetGroupNormNDTilingInfo", srcShape, typeSize);
ASCENDC_HOST_ASSERT(typeSize > 0, return, "typeSize must be greater than 0.");
ASCENDC_HOST_ASSERT(groupNum > 0, return, "groupNum must be greater than 0.");
std::vector<int64_t> shapeDims = srcShape.GetDims();
const uint32_t n = static_cast<uint32_t>(shapeDims[0]);
const uint32_t c = static_cast<uint32_t>(shapeDims[1]);
const uint32_t h = static_cast<uint32_t>(shapeDims[2]);
const uint32_t w = static_cast<uint32_t>(shapeDims[3]);
const uint32_t g = groupNum;
const uint32_t d = c / groupNum;
ASCENDC_HOST_ASSERT(c != 0, return, "the value of c should not be zero!");
const uint32_t hwAlignSize = (typeSize * h * w + GROUPNORM_ONE_BLK_SIZE - GROUPNORM_ONE_NUMBER) /
GROUPNORM_ONE_BLK_SIZE * GROUPNORM_ONE_BLK_SIZE / typeSize;
const uint32_t dhwAlignSize = d * hwAlignSize;
const uint32_t inputXSize = n * c * hwAlignSize;
const uint32_t meanVarSize = n * g;
const uint32_t oneBlockNum = GROUPNORM_ONE_BLK_SIZE / GROUPNORM_SIZEOF_FLOAT;
const uint32_t meanTmpTensorSize = (meanVarSize + oneBlockNum - GROUPNORM_ONE_NUMBER) / oneBlockNum * oneBlockNum;
uint32_t meanVarTotalSize = 2 * meanTmpTensorSize;
if (typeSize == GROUPNORM_SIZEOF_FLOAT) {
meanVarTotalSize = GROUPNORM_ZERO_NUMBER;
}
uint32_t numberOfTmpBuf = GROUPNORM_THREE_TIMES;
if (isReuseSource && (typeSize == GROUPNORM_SIZEOF_FLOAT)) {
numberOfTmpBuf = GROUPNORM_TWO_TIMES;
}
const uint32_t tmpBufSize =
stackBufferSize / GROUPNORM_ONE_BLK_SIZE * GROUPNORM_ONE_BLK_SIZE / GROUPNORM_SIZEOF_FLOAT;
uint32_t oneTmpSize = (tmpBufSize - meanVarTotalSize) / numberOfTmpBuf;
ASCENDC_HOST_ASSERT(dhwAlignSize != 0, return, "the value of dhwAlignSize should not be zero!");
uint32_t bsCurLength = oneTmpSize / dhwAlignSize;
uint32_t k = GROUPNORM_REDUCESUM_MAX_REPEAT_SMALLSHAPE;
while ((dhwAlignSize / (GROUPNORM_ONE_BLK_SIZE / GROUPNORM_SIZEOF_FLOAT)) % k != 0) {
k--;
}
const bool smallShape = (hwAlignSize <= GROUPNORM_REDUCESUM_MAX_FLOAT_NUM) &&
(hwAlignSize * d <= GROUPNORM_REDUCESUM_MAX_FLOAT_NUM * k);
if (smallShape) {
uint32_t mask1{GROUPNORM_MAX_MASK_VAL};
if (dhwAlignSize > GROUPNORM_MAX_MASK_VAL) {
while (mask1 != 0 && dhwAlignSize % mask1 != 0) {
mask1 -= GROUPNORM_STEP_MASK_VAL;
}
} else {
mask1 = dhwAlignSize;
}
ASCENDC_HOST_ASSERT(mask1 > 0, return, "mask1 must be greater than 0.");
const uint32_t maxBsCurLength =
(GROUPNORM_MAX_REPEAT_VAL / (dhwAlignSize / mask1) / GROUPNORM_MIN_BSCURLENGTH_IN_ITERATION) *
GROUPNORM_MIN_BSCURLENGTH_IN_ITERATION;
if (maxBsCurLength < bsCurLength) {
bsCurLength = maxBsCurLength;
}
}
if (typeSize == GROUPNORM_SIZEOF_HALF && bsCurLength * dhwAlignSize < c) {
return;
}
oneTmpSize = bsCurLength * d * hwAlignSize;
if (oneTmpSize > inputXSize) {
bsCurLength = meanVarSize;
oneTmpSize = inputXSize;
}
ASCENDC_HOST_ASSERT((oneTmpSize != GROUPNORM_ZERO_NUMBER), return, "the oneTmpSize should not be zero!");
if (oneTmpSize == GROUPNORM_ZERO_NUMBER) {
return;
}
const uint32_t inputRoundSize = oneTmpSize;
const uint32_t inputTailSize = inputXSize % oneTmpSize;
const uint32_t meanVarRoundSize = inputRoundSize / dhwAlignSize;
const uint32_t meanVarTailSize = inputTailSize / dhwAlignSize;
tiling.set_n(n);
tiling.set_c(c);
tiling.set_hw(h * w);
tiling.set_g(g);
tiling.set_d(d);
tiling.set_hwAlignSize(hwAlignSize);
tiling.set_dhwAlignSize(dhwAlignSize);
tiling.set_inputXSize(inputXSize);
tiling.set_meanVarSize(meanVarSize);
tiling.set_numberOfTmpBuf(numberOfTmpBuf);
tiling.set_meanTmpTensorPos(GROUPNORM_ZERO_NUMBER);
tiling.set_meanTmpTensorSize(meanTmpTensorSize);
tiling.set_varianceTmpTensorPos(meanTmpTensorSize);
tiling.set_varianceTmpTensorSize(meanTmpTensorSize);
tiling.set_tmpBufSize(tmpBufSize);
tiling.set_oneTmpSize(oneTmpSize);
tiling.set_firstTmpStartPos(meanVarTotalSize);
tiling.set_secondTmpStartPos(meanVarTotalSize + oneTmpSize);
tiling.set_thirdTmpStartPos(meanVarTotalSize + GROUPNORM_TWO_TIMES * oneTmpSize);
tiling.set_loopRound(inputXSize / oneTmpSize);
tiling.set_inputRoundSize(inputRoundSize);
tiling.set_inputTailSize(inputTailSize);
tiling.set_inputTailPos(inputXSize - inputTailSize);
tiling.set_meanVarRoundSize(meanVarRoundSize);
tiling.set_meanVarTailSize(meanVarTailSize);
tiling.set_meanVarTailPos(meanVarSize - meanVarTailSize);
tiling.set_bshCurLength(inputRoundSize);
tiling.set_bsCurLength(bsCurLength);
tiling.set_factor(GROUPNORM_ONE_FLOAT_VALUE / (d * h * w));
tiling.set_smallShape(smallShape);
}
void GetGroupNormNDTilingInfo(
const ge::Shape& srcShape, const uint32_t stackBufferSize, const uint32_t typeSize, const bool isReuseSource,
const uint32_t groupNum, AscendC::tiling::GroupNormTiling& tiling)
{
optiling::GroupNormTiling tilingData;
GetGroupNormNDTilingInfo(srcShape, stackBufferSize, typeSize, isReuseSource, groupNum, tilingData);
tilingData.SaveToBuffer(&tiling, sizeof(GroupNormTiling));
}
}