/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file batchnorm.asc
* \brief
*/
#include "acl/acl.h"
#include "data_utils.h"
#include "kernel_operator.h"
#include "tiling/tiling_api.h"
constexpr uint32_t BATCHNORM_SIZEOF_FLOAT = 4;
constexpr uint32_t BATCHNORM_SIZEOF_HALF = 2;
constexpr uint32_t FLOAT_BLOCK_NUMBER = 8;
constexpr uint32_t HALF_BLOCK_NUMBER = 16;
constexpr uint32_t BATCHNORM_HALF_ELE = 16;
constexpr uint32_t BATCHNORM_THREE_TIMES = 3;
constexpr uint32_t BATCHNORM_TWO_TIMES = 2;
constexpr uint32_t BATCHNORM_ONE_BLK_SIZE = 32;
constexpr uint32_t BATCHNORM_ONE_NUMBER = 1;
constexpr uint32_t BATCHNORM_ZERO_NUMBER = 0;
constexpr float BATCHNORM_LAST_DIM_INIT_VALUE = 1.0;
constexpr uint32_t BASIC_FLOAT_BLK_SHLENGTH = 64;
constexpr int32_t MAX_REPEAT_TIMES = 255;
const uint8_t DEFAULT_REPEAT_STRIDE = 8;
__aicore__ inline uint32_t AlignToBlockSize(const uint32_t inputX, const uint32_t inputY)
{
return (inputX + inputY - 1) / inputY * inputY;
}
__aicore__ inline bool CheckBasicBlockShape(const uint32_t originalBLength, const uint32_t sLength,
const uint32_t hLength)
{
if ((sLength * hLength % BASIC_FLOAT_BLK_SHLENGTH != 0) || (originalBLength % 8 != 0)) {
return false;
}
return true;
}
__aicore__ inline bool CheckShape(const AscendC::ShapeInfo srcShape, const AscendC::ShapeInfo originSrcShape,
const uint32_t typeSize, const bool isBasicBlock = false)
{
const uint32_t originalBLength = originSrcShape.shape[0];
const uint32_t sLength = srcShape.shape[1];
const uint32_t hLength = srcShape.shape[2];
if (sLength * hLength * typeSize % 32 != 0) {
return false;
}
if (isBasicBlock && (!CheckBasicBlockShape(originalBLength, sLength, hLength))) {
return false;
}
return true;
}
__aicore__ inline uint32_t GetBatchNormMaxTmpSize(const AscendC::ShapeInfo srcShape,
const AscendC::ShapeInfo originSrcShape, const uint32_t typeSize,
const bool isReuseSource, const bool isBasicBlock)
{
const uint32_t originalBLength = originSrcShape.shape[0];
const uint32_t sLength = srcShape.shape[1];
const uint32_t hLength = srcShape.shape[2];
uint32_t mvTmpLen = sLength * hLength * 4;
uint32_t inputLen = originalBLength * sLength * hLength * 4;
mvTmpLen = AlignToBlockSize(mvTmpLen, 32);
inputLen = AlignToBlockSize(inputLen, 32);
return 3 * inputLen + 2 * mvTmpLen;
}
__aicore__ inline uint32_t GetBatchNormMinTmpSize(const AscendC::ShapeInfo srcShape,
const AscendC::ShapeInfo originSrcShape, const uint32_t typeSize,
const bool isReuseSource, const bool isBasicBlock)
{
const uint32_t originalBLength = originSrcShape.shape[0];
const uint32_t sLength = srcShape.shape[1];
const uint32_t hLength = srcShape.shape[2];
uint32_t mvTmpLen = sLength * hLength * 4;
uint32_t bBlockLengthDiv = originalBLength * 32;
uint32_t minBasicBlock = originalBLength * 64 * 4;
mvTmpLen = AlignToBlockSize(mvTmpLen, 32);
if (typeSize == 2) {
bBlockLengthDiv = originalBLength * 64;
}
if (isBasicBlock) {
bBlockLengthDiv = minBasicBlock;
}
return 3 * bBlockLengthDiv + 2 * mvTmpLen;
}
__aicore__ inline bool GetBatchNormMaxMinTmpSize(const AscendC::ShapeInfo srcShape,
const AscendC::ShapeInfo originSrcShape, const uint32_t typeSize,
const bool isReuseSource, uint32_t& maxValue, uint32_t& minValue,
const bool isBasicBlock)
{
const uint32_t originalBLength = originSrcShape.shape[0];
const uint32_t sLength = srcShape.shape[1];
const uint32_t hLength = srcShape.shape[2];
if (isBasicBlock && (!CheckBasicBlockShape(originalBLength, sLength, hLength))) {
return false;
}
maxValue = GetBatchNormMaxTmpSize(srcShape, originSrcShape, typeSize, isReuseSource, isBasicBlock);
minValue = GetBatchNormMinTmpSize(srcShape, originSrcShape, typeSize, isReuseSource, isBasicBlock);
return true;
}
__aicore__ inline bool GetBatchNormNDTilingInfo(const AscendC::ShapeInfo srcShape,
const AscendC::ShapeInfo originSrcShape,
const uint32_t stackBufferByteSize, const uint32_t typeSize,
const bool isReuseSource, BatchNormTiling& tiling,
const bool isBasicBlock)
{
uint32_t minSize = 0;
uint32_t maxSize = 0;
bool res =
GetBatchNormMaxMinTmpSize(srcShape, originSrcShape, typeSize, isReuseSource, maxSize, minSize, isBasicBlock);
if (!res || stackBufferByteSize < minSize) {
return false;
}
const uint32_t sLength = srcShape.shape[1];
const uint32_t hLength = srcShape.shape[2];
const uint32_t originalBLength = originSrcShape.shape[0];
const uint32_t originalInputXSize = originalBLength * sLength * hLength;
const uint32_t meanVarSize = sLength * hLength;
uint32_t numberOfTmpBuf = 3;
constexpr uint32_t meanTmpTensorPos = 0;
const uint32_t meanTmpTensorSize = AlignToBlockSize(meanVarSize, 8);
const uint32_t varianceTmpTensorPos = meanTmpTensorSize;
const uint32_t varianceTmpTensorSize = meanTmpTensorSize;
uint32_t meanVarTotalSize = meanTmpTensorSize + varianceTmpTensorSize;
if (typeSize == 4) {
meanVarTotalSize = 0;
}
const uint32_t tmpBufSize = stackBufferByteSize / 4;
uint32_t oneTmpSize = (tmpBufSize - meanVarTotalSize) / numberOfTmpBuf;
if (typeSize != 4) {
oneTmpSize = oneTmpSize / (originalBLength * 16) * (originalBLength * 16);
} else {
oneTmpSize = oneTmpSize / (originalBLength * 8) * (originalBLength * 8);
}
if (oneTmpSize > originalInputXSize) {
oneTmpSize = originalInputXSize;
}
if (oneTmpSize == 0) {
return false;
}
uint32_t shCurLength = oneTmpSize / originalBLength;
const uint32_t shCurLengthAlign = shCurLength / BASIC_FLOAT_BLK_SHLENGTH * BASIC_FLOAT_BLK_SHLENGTH;
if (isBasicBlock && (shCurLength % BASIC_FLOAT_BLK_SHLENGTH != 0)) {
shCurLength = shCurLengthAlign;
oneTmpSize = shCurLength * originalBLength;
}
const uint32_t firstTmpStartPos = meanVarTotalSize;
const uint32_t secondTmpStartPos = firstTmpStartPos + oneTmpSize;
const uint32_t thirdTmpStartPos = secondTmpStartPos + oneTmpSize;
const uint32_t loopRound = originalInputXSize / oneTmpSize;
const uint32_t inputTailSize = originalInputXSize % oneTmpSize;
const uint32_t inputTailPos = (originalInputXSize - inputTailSize) / originalBLength;
const uint32_t meanVarTailSize = inputTailSize / originalBLength;
const uint32_t meanVarTailPos = meanVarSize - meanVarTailSize;
const uint32_t bshCurLength = oneTmpSize;
float firstDimValueBack = (float)1.0 / originalBLength;
const uint32_t castHalfRepStride = 8 / 2;
const uint32_t shCurLengthBlockNum = shCurLength / 8;
const uint32_t castHalfOutRepStride = meanVarSize / 16;
const uint32_t brcRepeatTimes = originalBLength / 8;
const uint32_t oriBloop = originalBLength / 255;
const uint32_t oriBTail = originalBLength % 255;
const uint32_t oriBTmpLoopOffset = shCurLength * 255;
const uint32_t oriBTmpTailOffset = oriBloop * oriBTmpLoopOffset;
const uint32_t oriBOutLoopOffset = meanVarSize * 255;
const uint32_t oriBOutTailOffset = oriBloop * oriBTmpLoopOffset;
const uint32_t reduceAddLoop = (originalBLength - 1) / 255;
const uint32_t reduceAddTail = (originalBLength - 1) % 255;
const uint32_t reduceAddTailOffset = BASIC_FLOAT_BLK_SHLENGTH + reduceAddLoop * oriBTmpLoopOffset;
const uint32_t basicLoop = shCurLength / BASIC_FLOAT_BLK_SHLENGTH;
tiling.originalBLength = originalBLength;
tiling.meanVarSize = meanVarSize;
tiling.meanTmpTensorPos = meanTmpTensorPos;
tiling.varianceTmpTensorPos = varianceTmpTensorPos;
tiling.tmpBufSize = tmpBufSize;
tiling.oneTmpSize = oneTmpSize;
tiling.firstTmpStartPos = firstTmpStartPos;
tiling.secondTmpStartPos = secondTmpStartPos;
tiling.thirdTmpStartPos = thirdTmpStartPos;
tiling.loopRound = loopRound;
tiling.inputTailSize = inputTailSize;
tiling.inputTailPos = inputTailPos;
tiling.meanVarTailSize = meanVarTailSize;
tiling.meanVarTailPos = meanVarTailPos;
tiling.bshCurLength = bshCurLength;
tiling.shCurLength = shCurLength;
tiling.firstDimValueBack = firstDimValueBack;
tiling.castHalfRepStride = castHalfRepStride;
tiling.shCurLengthBlockNum = shCurLengthBlockNum;
tiling.castHalfOutRepStride = castHalfOutRepStride;
return true;
}
template <typename T, bool isReuseSource = false, bool isBasicBlock = false>
class KernelBatchnorm {
public:
__aicore__ inline KernelBatchnorm() {}
__aicore__ inline void Init(GM_ADDR inputX_gm, GM_ADDR gamm_gm, GM_ADDR beta_gm, GM_ADDR output_gm,
GM_ADDR outputMean_gm, GM_ADDR outputVariance_gm, uint32_t bLength, uint32_t sLength,
uint32_t hLength, uint32_t originalBLength, T epsilon, AscendC::DataFormat dataFormat,
AscendC::TPipe* pipeIn)
{
pipe = pipeIn;
this->bLength = bLength;
this->sLength = sLength;
this->hLength = hLength;
this->originalBLength = originalBLength;
this->epsilon = epsilon;
this->dataFormat = dataFormat;
bshLength = originalBLength * sLength * hLength;
shLength = sLength * hLength;
inputX_global.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(inputX_gm), bshLength);
gamm_global.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(gamm_gm), bLength);
beta_global.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(beta_gm), bLength);
output_global.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(output_gm), bshLength);
outputMean_global.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(outputMean_gm), shLength);
outputVariance_global.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(outputVariance_gm), shLength);
pipe->InitBuffer(inQueueX, 1, sizeof(T) * bshLength);
pipe->InitBuffer(inQueueGamma, 1, sizeof(T) * bLength);
pipe->InitBuffer(inQueueBeta, 1, sizeof(T) * bLength);
pipe->InitBuffer(outQueue, 1, sizeof(T) * bshLength);
pipe->InitBuffer(outQueueMean, 1, sizeof(T) * shLength);
pipe->InitBuffer(outQueueVariance, 1, sizeof(T) * shLength);
}
__aicore__ inline void Process()
{
AscendC::AscendCUtils::SetOverflow(1);
CopyIn();
Compute();
CopyOut();
AscendC::AscendCUtils::SetOverflow(0);
}
private:
__aicore__ inline void CopyIn()
{
AscendC::LocalTensor<T> inputXLocal = inQueueX.AllocTensor<T>();
AscendC::LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
AscendC::LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
AscendC::DataCopy(inputXLocal, inputX_global, bshLength);
AscendC::DataCopy(gammaLocal, gamm_global, bLength);
AscendC::DataCopy(betaLocal, beta_global, bLength);
inQueueX.EnQue(inputXLocal);
inQueueGamma.EnQue(gammaLocal);
inQueueBeta.EnQue(betaLocal);
}
__aicore__ inline void Compute()
{
AscendC::LocalTensor<T> inputXLocal = inQueueX.DeQue<T>();
AscendC::LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
AscendC::LocalTensor<T> betaLocal = inQueueBeta.DeQue<T>();
AscendC::LocalTensor<T> outputLocal = outQueue.AllocTensor<T>();
AscendC::LocalTensor<T> meanLocal = outQueueMean.AllocTensor<T>();
AscendC::LocalTensor<T> varianceLocal = outQueueVariance.AllocTensor<T>();
BatchNormTiling tiling;
uint32_t inputShape[3] = {bLength, sLength, hLength};
uint32_t originInputShape[3] = {originalBLength, sLength, hLength};
AscendC::ShapeInfo shapeInfo{3, inputShape, 3, inputShape, dataFormat};
AscendC::ShapeInfo oriShapeInfo{3, originInputShape, 3, originInputShape, dataFormat};
uint32_t maxValue = 0;
uint32_t minValue = 0;
AscendC::LocalTensor<uint8_t> stackBuffer;
bool ans = AscendC::PopStackBuffer<uint8_t, AscendC::TPosition::LCM>(stackBuffer);
stackBufferSize = stackBuffer.GetSize();
GetBatchNormMaxMinTmpSize(shapeInfo, oriShapeInfo, sizeof(T), isReuseSource, maxValue, minValue, isBasicBlock);
uint64_t medianValue = (maxValue + minValue) / 2;
GetBatchNormNDTilingInfo(shapeInfo, oriShapeInfo, minValue, sizeof(T), isReuseSource, tiling, isBasicBlock);
stackBuffer.SetSize(stackBufferSize);
uint64_t now_stackBufferSize = stackBuffer.GetSize();
AscendC::BatchNorm<T, isReuseSource, isBasicBlock>(outputLocal, meanLocal, varianceLocal, inputXLocal,
gammaLocal, betaLocal, stackBuffer, (T)epsilon, tiling);
outQueue.EnQue<T>(outputLocal);
outQueueMean.EnQue<T>(meanLocal);
outQueueVariance.EnQue<T>(varianceLocal);
inQueueX.FreeTensor(inputXLocal);
inQueueGamma.FreeTensor(gammaLocal);
inQueueBeta.FreeTensor(betaLocal);
}
__aicore__ inline void CopyOut()
{
AscendC::LocalTensor<T> outputLocal = outQueue.DeQue<T>();
AscendC::LocalTensor<T> meanLocal = outQueueMean.DeQue<T>();
AscendC::LocalTensor<T> varianceLocal = outQueueVariance.DeQue<T>();
AscendC::DataCopy(output_global, outputLocal, bshLength);
AscendC::DataCopy(outputMean_global, meanLocal, shLength);
AscendC::DataCopy(outputVariance_global, varianceLocal, shLength);
outQueue.FreeTensor(outputLocal);
outQueueMean.FreeTensor(meanLocal);
outQueueVariance.FreeTensor(varianceLocal);
}
private:
AscendC::GlobalTensor<T> inputX_global;
AscendC::GlobalTensor<T> gamm_global;
AscendC::GlobalTensor<T> beta_global;
AscendC::GlobalTensor<T> output_global;
AscendC::GlobalTensor<T> outputMean_global;
AscendC::GlobalTensor<T> outputVariance_global;
AscendC::TPipe* pipe;
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueX;
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueGamma;
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueBeta;
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueue;
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueMean;
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueVariance;
uint32_t bLength;
uint32_t sLength;
uint32_t hLength;
uint32_t originalBLength;
T epsilon;
AscendC::DataFormat dataFormat;
uint32_t bshLength;
uint32_t shLength;
uint32_t stackBufferSize = 0;
};
__global__ __aicore__ void batchnorm_custom(GM_ADDR inputX_gm, GM_ADDR gamm_gm, GM_ADDR beta_gm, GM_ADDR output_gm,
GM_ADDR outputMean_gm, GM_ADDR outputVariance_gm)
{
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
AscendC::TPipe pipe;
constexpr uint32_t bLength = 8;
constexpr uint32_t sLength = 8;
constexpr uint32_t hLength = 8;
constexpr uint32_t originalBLength = 8;
constexpr float epsilon = 0.01;
KernelBatchnorm<float, false, false> op;
op.Init(inputX_gm, gamm_gm, beta_gm, output_gm, outputMean_gm, outputVariance_gm, bLength, sLength, hLength,
originalBLength, epsilon, AscendC::DataFormat::ND, &pipe);
op.Process();
}
static bool CompareResult(const void* outputData, int64_t outSize, std::string goldenName)
{
void* goldenData;
aclrtMallocHost((void**)(&goldenData), outSize);
size_t goldenSize = outSize;
bool ret = ReadFile("./output/golden_output_" + goldenName + ".bin", goldenSize, goldenData, goldenSize);
if (ret) {
printf("ReadFile golden_output_%s.bin success!\n", goldenName.c_str());
} else {
aclrtFreeHost(goldenData);
return false;
}
constexpr float EPS = 1e-5;
int64_t wrongNum = 0;
for (int i = 0; i < outSize / sizeof(float); i++) {
float a = (reinterpret_cast<const float*>(outputData))[i];
float b = (reinterpret_cast<const float*>(goldenData))[i];
float ae = std::abs(a - b);
float re = ae / std::abs(b);
if (ae > EPS && re > EPS) {
printf("CompareResult golden_output_%s.bin failed output is %lf, golden is %lf\n", goldenName.c_str(), a,
b);
wrongNum++;
}
}
aclrtFreeHost(goldenData);
if (wrongNum != 0) {
return false;
} else {
printf("CompareResult golden_output_%s.bin success!\n", goldenName.c_str());
return true;
}
}
int32_t main(int32_t argc, char* argv[])
{
size_t param1FileSize = 512 * sizeof(float);
size_t param2FileSize = 8 * sizeof(float);
size_t param3FileSize = 8 * sizeof(float);
size_t param4FileSize = 512 * sizeof(float);
size_t param5FileSize = 64 * sizeof(float);
size_t param6FileSize = 64 * sizeof(float);
uint32_t numBlocks = 1;
aclInit(nullptr);
aclrtContext context;
int32_t deviceId = 0;
aclrtSetDevice(deviceId);
aclrtCreateContext(&context, deviceId);
aclrtStream stream = nullptr;
aclrtCreateStream(&stream);
uint8_t* param1Host;
uint8_t* param1Device;
aclrtMallocHost((void**)(¶m1Host), param1FileSize);
aclrtMalloc((void**)¶m1Device, param1FileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/input_inputX.bin", param1FileSize, param1Host, param1FileSize);
aclrtMemcpy(param1Device, param1FileSize, param1Host, param1FileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t* param2Host;
uint8_t* param2Device;
aclrtMallocHost((void**)(¶m2Host), param2FileSize);
aclrtMalloc((void**)¶m2Device, param2FileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/input_gamma.bin", param2FileSize, param2Host, param2FileSize);
aclrtMemcpy(param2Device, param2FileSize, param2Host, param2FileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t* param3Host;
uint8_t* param3Device;
aclrtMallocHost((void**)(¶m3Host), param3FileSize);
aclrtMalloc((void**)¶m3Device, param3FileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/input_beta.bin", param3FileSize, param3Host, param3FileSize);
aclrtMemcpy(param3Device, param3FileSize, param3Host, param3FileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t* param4Host;
uint8_t* param4Device;
aclrtMallocHost((void**)(¶m4Host), param4FileSize);
aclrtMalloc((void**)¶m4Device, param4FileSize, ACL_MEM_MALLOC_HUGE_FIRST);
uint8_t* param5Host;
uint8_t* param5Device;
aclrtMallocHost((void**)(¶m5Host), param5FileSize);
aclrtMalloc((void**)¶m5Device, param5FileSize, ACL_MEM_MALLOC_HUGE_FIRST);
uint8_t* param6Host;
uint8_t* param6Device;
aclrtMallocHost((void**)(¶m6Host), param6FileSize);
aclrtMalloc((void**)¶m6Device, param6FileSize, ACL_MEM_MALLOC_HUGE_FIRST);
batchnorm_custom<<<numBlocks, nullptr, stream>>>(param1Device, param2Device, param3Device, param4Device,
param5Device, param6Device);
aclrtSynchronizeStream(stream);
aclrtFree(param1Device);
aclrtFreeHost(param1Host);
aclrtFree(param2Device);
aclrtFreeHost(param2Host);
aclrtFree(param3Device);
aclrtFreeHost(param3Host);
aclrtMemcpy(param4Host, param4FileSize, param4Device, param4FileSize, ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile("./output/output_result.bin", param4Host, param4FileSize);
aclrtFree(param4Device);
aclrtFreeHost(param4Host);
aclrtMemcpy(param5Host, param5FileSize, param5Device, param5FileSize, ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile("./output/output_mean.bin", param5Host, param5FileSize);
aclrtFree(param5Device);
aclrtFreeHost(param5Host);
aclrtMemcpy(param6Host, param6FileSize, param6Device, param6FileSize, ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile("./output/output_variance.bin", param6Host, param6FileSize);
aclrtFree(param6Device);
aclrtFreeHost(param6Host);
bool goldenResult = true;
goldenResult &= CompareResult(param4Host, param4FileSize, "result");
goldenResult &= CompareResult(param5Host, param5FileSize, "mean");
goldenResult &= CompareResult(param6Host, param6FileSize, "variance");
if (goldenResult) {
printf("test pass!\n");
} else {
printf("test failed!\n");
}
aclrtDestroyStream(stream);
aclrtDestroyContext(context);
aclrtResetDevice(deviceId);
aclFinalize();
return 0;
}