* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../../detail/host_log.h"
#include "include/adv_api/normalization/deepnorm_tiling.h"
namespace optiling {
REGISTER_TILING_DATA_CLASS(DeepNormTilingOpApi, DeepNormTiling);
}
namespace AscendC {
namespace {
constexpr uint32_t ONE_BLK_SIZE = 32;
constexpr uint32_t ONE_BLOCK_NUM = ONE_BLK_SIZE / sizeof(float);
constexpr uint32_t DEEPNORM_THREE_TIMES = 3;
constexpr uint32_t DEEPNORM_TWO_TIMES = 2;
constexpr uint32_t DEEPNORM_MAX_REPEAT = 255;
constexpr float DEEPNORM_LAST_DIM_INIT_VALUE = 1.0;
constexpr uint32_t BASIC_BLK_HLENGTH = 64;
constexpr uint32_t BASIC_BLK_BSLENGTH = 8;
constexpr uint32_t DEEPNORM_HALF_SIZE = 2;
constexpr uint32_t DEEPNORM_FLOAT_SIZE = 4;
constexpr uint32_t DEEPNORM_SRC_DIM_NUM = 3;
constexpr uint32_t DEEPNORM_ALIGN_LAST_DIM = 32;
uint32_t GetFinalBlockSize(uint32_t dataAmount, uint32_t blockSize)
{
return (dataAmount + blockSize - 1) / blockSize * blockSize;
}
uint32_t GetDeepNormMaxTmpSize(const ge::Shape& srcShape, const uint32_t typeSize, const bool isReuseSource)
{
std::vector<int64_t> shapeDims = srcShape.GetDims();
const uint32_t bLength = static_cast<uint32_t>(shapeDims[0]);
const uint32_t sLength = static_cast<uint32_t>(shapeDims[1]);
const uint32_t hLength = static_cast<uint32_t>(shapeDims[2]);
uint32_t inputBSsize = bLength * sLength * sizeof(float);
uint32_t inputBSHsize = bLength * sLength * hLength * sizeof(float);
inputBSsize = GetFinalBlockSize(inputBSsize, ONE_BLK_SIZE);
inputBSHsize = GetFinalBlockSize(inputBSHsize, ONE_BLK_SIZE);
if (isReuseSource && (typeSize == sizeof(float))) {
return DEEPNORM_TWO_TIMES * inputBSHsize + DEEPNORM_TWO_TIMES * inputBSsize;
}
return DEEPNORM_THREE_TIMES * inputBSHsize + DEEPNORM_TWO_TIMES * inputBSsize;
}
uint32_t GetDeepNormMinTmpSize(
const ge::Shape& srcShape, const uint32_t typeSize, const bool isReuseSource, const bool isBasicBlock)
{
(void)isBasicBlock;
std::vector<int64_t> shapeDims = srcShape.GetDims();
const uint32_t bLength = static_cast<uint32_t>(shapeDims[0]);
const uint32_t sLength = static_cast<uint32_t>(shapeDims[1]);
const uint32_t hLength = static_cast<uint32_t>(shapeDims[2]);
uint32_t inputBSsize = bLength * sLength;
uint32_t inputHsize = hLength;
inputBSsize = GetFinalBlockSize(inputBSsize * sizeof(float), ONE_BLK_SIZE);
inputHsize = GetFinalBlockSize(inputHsize * sizeof(float), ONE_BLK_SIZE);
if (isReuseSource && (typeSize == sizeof(float))) {
return DEEPNORM_TWO_TIMES * inputHsize + DEEPNORM_TWO_TIMES * inputBSsize;
}
return DEEPNORM_THREE_TIMES * inputHsize + DEEPNORM_TWO_TIMES * inputBSsize;
}
void CheckDeepNormHostCommon(
const char* apiName, const char* hostFuncName, const ge::Shape& srcShape, const uint32_t typeSize,
const bool isBasicBlock)
{
ASCENDC_HOST_ASSERT(
srcShape.GetShapeSize() > 0, return, "[%s][%s] Input Shape size must be greater than 0.", apiName,
hostFuncName);
ASCENDC_HOST_ASSERT(
typeSize == DEEPNORM_HALF_SIZE || typeSize == DEEPNORM_FLOAT_SIZE, return,
"[%s][%s] Type size %u is unsupported!", apiName, hostFuncName, typeSize);
ASCENDC_HOST_ASSERT(
srcShape.GetDimNum() == DEEPNORM_SRC_DIM_NUM, return,
"[%s][%s] The dims of srcShape is %zu, should be 3 (e.g. [B, S, H])!", apiName, hostFuncName,
srcShape.GetDimNum());
uint32_t bsLength = static_cast<uint32_t>(srcShape.GetDim(0) * srcShape.GetDim(1));
uint32_t hLength = static_cast<uint32_t>(srcShape.GetDim(2));
ASCENDC_HOST_ASSERT(hLength > 0, return, "[%s][%s] The last axis H must be greater than 0!", apiName, hostFuncName);
ASCENDC_HOST_ASSERT(
bsLength > 0, return, "[%s][%s] The non-last axis B*S must be greater than 0!", apiName, hostFuncName);
ASCENDC_HOST_ASSERT(
hLength * typeSize % DEEPNORM_ALIGN_LAST_DIM == 0, return, "[%s][%s] The last axis H must be 32B aligned!",
apiName, hostFuncName);
ASCENDC_HOST_ASSERT(
hLength <= 2040, return, "[%s][%s] The value of last axis can not be greater than 2040!", apiName,
hostFuncName);
if (isBasicBlock) {
ASCENDC_HOST_ASSERT(
hLength % BASIC_BLK_HLENGTH == 0, return, "[%s][%s] The last axis H must be 64 aligned!", apiName,
hostFuncName);
ASCENDC_HOST_ASSERT(
bsLength % BASIC_BLK_BSLENGTH == 0, return, "[%s][%s] The non-last axis B*S must be 8 aligned!", apiName,
hostFuncName);
}
return;
}
void CheckDeepNormTilingInfo(
const char* apiName, const char* hostFuncName, const ge::Shape& srcShape, const ge::Shape& originSrcShape,
const bool isBasicBlock)
{
uint32_t hLength = static_cast<uint32_t>(srcShape.GetDim(2));
uint32_t hOriginLength = static_cast<uint32_t>(originSrcShape.GetDim(2));
ASCENDC_HOST_ASSERT(
hOriginLength > 0 && hOriginLength <= hLength, return, "[%s][%s] The range of originH is %u, must be (0, %u]!",
apiName, hostFuncName, hOriginLength, hLength);
if (isBasicBlock) {
ASCENDC_HOST_ASSERT(
hOriginLength == hLength, return,
"[%s][%s] When isBasicBlock is true, originH is %u, H is %u, originH must be equal to H!", apiName,
hostFuncName, hOriginLength, hLength);
}
return;
}
}
bool GetDeepNormMaxMinTmpSize(
const ge::Shape& srcShape, const uint32_t typeSize, const bool isReuseSource, const bool isBasicBlock,
uint32_t& maxValue, uint32_t& minValue)
{
CheckDeepNormHostCommon("DeepNorm", "GetDeepNormMaxMinTmpSize", srcShape, typeSize, isBasicBlock);
std::vector<int64_t> shapeDims = srcShape.GetDims();
const uint32_t bLength = static_cast<uint32_t>(shapeDims[0]);
const uint32_t sLength = static_cast<uint32_t>(shapeDims[1]);
const uint32_t hLength = static_cast<uint32_t>(shapeDims[2]);
bool hDivBy64 = (hLength % BASIC_BLK_HLENGTH == 0) && (hLength > 0);
bool bsDivBy8 = (bLength * sLength) % BASIC_BLK_BSLENGTH == 0;
if (isBasicBlock && !(hDivBy64 && bsDivBy8)) {
return false;
}
maxValue = GetDeepNormMaxTmpSize(srcShape, typeSize, isReuseSource);
minValue = GetDeepNormMinTmpSize(srcShape, typeSize, isReuseSource, isBasicBlock);
return true;
}
bool GetDeepNormTilingInfo(
const ge::Shape& srcShape, const ge::Shape& originSrcShape, const uint32_t stackBufferSize, const uint32_t typeSize,
const bool isReuseSource, const bool isBasicBlock, optiling::DeepNormTiling& tiling)
{
CheckDeepNormHostCommon("DeepNorm", "GetDeepNormTilingInfo", srcShape, typeSize, isBasicBlock);
CheckDeepNormTilingInfo("DeepNorm", "GetDeepNormTilingInfo", srcShape, originSrcShape, isBasicBlock);
uint32_t maxSize = 0;
uint32_t minSize = 0;
const bool res = GetDeepNormMaxMinTmpSize(srcShape, typeSize, isReuseSource, isBasicBlock, maxSize, minSize);
if (!res || stackBufferSize < minSize) {
return false;
}
std::vector<int64_t> shapeDims = srcShape.GetDims();
const uint32_t bLength = static_cast<uint32_t>(shapeDims[0]);
const uint32_t sLength = static_cast<uint32_t>(shapeDims[1]);
const uint32_t hLength = static_cast<uint32_t>(shapeDims[2]);
std::vector<int64_t> shapeDimsOri = originSrcShape.GetDims();
const uint32_t bLengthOri = static_cast<uint32_t>(shapeDimsOri[0]);
const uint32_t sLengthOri = static_cast<uint32_t>(shapeDimsOri[1]);
const uint32_t originalHLength = static_cast<uint32_t>(shapeDimsOri[2]);
bool bsCheck = (bLength == bLengthOri) && (sLength == sLengthOri);
bool hAlign = (hLength * typeSize % ONE_BLK_SIZE) == 0;
bool hCheckOri = (0u < originalHLength && originalHLength <= hLength);
bool basicblkOriH = !(isBasicBlock && originalHLength != hLength);
bool hLengthCheck = (hLength <= DEEPNORM_MAX_REPEAT * 8);
if (!(bsCheck && hAlign && hCheckOri && basicblkOriH && hLengthCheck)) {
return false;
}
const uint32_t inputXSize = bLength * sLength * hLength;
const uint32_t meanVarSize = bLength * sLength;
uint32_t numberOfTmpBuf =
(isReuseSource && (typeSize == sizeof(float))) ? DEEPNORM_TWO_TIMES : DEEPNORM_THREE_TIMES;
constexpr uint32_t meanTmpTensorPos = 0;
const uint32_t meanTmpTensorSize = GetFinalBlockSize(meanVarSize, ONE_BLOCK_NUM);
const uint32_t varianceTmpTensorPos = meanTmpTensorSize;
const uint32_t varianceTmpTensorSize = meanTmpTensorSize;
uint32_t meanVarTotalSize = (typeSize == sizeof(float)) ? 0 : meanTmpTensorSize + varianceTmpTensorSize;
const uint32_t tmpBufSize = stackBufferSize / sizeof(float);
uint32_t oneTmpSize = (tmpBufSize - meanVarTotalSize) / numberOfTmpBuf;
oneTmpSize = oneTmpSize / hLength * hLength;
oneTmpSize = (oneTmpSize > inputXSize) ? inputXSize : oneTmpSize;
if (isBasicBlock && (oneTmpSize / (BASIC_BLK_BSLENGTH * hLength) > 0)) {
oneTmpSize = oneTmpSize / (BASIC_BLK_BSLENGTH * hLength) * (BASIC_BLK_BSLENGTH * hLength);
}
if (oneTmpSize == 0) {
return false;
}
const uint32_t firstTmpStartPos = meanVarTotalSize;
const uint32_t secondTmpStartPos = firstTmpStartPos + oneTmpSize;
const uint32_t thirdTmpStartPos = secondTmpStartPos + oneTmpSize;
const uint32_t loopRound = inputXSize / oneTmpSize;
const uint32_t inputTailSize = inputXSize % oneTmpSize;
const uint32_t inputTailPos = inputXSize - inputTailSize;
const uint32_t inputRoundSize = oneTmpSize;
const uint32_t meanVarRoundSize = inputRoundSize / hLength;
const uint32_t meanVarTailSize = inputTailSize / hLength;
const uint32_t meanVarTailPos = meanVarSize - meanVarTailSize;
const uint32_t bshCurLength = inputRoundSize;
const uint32_t bsCurLength = meanVarRoundSize;
const float lastDimValueBack = DEEPNORM_LAST_DIM_INIT_VALUE / static_cast<float>(originalHLength);
tiling.set_bLength(bLength);
tiling.set_sLength(sLength);
tiling.set_hLength(hLength);
tiling.set_originalHLength(originalHLength);
tiling.set_inputXSize(inputXSize);
tiling.set_meanVarSize(meanVarSize);
tiling.set_numberOfTmpBuf(numberOfTmpBuf);
tiling.set_meanTmpTensorPos(meanTmpTensorPos);
tiling.set_meanTmpTensorSize(meanTmpTensorSize);
tiling.set_varianceTmpTensorPos(varianceTmpTensorPos);
tiling.set_varianceTmpTensorSize(varianceTmpTensorSize);
tiling.set_tmpBufSize(tmpBufSize);
tiling.set_oneTmpSize(oneTmpSize);
tiling.set_firstTmpStartPos(firstTmpStartPos);
tiling.set_secondTmpStartPos(secondTmpStartPos);
tiling.set_thirdTmpStartPos(thirdTmpStartPos);
tiling.set_loopRound(loopRound);
tiling.set_inputRoundSize(inputRoundSize);
tiling.set_inputTailSize(inputTailSize);
tiling.set_inputTailPos(inputTailPos);
tiling.set_meanVarRoundSize(meanVarRoundSize);
tiling.set_meanVarTailSize(meanVarTailSize);
tiling.set_meanVarTailPos(meanVarTailPos);
tiling.set_bshCurLength(bshCurLength);
tiling.set_bsCurLength(bsCurLength);
tiling.set_lastDimValueBack(lastDimValueBack);
return true;
}
bool GetDeepNormTilingInfo(
const ge::Shape& srcShape, const ge::Shape& originSrcShape, const uint32_t stackBufferSize, const uint32_t typeSize,
const bool isReuseSource, const bool isBasicBlock, AscendC::tiling::DeepNormTiling& tiling)
{
optiling::DeepNormTiling tilingData;
bool ret = GetDeepNormTilingInfo(
srcShape, originSrcShape, stackBufferSize, typeSize, isReuseSource, isBasicBlock, tilingData);
tilingData.SaveToBuffer(&tiling, sizeof(DeepNormTiling));
return ret;
}
}