* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*/
#ifndef _SCATTER_MEAN_NORAML_H_
#define _SCATTER_MEAN_NORAML_H_
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
using namespace AscendC;
constexpr uint32_t BLOCK_SIZE = 32;
constexpr uint32_t MAX_MASK = 64;
class KernelScatterMeanFix {
public:
__aicore__ inline KernelScatterMeanFix() {}
__aicore__ inline void Init(GM_ADDR src, GM_ADDR indices, GM_ADDR var, GM_ADDR out, GM_ADDR count, ScatterMeanTilingData *tiling_data, TPipe* tmpPipe)
{
pipe = tmpPipe;
ASSERT(GetBlockNum() != 0 && "block dim can not be zero!");
TilingDataInit(tiling_data);
varGm.SetGlobalBuffer((__gm__ DTYPE_VAR*)var, outNum);
indicesGm.SetGlobalBuffer((__gm__ DTYPE_INDICES*)indices, indicesNum);
srcGm.SetGlobalBuffer((__gm__ DTYPE_SRC*)src, srcNum);
outGm.SetGlobalBuffer((__gm__ DTYPE_OUT*)out, outNum);
countGm.SetGlobalBuffer((__gm__ DTYPE_OUT*)count, outNum / tail);
eventIdMte3ToMte2_0 = static_cast<event_t>(pipe->AllocEventID<HardEvent::MTE3_MTE2>());
eventIdMte2ToMte3_0 = static_cast<event_t>(pipe->AllocEventID<HardEvent::MTE2_MTE3>());
pipe->InitBuffer(inQueueIndices, AlignUp(ubIndicesNum, indicesEachBlock) * sizeof(DTYPE_INDICES));
pipe->InitBuffer(inQueueSrc, AlignUp(ubTailNum, indicesEachBlock) * sizeof(DTYPE_SRC));
pipe->InitBuffer(onesTensorBuff, dataEachBlock * sizeof(DTYPE_COUNT));
}
__aicore__ inline void TilingDataInit(ScatterMeanTilingData *tiling_data)
{
curBlockIdx = GetBlockIdx();
usedCoreNum = tiling_data->usedCoreNum;
tail = tiling_data->tail;
body = tiling_data->body;
taskNum = tiling_data->taskNum;
taskEachLine = tiling_data->taskEachLine;
taskLastLine = tiling_data->taskLastLine;
bigCoreNum = tiling_data->bigCoreNum;
outDimSize = tiling_data->outDimSize;
dimSize = tiling_data->dimSize;
srcNum = tiling_data->srcNum;
indicesNum = tiling_data->indicesNum;
outNum = tiling_data->outNum;
ubIndicesNum = tiling_data->ubIndicesNum;
ubTailNum = tiling_data->ubTailNum;
uint64_t coreDataLine = tiling_data->bacthSmallCore;
if (curBlockIdx < bigCoreNum) {
coreDataLine = coreDataLine + 1;
indicesBaseOffset = curBlockIdx * coreDataLine;
} else {
taskNum = tiling_data->taskNumLast;
taskEachLine = tiling_data->taskEachLineLast;
taskLastLine = tiling_data->taskLastLineLast;
indicesBaseOffset = bigCoreNum * (coreDataLine + 1) + (curBlockIdx - bigCoreNum) * coreDataLine;
}
indicesEachBlock = BLOCK_SIZE / sizeof(DTYPE_INDICES);
dataEachBlock = BLOCK_SIZE / sizeof(DTYPE_SRC);
tailLoop = tail / ubTailNum;
tailLast = tail - tailLoop * ubTailNum;
copyParamsCount.blockCount = 1;
copyParamsCount.blockLen = static_cast<uint32_t>(1 * sizeof(float));
copyParamsCount.srcStride = 0;
copyParamsCount.dstStride = 0;
copyParamsCount.rsv = 0;
}
__aicore__ inline void Process()
{
for (int32_t i = 0; i < taskNum - 1; i++) {
ComputeEachTask(i, taskEachLine);
}
if (taskLastLine != 0) {
ComputeEachTask(taskNum - 1, taskLastLine);
}
}
private:
__aicore__ inline void CopyParamasInit(const uint32_t calCount)
{
copyParamsOut.blockCount = 1;
copyParamsOut.blockLen = static_cast<uint32_t>(calCount * sizeof(float));
copyParamsOut.srcStride = 0;
copyParamsOut.dstStride = 0;
copyParamsOut.rsv = 0;
}
__aicore__ inline void ComputeTailAdd(uint64_t idxTure, uint64_t dataInIndices, uint64_t src_offset)
{
uint64_t offset = 0;
uint64_t srcLineEachHead = dimSize * body;
auto idx1 = idxTure / srcLineEachHead;
auto idx2 = (idxTure - idx1 * srcLineEachHead) / body;
auto idx3 = idxTure - idx1 * srcLineEachHead - idx2 * body;
uint64_t outLineOffset = idx3 + dataInIndices * body + idx1 * (outDimSize * body);
PipeBarrier<PIPE_ALL>();
for (uint64_t loop = 0; loop < tailLoop; loop++) {
PipeBarrier<PIPE_ALL>();
offset = loop * ubTailNum;
DataCopy(srcLocal, srcGm[src_offset + offset], ubTailNum);
SetFlag<HardEvent::MTE2_MTE3>(eventIdMte2ToMte3_0);
WaitFlag<HardEvent::MTE2_MTE3>(eventIdMte2ToMte3_0);
DataCopy(outGm[outLineOffset * tail + offset], srcLocal, ubTailNum);
}
offset = tailLoop * ubTailNum;
if (tailLast != 0) {
PipeBarrier<PIPE_ALL>();
CopyParamasInit(tailLast);
DataCopy(srcLocal, srcGm[src_offset + offset], AlignUp(tailLast, dataEachBlock));
SetFlag<HardEvent::MTE2_MTE3>(eventIdMte2ToMte3_0);
WaitFlag<HardEvent::MTE2_MTE3>(eventIdMte2ToMte3_0);
DataCopyPad(outGm[outLineOffset * tail + offset], srcLocal, copyParamsOut);
}
DataCopyPad<DTYPE_SRC>(countGm[outLineOffset], onesTensor, copyParamsCount);
}
__aicore__ inline void ComputeEachTask(int32_t taskId, uint64_t taskLine)
{
LocalTensor<DTYPE_INDICES>indicesLocal = inQueueIndices.Get<DTYPE_INDICES>();
onesTensor = onesTensorBuff.Get<DTYPE_COUNT>();
srcLocal = inQueueSrc.Get<DTYPE_SRC>();
Duplicate(onesTensor, (float)1, dataEachBlock);
uint64_t indices_offset = indicesBaseOffset + taskEachLine * taskId;
DataCopy(indicesLocal, indicesGm[indices_offset], AlignUp(taskLine, indicesEachBlock));
PipeBarrier<PIPE_ALL>();
for (uint64_t idx = 0; idx < taskLine; idx++) {
DTYPE_INDICES dataInIndices = indicesLocal.GetValue(idx);
auto idxTure = indices_offset + idx;
auto src_offset = indices_offset * tail + idx * tail;
SetAtomicAdd<DTYPE_SRC>();
ComputeTailAdd(idxTure, dataInIndices, src_offset);
SetAtomicNone();
}
}
private:
TPipe* pipe;
TBuf<TPosition::VECCALC> inQueueIndices, inQueueSrc;
TBuf<TPosition::VECCALC> onesTensorBuff;
GlobalTensor<DTYPE_VAR> varGm;
GlobalTensor<DTYPE_INDICES> indicesGm;
GlobalTensor<DTYPE_SRC> srcGm;
GlobalTensor<DTYPE_OUT> outGm;
GlobalTensor<DTYPE_COUNT> countGm;
LocalTensor<DTYPE_SRC> srcLocal;
LocalTensor<float>onesTensor;
DataCopyExtParams copyParamsOut;
DataCopyExtParams copyParamsCount;
uint64_t curBlockIdx;
bool isOneDeal;
uint64_t usedCoreNum, bigCoreNum;
uint64_t tail, body;
uint64_t taskNum;
uint64_t taskEachLine, taskLastLine;
uint64_t indicesEachBlock, dataEachBlock;
uint64_t srcNum, indicesNum, outNum;
uint64_t ubIndicesNum;
uint64_t outDimSize, dimSize;
uint64_t ubTailNum;
uint64_t indicesBaseOffset;
int64_t tailLoop;
uint64_t tailLast;
event_t eventIdMte3ToMte2_0, eventIdMte2ToMte3_0;
};
#endif