/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file layernorm.asc
* \brief
*/
#include "acl/acl.h"
#include "data_utils.h"
#include "kernel_operator.h"
#include "tiling/tiling_api.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(LayernormCustomTilingData)
TILING_DATA_FIELD_DEF_STRUCT(LayerNormTiling, layernormTilingData);
TILING_DATA_FIELD_DEF(float, epsilon);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(LayernormCustom, LayernormCustomTilingData)
} // namespace optiling
void ComputeTiling(const uint32_t bLength, const uint32_t sLength, const uint32_t hLength,
optiling::LayernormCustomTilingData& tiling)
{
ge::Shape geShape({bLength, sLength, hLength, hLength});
uint32_t maxTmpSize = 0;
uint32_t minTmpSize = 0;
bool isReuseSource = false;
AscendC::GetLayerNormMaxMinTmpSize(geShape, sizeof(float), isReuseSource, maxTmpSize, minTmpSize);
uint32_t localWorkspaceSize = minTmpSize;
// get layernorm Tiling
AscendC::GetLayerNormNDTilingInfo(geShape, localWorkspaceSize, sizeof(float), isReuseSource,
tiling.layernormTilingData);
tiling.set_epsilon(0.0001);
}
uint8_t* GetTilingBuf(optiling::LayernormCustomTilingData* tilingData)
{
uint32_t tilingSize = sizeof(optiling::LayernormCustomTilingData);
uint8_t* buf = (uint8_t*)malloc(tilingSize);
tilingData->SaveToBuffer(buf, tilingSize);
return buf;
}
uint8_t* GenerateTiling(uint32_t bLength, uint32_t sLength, uint32_t hLength)
{
optiling::LayernormCustomTilingData tiling;
ComputeTiling(bLength, sLength, hLength, tiling);
return GetTilingBuf(&tiling);
}
namespace MyCustomKernel {
struct VecTiling {
LayerNormTiling layernormTilingData;
float epsilon = 0;
};
template <bool isReuseSource = false>
class KernelLayernorm {
public:
__aicore__ inline KernelLayernorm() {}
__aicore__ inline void Init(GM_ADDR inputXGm, GM_ADDR gammGm, GM_ADDR betaGm, GM_ADDR outputGm,
GM_ADDR outputMeanGm, GM_ADDR outputVarianceGm, VecTiling tilingData,
AscendC::TPipe* pipeIn)
{
pipe = pipeIn;
this->epsilon = tilingData.epsilon;
tiling_ = tilingData.layernormTilingData;
this->bLength = tiling_.bLength;
this->sLength = tiling_.sLength;
this->hLength = tiling_.hLength;
bshLength = bLength * sLength * hLength;
bsLength = bLength * sLength;
inputXGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(inputXGm), bshLength);
gammGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(gammGm), hLength);
betaGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(betaGm), hLength);
outputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(outputGm), bshLength);
outputMeanGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(outputMeanGm), bsLength);
outputVarianceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(outputVarianceGm), bsLength);
pipe->InitBuffer(inQueueX, 1, sizeof(float) * bshLength);
pipe->InitBuffer(inQueueGamma, 1, sizeof(float) * hLength);
pipe->InitBuffer(inQueueBeta, 1, sizeof(float) * hLength);
pipe->InitBuffer(outQueue, 1, sizeof(float) * bshLength);
pipe->InitBuffer(outQueueMean, 1, sizeof(float) * bsLength);
pipe->InitBuffer(outQueueVariance, 1, sizeof(float) * bsLength);
}
__aicore__ inline void Process()
{
CopyIn();
Compute();
CopyOut();
}
private:
__aicore__ inline void CopyIn()
{
AscendC::LocalTensor<float> inputXLocal = inQueueX.AllocTensor<float>();
AscendC::LocalTensor<float> gammaLocal = inQueueGamma.AllocTensor<float>();
AscendC::LocalTensor<float> betaLocal = inQueueBeta.AllocTensor<float>();
AscendC::DataCopy(inputXLocal, inputXGlobal, bshLength);
AscendC::DataCopy(gammaLocal, gammGlobal, hLength);
AscendC::DataCopy(betaLocal, betaGlobal, hLength);
inQueueX.EnQue(inputXLocal);
inQueueGamma.EnQue(gammaLocal);
inQueueBeta.EnQue(betaLocal);
}
__aicore__ inline void Compute()
{
AscendC::LocalTensor<float> inputXLocal = inQueueX.DeQue<float>();
AscendC::LocalTensor<float> gammaLocal = inQueueGamma.DeQue<float>();
AscendC::LocalTensor<float> betaLocal = inQueueBeta.DeQue<float>();
AscendC::LocalTensor<float> outputLocal = outQueue.AllocTensor<float>();
AscendC::LocalTensor<float> meanLocal = outQueueMean.AllocTensor<float>();
AscendC::LocalTensor<float> varianceLocal = outQueueVariance.AllocTensor<float>();
AscendC::LayerNorm<float, isReuseSource>(outputLocal, meanLocal, varianceLocal, inputXLocal, gammaLocal,
betaLocal, (float)epsilon, tiling_);
outQueue.EnQue<float>(outputLocal);
outQueueMean.EnQue<float>(meanLocal);
outQueueVariance.EnQue<float>(varianceLocal);
inQueueX.FreeTensor(inputXLocal);
inQueueGamma.FreeTensor(gammaLocal);
inQueueBeta.FreeTensor(betaLocal);
}
__aicore__ inline void CopyOut()
{
AscendC::LocalTensor<float> outputLocal = outQueue.DeQue<float>();
AscendC::LocalTensor<float> meanLocal = outQueueMean.DeQue<float>();
AscendC::LocalTensor<float> varianceLocal = outQueueVariance.DeQue<float>();
AscendC::DataCopy(outputGlobal, outputLocal, bshLength);
AscendC::DataCopy(outputMeanGlobal, meanLocal, bsLength);
AscendC::DataCopy(outputVarianceGlobal, varianceLocal, bsLength);
outQueue.FreeTensor(outputLocal);
outQueueMean.FreeTensor(meanLocal);
outQueueVariance.FreeTensor(varianceLocal);
}
private:
AscendC::GlobalTensor<float> inputXGlobal;
AscendC::GlobalTensor<float> gammGlobal;
AscendC::GlobalTensor<float> betaGlobal;
AscendC::GlobalTensor<float> outputGlobal;
AscendC::GlobalTensor<float> outputMeanGlobal;
AscendC::GlobalTensor<float> outputVarianceGlobal;
AscendC::TPipe* pipe;
AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueX;
AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueGamma;
AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueBeta;
AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueue;
AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueMean;
AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueVariance;
uint32_t bLength;
uint32_t sLength;
uint32_t hLength;
float epsilon;
LayerNormTiling tiling_;
uint32_t bshLength;
uint32_t bsLength;
};
} // namespace MyCustomKernel
__aicore__ inline void CopyTiling(MyCustomKernel::VecTiling* tiling, GM_ADDR tilingGM)
{
uint32_t* ptr = reinterpret_cast<uint32_t*>(tiling);
auto tiling32 = reinterpret_cast<__gm__ uint32_t*>(tilingGM);
for (int i = 0; i < sizeof(MyCustomKernel::VecTiling) / sizeof(uint32_t); i++, ptr++) { *ptr = *(tiling32 + i); }
return;
}
extern "C" __global__ __aicore__ void layernorm_custom(GM_ADDR inputXGm, GM_ADDR gammaGm, GM_ADDR betaGm,
GM_ADDR outputGm, GM_ADDR outputMeanGm, GM_ADDR outputVarianceGm,
GM_ADDR workspace, GM_ADDR tiling)
{
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
AscendC::TPipe pipe;
MyCustomKernel::VecTiling tilingData;
CopyTiling(&tilingData, tiling);
MyCustomKernel::KernelLayernorm<false> op;
op.Init(inputXGm, gammaGm, betaGm, outputGm, outputMeanGm, outputVarianceGm, tilingData, &pipe);
op.Process();
}
constexpr uint32_t BLENGTH = 2;
constexpr uint32_t SLENGTH = 32;
constexpr uint32_t HLENGTH = 16;
constexpr uint32_t NUM_BLOCKS = 40;
constexpr uint32_t TILINGDATA_SIZE = 27;
constexpr uint32_t WORKSPACE_SIZE = 1024;
static bool CompareResult(const void* outputData, int64_t outSize, std::string goldenName)
{
void* goldenData;
aclrtMallocHost((void**)(&goldenData), outSize);
size_t goldenSize = outSize;
bool ret = ReadFile("./output/golden_output_" + goldenName + ".bin", goldenSize, goldenData, goldenSize);
if (ret) {
printf("ReadFile golden_output_%s.bin success!\n", goldenName.c_str());
} else {
aclrtFreeHost(goldenData);
return false;
}
constexpr float EPS = 1e-5;
int64_t wrongNum = 0;
for (int i = 0; i < outSize / sizeof(float); i++) {
float a = (reinterpret_cast<const float*>(outputData))[i];
float b = (reinterpret_cast<const float*>(goldenData))[i];
float ae = std::abs(a - b);
float re = ae / std::abs(b);
if (ae > EPS && re > EPS) {
printf("CompareResult golden_output_%s.bin failed output is %lf, golden is %lf\n", goldenName.c_str(), a,
b);
wrongNum++;
}
}
aclrtFreeHost(goldenData);
if (wrongNum != 0) {
return false;
} else {
printf("CompareResult golden_output_%s.bin success!\n", goldenName.c_str());
return true;
}
}
int32_t main(int32_t argc, char* argv[])
{
uint32_t numBlocks = NUM_BLOCKS;
size_t workspaceSize = WORKSPACE_SIZE * sizeof(float);
size_t xSize = BLENGTH * SLENGTH * HLENGTH * sizeof(float);
size_t gammaSize = HLENGTH * sizeof(float);
size_t betaSize = HLENGTH * sizeof(float);
size_t outputSize = BLENGTH * SLENGTH * HLENGTH * sizeof(float);
size_t meanSize = BLENGTH * SLENGTH * sizeof(float);
size_t varianceSize = BLENGTH * SLENGTH * sizeof(float);
size_t tilingFileSize = TILINGDATA_SIZE * sizeof(uint32_t);
bool goldenResult = true;
uint8_t* tilingBuf = GenerateTiling(BLENGTH, SLENGTH, HLENGTH);
aclInit(nullptr);
aclrtContext context;
int32_t deviceId = 0;
aclrtSetDevice(deviceId);
aclrtCreateContext(&context, deviceId);
aclrtStream stream = nullptr;
aclrtCreateStream(&stream);
uint8_t *inputXHost, *gammaHost, *betaHost, *resultHost, *meanHost, *varianceHost, *workspaceHost, *tilingHost;
uint8_t *inputXDevice, *gammaDevice, *betaDevice, *resultDevice, *meanDevice, *varianceDevice, *workspaceDevice,
*tilingDevice;
aclrtMallocHost((void**)(&inputXHost), xSize);
aclrtMallocHost((void**)(&gammaHost), gammaSize);
aclrtMallocHost((void**)(&betaHost), betaSize);
aclrtMallocHost((void**)(&resultHost), outputSize);
aclrtMallocHost((void**)(&meanHost), meanSize);
aclrtMallocHost((void**)(&varianceHost), varianceSize);
aclrtMallocHost((void**)(&workspaceHost), workspaceSize);
aclrtMallocHost((void**)(&tilingHost), tilingFileSize);
aclrtMalloc((void**)&inputXDevice, xSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void**)&gammaDevice, gammaSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void**)&betaDevice, betaSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void**)&resultDevice, outputSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void**)&meanDevice, meanSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void**)&varianceDevice, varianceSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void**)&workspaceDevice, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void**)&tilingDevice, tilingFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/input_inputX.bin", xSize, inputXHost, xSize);
ReadFile("./input/input_gamma.bin", gammaSize, gammaHost, gammaSize);
ReadFile("./input/input_beta.bin", betaSize, betaHost, betaSize);
aclrtMemcpy(workspaceDevice, workspaceSize, workspaceHost, workspaceSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemcpy(tilingDevice, tilingFileSize, tilingBuf, tilingFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemcpy(inputXDevice, xSize, inputXHost, xSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemcpy(gammaDevice, gammaSize, gammaHost, gammaSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemcpy(betaDevice, betaSize, betaHost, betaSize, ACL_MEMCPY_HOST_TO_DEVICE);
layernorm_custom<<<numBlocks, nullptr, stream>>>(inputXDevice, gammaDevice, betaDevice, resultDevice, meanDevice,
varianceDevice, workspaceDevice, tilingDevice);
aclrtSynchronizeStream(stream);
aclrtMemcpy(resultHost, outputSize, resultDevice, outputSize, ACL_MEMCPY_DEVICE_TO_HOST);
aclrtMemcpy(meanHost, meanSize, meanDevice, meanSize, ACL_MEMCPY_DEVICE_TO_HOST);
aclrtMemcpy(varianceHost, varianceSize, varianceDevice, varianceSize, ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile("./output/output_result.bin", resultHost, outputSize);
WriteFile("./output/output_mean.bin", meanHost, meanSize);
WriteFile("./output/output_variance.bin", varianceHost, varianceSize);
goldenResult &= CompareResult(resultHost, outputSize, "result");
goldenResult &= CompareResult(meanHost, meanSize, "mean");
goldenResult &= CompareResult(varianceHost, varianceSize, "variance");
aclrtFree(inputXDevice);
aclrtFree(gammaDevice);
aclrtFree(betaDevice);
aclrtFree(resultDevice);
aclrtFree(meanDevice);
aclrtFree(varianceDevice);
aclrtFree(workspaceDevice);
aclrtFree(tilingDevice);
aclrtFreeHost(inputXHost);
aclrtFreeHost(gammaHost);
aclrtFreeHost(betaHost);
aclrtFreeHost(resultHost);
aclrtFreeHost(meanHost);
aclrtFreeHost(varianceHost);
aclrtFreeHost(workspaceHost);
aclrtFreeHost(tilingHost);
aclrtDestroyStream(stream);
aclrtDestroyContext(context);
aclrtResetDevice(deviceId);
aclFinalize();
free(tilingBuf);
if (goldenResult) {
printf("test pass!\n");
} else {
printf("test failed!\n");
}
return 0;
}