* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
*/
#include "kernel_operator.h"
using namespace AscendC;
constexpr uint32_t BLOCK_SIZE = 32;
constexpr uint32_t MAX_DEAL_NUM = 2048;
constexpr uint32_t MAX_MASK = 64;
class KernelScatterMeanDiv {
public:
__aicore__ inline KernelScatterMeanDiv() {}
__aicore__ inline void Init(GM_ADDR src, GM_ADDR count, GM_ADDR out, ScatterMeanDivTilingData *tiling_data, TPipe* tmpPipe)
{
pipe = tmpPipe;
ASSERT(GetBlockNum() != 0 && "block dim can not be zero!");
TilingDataInit(tiling_data);
countGm.SetGlobalBuffer((__gm__ DTYPE_COUNT*)count, countNum);
srcGm.SetGlobalBuffer((__gm__ DTYPE_SRC*)src, srcNum);
outGm.SetGlobalBuffer((__gm__ DTYPE_OUT*)out, outNum);
eventIdMte2ToV_0 = static_cast<event_t>(pipe->AllocEventID<HardEvent::MTE2_V>());
pipe->InitBuffer(inQueueCount, AlignUp(ubCountNum, countEachBlock) * sizeof(DTYPE_COUNT));
pipe->InitBuffer(inQueueSrc, AlignUp(ubTailNum, countEachBlock) * sizeof(DTYPE_SRC));
}
__aicore__ inline void TilingDataInit(ScatterMeanDivTilingData *tiling_data)
{
curBlockIdx = GetBlockIdx();
usedCoreNum = tiling_data->usedCoreNum;
tail = tiling_data->tail;
taskNum = tiling_data->taskNum;
taskEachLine = tiling_data->taskEachLine;
taskLastLine = tiling_data->taskLastLine;
bigCoreNum = tiling_data->bigCoreNum;
srcNum = tiling_data->srcNum;
countNum = tiling_data->countNum;
outNum = tiling_data->outNum;
ubCountNum = tiling_data->ubCountNum;
ubTailNum = tiling_data->ubTailNum;
uint64_t coreDataLine = tiling_data->coreSmallLine;
if (curBlockIdx < bigCoreNum) {
coreDataLine = coreDataLine + 1;
countBaseOffset = curBlockIdx * coreDataLine;
} else {
taskNum = tiling_data->taskNumSmall;
taskEachLine = tiling_data->taskEachLineSmall;
taskLastLine = tiling_data->taskLastLineSmall;
countBaseOffset = bigCoreNum * (coreDataLine + 1) + (curBlockIdx - bigCoreNum) * coreDataLine;
}
countEachBlock = BLOCK_SIZE / sizeof(DTYPE_COUNT);
dataEachBlock = BLOCK_SIZE / sizeof(DTYPE_SRC);
}
__aicore__ inline void Process()
{
for (int32_t i = 0; i < taskNum - 1; i++) {
ComputeEachTask(i, taskEachLine);
}
if (taskLastLine != 0) {
ComputeEachTask(taskNum - 1, taskLastLine);
}
}
private:
__aicore__ inline void CopyParamasInit(const uint32_t calCount)
{
copyParamsOut.blockCount = 1;
copyParamsOut.blockLen = static_cast<uint32_t>(calCount * sizeof(float));
copyParamsOut.srcStride = 0;
copyParamsOut.dstStride = 0;
copyParamsOut.rsv = 0;
}
__aicore__ inline void ComputeTailDiv(uint64_t tail, uint64_t baseOffset, float mulValue)
{
uint64_t tailLoop = tail / ubTailNum;
uint64_t offset = 0;
PipeBarrier<PIPE_ALL>();
for (uint64_t loop = 0; loop < tailLoop; loop++) {
PipeBarrier<PIPE_ALL>();
offset = loop * ubTailNum;
DataCopy(srcLocal, srcGm[baseOffset + offset], ubTailNum);
PipeBarrier<PIPE_ALL>();
Muls(srcLocal, srcLocal, mulValue, ubTailNum);
PipeBarrier<PIPE_ALL>();
DataCopy(outGm[baseOffset + offset], srcLocal, ubTailNum);
}
offset = tailLoop * ubTailNum;
uint64_t tailLast = tail - offset;
PipeBarrier<PIPE_ALL>();
if (tailLast != 0) {
CopyParamasInit(tailLast);
PipeBarrier<PIPE_ALL>();
DataCopy(srcLocal, srcGm[baseOffset + offset], AlignUp(tailLast, dataEachBlock));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV_0);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV_0);
Muls(srcLocal, srcLocal, mulValue, AlignUp(tailLast, dataEachBlock));
PipeBarrier<PIPE_ALL>();
DataCopyPad(outGm[baseOffset + offset], srcLocal, copyParamsOut);
}
}
__aicore__ inline void ComputeEachTask(int32_t taskId, uint64_t taskLine)
{
LocalTensor<DTYPE_COUNT>countLocal = inQueueCount.Get<DTYPE_COUNT>();
srcLocal = inQueueSrc.Get<DTYPE_SRC>();
auto count_offset = countBaseOffset + taskEachLine * taskId;
DataCopy(countLocal, countGm[count_offset], AlignUp(taskLine, countEachBlock));
CopyParamasInit(tail);
PipeBarrier<PIPE_ALL>();
for (uint64_t idx = 0; idx < taskLine; idx++) {
DTYPE_COUNT countValue = countLocal.GetValue(idx);
PipeBarrier<PIPE_ALL>();
if (countValue != 0) {
float mulValue = 1 / countValue;
ComputeTailDiv(tail, count_offset * tail + idx * tail, mulValue);
}
}
}
private:
TPipe* pipe;
TBuf<TPosition::VECCALC> inQueueSrc, inQueueCount;
GlobalTensor<DTYPE_COUNT> countGm;
GlobalTensor<DTYPE_SRC> srcGm;
GlobalTensor<DTYPE_OUT> outGm;
LocalTensor<DTYPE_SRC> srcLocal;
DataCopyExtParams copyParamsOut;
uint64_t curBlockIdx;
uint64_t usedCoreNum, bigCoreNum;
uint64_t tail;
uint64_t taskNum;
uint64_t taskEachLine, taskLastLine;
uint64_t countEachBlock, dataEachBlock;
uint64_t srcNum, countNum, outNum;
uint64_t ubCountNum;
uint64_t countBaseOffset;
uint64_t ubTailNum;
event_t eventIdMte2ToV_0;
};
extern "C" __global__ __aicore__ void scatter_mean_div(GM_ADDR src, GM_ADDR count, GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
{
TPipe pipe;
GET_TILING_DATA(tiling_data, tiling);
KernelScatterMeanDiv op;
op.Init(src, count, out, &tiling_data, &pipe);
op.Process();
}
#ifndef __CCE_KT_TEST__
void scatter_mean_div_do(uint32_t blockDim, void* l2ctrl, void* stream, uint8_t* src, uint8_t* count, uint8_t* out,
uint8_t* workspace, uint8_t* tiling)
{
scatter_mean_div<<<blockDim, l2ctrl, stream>>>(var, indices, updates, out, argmax, workspace, tiling);
}
#endif