* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*/
#ifndef SCATTER_ADD_V1_H_
#define SCATTER_ADD_V1_H_
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
using namespace AscendC;
constexpr uint64_t MAX_MASK = 64;
constexpr uint64_t BLOCK_SIZE = 32;
constexpr uint64_t MASK_ALIGN_SIZE = 256;
constexpr uint64_t DIM_SIZE_THRESHOLD = 200;
constexpr uint64_t MAX_COPY_PAD = 4095;
constexpr uint64_t BUFFER_NUM_MAX = 8;
constexpr uint64_t UB_SIZE_COEFF = 2;
class ScatterAddBaseKernel {
public:
__aicore__ inline ScatterAddBaseKernel() = delete;
__aicore__ inline ScatterAddBaseKernel(
GM_ADDR src, GM_ADDR indices, GM_ADDR out, ScatterAddTilingDataV1* tiling_data, TPipe* pipe)
: _blockIdx(GetBlockIdx()), _pipe(pipe)
{
ASSERT(GetBlockNum() != 0 && "block dim can not be zero!");
initTiling(tiling_data);
initGM(src, indices, out);
}
protected:
template<typename T1, typename T2>
__aicore__ inline T1 ceilDiv(T1 a, T2 b)
{
return b == 0 ? 0 : (a + b - 1) / b;
};
protected:
TPipe* _pipe;
TBuf<TPosition::VECCALC> _srcBuf;
TBuf<TPosition::VECCALC> _outBuf;
TBuf<TPosition::VECCALC> _indicesBuf;
LocalTensor<DTYPE_SRC> _srcLocal;
LocalTensor<DTYPE_OUT> _outLocal;
LocalTensor<DTYPE_INDICES> _indicesLocal;
GlobalTensor<DTYPE_SRC> _srcGm;
GlobalTensor<DTYPE_OUT> _outGm;
GlobalTensor<DTYPE_INDICES> _indicesGm;
uint64_t _blockIdx;
uint64_t _totalHead;
uint64_t _tailLen;
uint64_t _dimSize;
uint64_t _srcDimSize;
uint64_t _ubSize;
uint64_t _totalSrcNum;
uint64_t _totalOutNum;
uint64_t _totalIndicesNum;
uint64_t _outNumPerHead;
uint64_t _indicesNumPerHead;
uint64_t _dataNumPerBlock;
uint64_t _loop;
private:
__aicore__ inline void initTiling(ScatterAddTilingDataV1* tiling_data)
{
_totalHead = tiling_data->totalHead;
_tailLen = tiling_data->tailLen;
_dimSize = tiling_data->dimSize;
_ubSize = tiling_data->ubSize;
_srcDimSize = tiling_data->srcDimSize;
_totalSrcNum = tiling_data->totalSrcNum;
_totalOutNum = tiling_data->totalOutNum;
_totalIndicesNum = tiling_data->totalIndicesNum;
_outNumPerHead = tiling_data->outNumPerHead;
_indicesNumPerHead = tiling_data->indicesNumPerHead;
_dataNumPerBlock = BLOCK_SIZE / sizeof(DTYPE_OUT);
}
__aicore__ inline void initGM(GM_ADDR src, GM_ADDR indices, GM_ADDR out)
{
_srcGm.SetGlobalBuffer((__gm__ DTYPE_SRC*)src, _totalSrcNum);
_outGm.SetGlobalBuffer((__gm__ DTYPE_OUT*)out, _totalOutNum);
_indicesGm.SetGlobalBuffer((__gm__ DTYPE_INDICES*)indices, _totalIndicesNum);
}
};
class ScatterAddFullyLoad final : public ScatterAddBaseKernel {
public:
__aicore__ inline ScatterAddFullyLoad() = delete;
__aicore__ inline ScatterAddFullyLoad(
GM_ADDR src, GM_ADDR indices, GM_ADDR out, ScatterAddTilingDataV1* tiling_data, TPipe* pipe)
: ScatterAddBaseKernel(src, indices, out, tiling_data, pipe)
{
initTiling(tiling_data);
initBuffer();
}
__aicore__ inline void Process()
{
_outLocal = _outBuf.Get<DTYPE_OUT>();
Duplicate(_outLocal, 0.0f, AlignUp(_totalOutNum, BLOCK_SIZE / sizeof(DTYPE_OUT)));
for (uint64_t i = 0; i < _loop; i++) {
if (_dimSize <= DIM_SIZE_THRESHOLD) {
computeWithSmallDimSize(i);
} else {
computeWithLargeDimSize(i);
}
}
SetFlag<HardEvent::V_MTE3>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE3>(EVENT_ID0);
SetAtomicAdd<float>();
DataCopyPad(_outGm, _outLocal, {1, static_cast<uint32_t>(_totalOutNum * sizeof(DTYPE_OUT)), 0, 0, 0});
SetAtomicNone();
}
private:
__aicore__ inline void initTiling(ScatterAddTilingDataV1* tiling_data)
{
_indicesNumBigCore = tiling_data->indicesNumBigCore;
_indicesNumSmallCore = tiling_data->indicesNumSmallCore;
_bigCoreNum = tiling_data->bigCoreNum;
if (_blockIdx < _bigCoreNum) {
_indicesMaxLoadableNum = tiling_data->indicesMaxLoadableNumBigCore;
_loop = ceilDiv(_indicesNumBigCore, _indicesMaxLoadableNum);
_indicesNumLeftover = _indicesNumBigCore % _indicesMaxLoadableNum;
_baseOffset = _blockIdx * _indicesNumBigCore;
} else {
_indicesMaxLoadableNum = tiling_data->indicesMaxLoadableNumSmallCore;
_loop = ceilDiv(_indicesNumSmallCore, _indicesMaxLoadableNum);
_indicesNumLeftover = _indicesNumSmallCore % _indicesMaxLoadableNum;
_baseOffset = _bigCoreNum * _indicesNumBigCore + (_blockIdx - _bigCoreNum) * _indicesNumSmallCore;
}
}
__aicore__ inline void initBuffer()
{
_pipe->InitBuffer(_srcBuf, AlignUp(_indicesMaxLoadableNum * sizeof(DTYPE_SRC), MASK_ALIGN_SIZE));
_pipe->InitBuffer(_outBuf, AlignUp(_totalOutNum * sizeof(DTYPE_OUT), MASK_ALIGN_SIZE));
_pipe->InitBuffer(_indicesBuf, AlignUp(_indicesMaxLoadableNum * sizeof(DTYPE_INDICES), MASK_ALIGN_SIZE));
if (_dimSize <= DIM_SIZE_THRESHOLD) {
uint64_t srcMaskNum =
AlignUp(_indicesMaxLoadableNum * sizeof(DTYPE_SRC), MASK_ALIGN_SIZE) / sizeof(DTYPE_SRC);
uint64_t maskBitNum = AscendCUtils::GetBitSize(sizeof(uint8_t));
uint64_t maskBufSize = ceilDiv(srcMaskNum, maskBitNum) * sizeof(uint8_t);
_pipe->InitBuffer(_maskBuf, maskBufSize);
_pipe->InitBuffer(_selectedSrcBuf, AlignUp(_indicesMaxLoadableNum * sizeof(DTYPE_SRC), MASK_ALIGN_SIZE));
_pipe->InitBuffer(_reduceSumBuf, BLOCK_SIZE);
_pipe->InitBuffer(_sharedBuf, BLOCK_SIZE);
}
}
__aicore__ inline void computeWithSmallDimSize(uint64_t i)
{
_srcLocal = _srcBuf.Get<DTYPE_SRC>();
_indicesLocal = _indicesBuf.Get<DTYPE_INDICES>();
_maskLocal = _maskBuf.Get<uint8_t>();
_reduceSumLocal = _reduceSumBuf.Get<DTYPE_SRC>();
_sharedLocal = _sharedBuf.Get<DTYPE_SRC>();
_selectedSrcLocal = _selectedSrcBuf.Get<DTYPE_SRC>();
uint64_t indicesGlobalOffset = _baseOffset + i * _indicesMaxLoadableNum;
uint64_t loadableNum = calcIndicesLoadableNum(i);
uint64_t startHeadID = indicesGlobalOffset / _indicesNumPerHead;
uint64_t endHeadID = (indicesGlobalOffset + loadableNum - 1) / _indicesNumPerHead;
for (uint64_t headID = startHeadID; headID <= endHeadID; headID++) {
uint64_t loadOffset = max(indicesGlobalOffset, headID * _indicesNumPerHead);
uint64_t loadLen = min(indicesGlobalOffset + loadableNum, (headID + 1) * _indicesNumPerHead) - loadOffset;
uint64_t copyLenAlign32 = AlignUp(loadLen, BLOCK_SIZE / sizeof(DTYPE_SRC));
uint64_t cmpLenAlign256 = AlignUp(loadLen, MASK_ALIGN_SIZE / sizeof(DTYPE_INDICES));
SetFlag<HardEvent::V_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE2>(EVENT_ID0);
DataCopy(_srcLocal, _srcGm[loadOffset], copyLenAlign32);
DataCopy(_indicesLocal, _indicesGm[loadOffset], copyLenAlign32);
SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);
for (uint64_t idxVal = 0; idxVal < _dimSize; idxVal++) {
CompareScalar(
_maskLocal, _indicesLocal, static_cast<DTYPE_INDICES>(idxVal), CMPMODE::EQ, cmpLenAlign256);
Select(_selectedSrcLocal, _maskLocal, _srcLocal, static_cast<DTYPE_SRC>(0),
SELMODE::VSEL_TENSOR_SCALAR_MODE, loadLen);
ReduceSum<float>(_reduceSumLocal, _selectedSrcLocal, _sharedLocal, loadLen);
uint64_t outOffset = headID * _outNumPerHead + idxVal;
_outLocal.SetValue(outOffset, _outLocal.GetValue(outOffset) + _reduceSumLocal.GetValue(0));
}
}
}
__aicore__ inline void computeWithLargeDimSize(uint64_t i)
{
uint64_t indicesGlobalOffset = _baseOffset + i * _indicesMaxLoadableNum;
uint64_t loadableNum = calcIndicesLoadableNum(i);
uint64_t copyLenAlign32 = AlignUp(loadableNum, BLOCK_SIZE / sizeof(DTYPE_SRC));
_srcLocal = _srcBuf.Get<DTYPE_SRC>();
_indicesLocal = _indicesBuf.Get<DTYPE_INDICES>();
DataCopy(_srcLocal, _srcGm[indicesGlobalOffset], copyLenAlign32);
DataCopy(_indicesLocal, _indicesGm[indicesGlobalOffset], copyLenAlign32);
SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);
uint64_t startHeadID = indicesGlobalOffset / _indicesNumPerHead;
uint64_t endHeadID = (indicesGlobalOffset + loadableNum - 1) / _indicesNumPerHead;
for (uint64_t headID = startHeadID; headID <= endHeadID; headID++) {
uint64_t localStart = max(indicesGlobalOffset, headID * _indicesNumPerHead) - indicesGlobalOffset;
uint64_t localEnd =
min(indicesGlobalOffset + loadableNum, (headID + 1) * _indicesNumPerHead) - indicesGlobalOffset;
for (uint64_t pos = localStart; pos < localEnd; pos++) {
DTYPE_INDICES idxVal = _indicesLocal.GetValue(pos);
uint64_t outOffset = headID * _outNumPerHead + idxVal;
_outLocal.SetValue(outOffset, _outLocal.GetValue(outOffset) + _srcLocal.GetValue(pos));
}
}
}
__aicore__ inline uint64_t calcIndicesLoadableNum(uint64_t i)
{
if (_indicesNumLeftover == 0) {
return _indicesMaxLoadableNum;
}
return i == _loop - 1 ? _indicesNumLeftover : _indicesMaxLoadableNum;
}
private:
TBuf<TPosition::VECCALC> _maskBuf;
TBuf<TPosition::VECCALC> _selectedSrcBuf;
TBuf<TPosition::VECCALC> _reduceSumBuf;
TBuf<TPosition::VECCALC> _sharedBuf;
LocalTensor<uint8_t> _maskLocal;
LocalTensor<DTYPE_SRC> _selectedSrcLocal;
LocalTensor<DTYPE_SRC> _reduceSumLocal;
LocalTensor<DTYPE_SRC> _sharedLocal;
uint64_t _indicesNumBigCore;
uint64_t _indicesNumSmallCore;
uint64_t _bigCoreNum;
uint64_t _indicesMaxLoadableNum;
uint64_t _indicesNumLeftover;
uint64_t _baseOffset;
};
class ScatterAddMultiHeads final : public ScatterAddBaseKernel {
public:
__aicore__ inline ScatterAddMultiHeads() = delete;
__aicore__ inline ScatterAddMultiHeads(
GM_ADDR src, GM_ADDR indices, GM_ADDR out, ScatterAddTilingDataV1* tiling_data, TPipe* pipe)
: ScatterAddBaseKernel(src, indices, out, tiling_data, pipe)
{
initTiling(tiling_data);
initBuffer();
}
__aicore__ inline void Process()
{
for (uint64_t i = 0; i < _loop; i++) {
compute(i);
}
}
private:
__aicore__ inline void initTiling(ScatterAddTilingDataV1* tiling_data)
{
_headNumBigCore = tiling_data->headNumBigCore;
_headNumSmallCore = tiling_data->headNumSmallCore;
_bigCoreNum = tiling_data->bigCoreNum;
_headNumPerTask = tiling_data->headNumPerTask;
if (_blockIdx < _bigCoreNum) {
_loop = ceilDiv(_headNumBigCore, _headNumPerTask);
_headNumLeftover = _headNumBigCore % _headNumPerTask;
_headIdOffset = _blockIdx * _headNumBigCore;
} else {
_loop = ceilDiv(_headNumSmallCore, _headNumPerTask);
_headNumLeftover = _headNumSmallCore % _headNumPerTask;
_headIdOffset = _bigCoreNum * _headNumBigCore + (_blockIdx - _bigCoreNum) * _headNumSmallCore;
}
}
__aicore__ inline void initBuffer()
{
uint64_t outMaxLoadNum = AlignUp(_headNumPerTask * _outNumPerHead, _dataNumPerBlock);
uint64_t indicesMaxLoadNum = AlignUp(_headNumPerTask * _indicesNumPerHead, _dataNumPerBlock);
_pipe->InitBuffer(_srcBuf, indicesMaxLoadNum * sizeof(DTYPE_SRC));
_pipe->InitBuffer(_outBuf, outMaxLoadNum * sizeof(DTYPE_OUT));
_pipe->InitBuffer(_indicesBuf, indicesMaxLoadNum * sizeof(DTYPE_INDICES));
}
__aicore__ inline void compute(uint64_t i)
{
_srcLocal = _srcBuf.Get<DTYPE_SRC>();
_outLocal = _outBuf.Get<DTYPE_OUT>();
_indicesLocal = _indicesBuf.Get<DTYPE_INDICES>();
uint64_t startHeadID = _headIdOffset + i * _headNumPerTask;
uint64_t headNum = calcHeadNum(i);
uint64_t outLoadOffset = startHeadID * _outNumPerHead;
uint64_t indicesLoadOffset = startHeadID * _indicesNumPerHead;
uint64_t outLoadNum = headNum * _outNumPerHead;
uint64_t indicesLoadNum = headNum * _indicesNumPerHead;
PipeBarrier<PIPE_ALL>();
DataCopy(_outLocal, _outGm[outLoadOffset], AlignUp(outLoadNum, _dataNumPerBlock));
DataCopy(_srcLocal, _srcGm[indicesLoadOffset], AlignUp(indicesLoadNum, _dataNumPerBlock));
DataCopy(_indicesLocal, _indicesGm[indicesLoadOffset], AlignUp(indicesLoadNum, _dataNumPerBlock));
for (uint64_t h = 0; h < headNum; h++) {
uint64_t indicesBaseOffset = h * _indicesNumPerHead;
uint64_t outBaseOffset = h * _outNumPerHead;
for (uint64_t k = 0; k < _indicesNumPerHead; k++) {
uint64_t indicesOffset = indicesBaseOffset + k;
DTYPE_INDICES idxVal = _indicesLocal.GetValue(indicesOffset);
int64_t outOffset = outBaseOffset + idxVal;
_outLocal.SetValue(outOffset, _outLocal.GetValue(outOffset) + _srcLocal.GetValue(indicesOffset));
}
}
DataCopyPad(
_outGm[outLoadOffset], _outLocal, {1, static_cast<uint32_t>(outLoadNum * sizeof(DTYPE_OUT)), 0, 0, 0});
}
__aicore__ inline uint64_t calcHeadNum(uint64_t i)
{
if (_headNumLeftover == 0) {
return _headNumPerTask;
}
return i == _loop - 1 ? _headNumLeftover : _headNumPerTask;
}
private:
uint64_t _headNumBigCore;
uint64_t _headNumSmallCore;
uint64_t _bigCoreNum;
uint64_t _headNumPerTask;
uint64_t _headNumLeftover;
uint64_t _headIdOffset;
};
template<bool largeHead>
class ScatterAddHeadInBatch final : public ScatterAddBaseKernel {
public:
__aicore__ inline ScatterAddHeadInBatch() = delete;
__aicore__ inline ScatterAddHeadInBatch(
GM_ADDR src, GM_ADDR indices, GM_ADDR out, ScatterAddTilingDataV1* tiling_data, TPipe* pipe)
: ScatterAddBaseKernel(src, indices, out, tiling_data, pipe)
{
if constexpr (largeHead) {
initTilingLargeHead(tiling_data);
} else {
initTilingFewHeads(tiling_data);
}
initBuffer();
}
__aicore__ inline void Process()
{
for (uint64_t i = 0; i < _loop; i++) {
compute(i);
}
}
private:
__aicore__ inline void initTilingLargeHead(ScatterAddTilingDataV1* tiling_data)
{
_headNumBigCore = tiling_data->headNumBigCore;
_headNumSmallCore = tiling_data->headNumSmallCore;
_bigCoreNum = tiling_data->bigCoreNum;
_indicesNumPerBatch = tiling_data->indicesNumPerBatch;
_maxOutNumPerBatch = tiling_data->maxOutNumPerBatch;
if (_blockIdx < _bigCoreNum) {
_loop = _headNumBigCore;
_headIdOffset = _blockIdx * _headNumBigCore;
} else {
_loop = _headNumSmallCore;
_headIdOffset = _bigCoreNum * _headNumBigCore + (_blockIdx - _bigCoreNum) * _headNumSmallCore;
}
if (_dimSize <= _maxOutNumPerBatch) {
_outLoop = 1;
_outNumPerBatch = _dimSize;
_outNumLeftover = 0;
} else {
_outLoop = ceilDiv(_dimSize, _maxOutNumPerBatch);
_outNumPerBatch = _maxOutNumPerBatch;
_outNumLeftover = _dimSize % _maxOutNumPerBatch;
}
_indicesLoop = ceilDiv(_indicesNumPerHead, _indicesNumPerBatch);
_indicesNumLeftover = _indicesNumPerHead % _indicesNumPerBatch;
}
__aicore__ inline void initTilingFewHeads(ScatterAddTilingDataV1* tiling_data)
{
_indicesNumPerBatch = tiling_data->indicesNumPerBatch;
_maxOutNumPerBatch = tiling_data->maxOutNumPerBatch;
_coreNumPerHead = tiling_data->coreNumPerHead;
_outNumPerCore = tiling_data->outNumPerCore;
_outNumPartOffset = _outNumPerCore;
_loop = 1;
_headIdOffset = _blockIdx / _coreNumPerHead;
_headPartId = _blockIdx % _coreNumPerHead;
if (_headPartId == _coreNumPerHead - 1) {
_outNumPerCore = _dimSize - _outNumPerCore * (_coreNumPerHead - 1);
}
if (_outNumPerCore <= _maxOutNumPerBatch) {
_outLoop = 1;
_outNumPerBatch = _outNumPerCore;
_outNumLeftover = _outNumPerCore;
} else {
_outLoop = ceilDiv(_outNumPerCore, _maxOutNumPerBatch);
_outNumPerBatch = _maxOutNumPerBatch;
_outNumLeftover = _outNumPerCore % _maxOutNumPerBatch;
}
_indicesLoop = ceilDiv(_indicesNumPerHead, _indicesNumPerBatch);
_indicesNumLeftover = _indicesNumPerHead % _indicesNumPerBatch;
}
__aicore__ inline void initBuffer()
{
_pipe->InitBuffer(_srcBuf, AlignUp(_indicesNumPerBatch, _dataNumPerBlock) * sizeof(DTYPE_SRC));
_pipe->InitBuffer(_outBuf, AlignUp(_outNumPerBatch, MAX_MASK) * sizeof(DTYPE_OUT));
_pipe->InitBuffer(_indicesBuf, AlignUp(_indicesNumPerBatch, _dataNumPerBlock) * sizeof(DTYPE_INDICES));
uint64_t maskBitNum = AscendCUtils::GetBitSize(sizeof(uint8_t));
uint64_t indicesMaskNum =
AlignUp(_indicesNumPerBatch * sizeof(DTYPE_INDICES), MASK_ALIGN_SIZE) / sizeof(DTYPE_INDICES);
_maskLen = AlignUp(ceilDiv(indicesMaskNum, maskBitNum), BLOCK_SIZE);
_pipe->InitBuffer(_maskBuf, _maskLen * sizeof(uint8_t) * UB_SIZE_COEFF);
}
__aicore__ inline void compute(uint64_t i)
{
SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
uint64_t headID = _headIdOffset + i;
for (uint64_t k = 0; k < _outLoop; k++) {
uint64_t outNum = calcBatchOutNum(k);
computeHeadInBatches(k, headID, outNum);
}
WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
}
__aicore__ inline void computeHeadInBatches(uint64_t k, uint64_t headID, uint64_t outNum)
{
_srcLocal = _srcBuf.Get<DTYPE_SRC>();
_outLocal = _outBuf.Get<DTYPE_OUT>();
_indicesLocal = _indicesBuf.Get<DTYPE_INDICES>();
uint64_t outHeadOffset = headID * _outNumPerHead;
uint64_t outLoadOffset;
if constexpr (largeHead) {
outLoadOffset = outHeadOffset + k * _outNumPerBatch;
} else {
outLoadOffset = outHeadOffset + _headPartId * _outNumPartOffset + k * _outNumPerBatch;
}
WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
DataCopy(_outLocal, _outGm[outLoadOffset], AlignUp(_outNumPerBatch, _dataNumPerBlock));
int32_t idxValOffset = outHeadOffset - outLoadOffset;
computeBatch(headID, outNum, outLoadOffset, idxValOffset);
DataCopyPad(_outGm[outLoadOffset], _outLocal, {1, static_cast<uint32_t>(outNum * sizeof(DTYPE_OUT)), 0, 0, 0});
SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
}
__aicore__ inline void computeBatch(uint64_t headID, uint64_t outNum, uint64_t outLoadOffset, int32_t idxValOffset)
{
_indicesMask = _maskBuf.Get<uint8_t>();
uint64_t indicesHeadOffset = headID * _indicesNumPerHead;
for (uint64_t n = 0; n < _indicesLoop; n++) {
uint64_t indicesLoadOffset = indicesHeadOffset + n * _indicesNumPerBatch;
uint64_t indicesLoadNum = calcBatchIndicesNum(n);
uint64_t indicesLoadAlign = AlignUp(indicesLoadNum, _dataNumPerBlock);
uint64_t cmpLenAlign256 = AlignUp(indicesLoadNum, MASK_ALIGN_SIZE / sizeof(DTYPE_INDICES));
uint64_t mask64BitNum = ceilDiv(indicesLoadNum, 64);
DataCopy(_indicesLocal, _indicesGm[indicesLoadOffset], indicesLoadAlign);
SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
DataCopy(_srcLocal, _srcGm[indicesLoadOffset], indicesLoadAlign);
WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);
Adds(_indicesLocal, _indicesLocal, idxValOffset, indicesLoadAlign);
Cast(_indicesLocal.ReinterpretCast<float>(), _indicesLocal, RoundMode::CAST_NONE, indicesLoadNum);
CompareScalar(_indicesMask, _indicesLocal.ReinterpretCast<float>(), 0.0f, CMPMODE::GE, cmpLenAlign256);
CompareScalar(_indicesMask[_maskLen], _indicesLocal.ReinterpretCast<float>(), outNum * 1.0f, CMPMODE::LT,
cmpLenAlign256);
Cast(_indicesLocal, _indicesLocal.ReinterpretCast<float>(), RoundMode::CAST_RINT, indicesLoadNum);
And(_indicesMask.ReinterpretCast<uint16_t>(), _indicesMask.ReinterpretCast<uint16_t>(),
_indicesMask[_maskLen].ReinterpretCast<uint16_t>(), _maskLen);
for (uint64_t m = 0; m < mask64BitNum; m++) {
uint64_t mask = _indicesMask.ReinterpretCast<uint64_t>().GetValue(m);
uint64_t maxPos = (m == mask64BitNum - 1) ? indicesLoadNum - m * 64 : 64;
for (int32_t p = ScalarGetSFFValue<1>(mask); p >= 0 && p < maxPos; p = ScalarGetSFFValue<1>(mask)) {
mask = sbitset0(mask, p);
uint64_t pos = p + m * 64;
DTYPE_INDICES idxVal = _indicesLocal.GetValue(pos);
_outLocal.SetValue(idxVal, _outLocal.GetValue(idxVal) + _srcLocal.GetValue(pos));
}
}
}
}
__aicore__ inline uint64_t calcBatchOutNum(uint64_t k)
{
if (_outNumLeftover == 0) {
return _outNumPerBatch;
}
return k == _outLoop - 1 ? _outNumLeftover : _outNumPerBatch;
}
__aicore__ inline uint64_t calcBatchIndicesNum(uint64_t n)
{
if (_indicesNumLeftover == 0) {
return _indicesNumPerBatch;
}
return n == _indicesLoop - 1 ? _indicesNumLeftover : _indicesNumPerBatch;
}
private:
uint64_t _headIdOffset;
uint64_t _bigCoreNum;
uint64_t _outLoop;
uint64_t _outNumPerBatch;
uint64_t _outNumLeftover;
uint64_t _maxOutNumPerBatch;
uint64_t _outNumPartOffset;
uint64_t _indicesLoop;
uint64_t _indicesNumPerBatch;
uint64_t _indicesNumLeftover;
uint64_t _headNumBigCore;
uint64_t _headNumSmallCore;
uint64_t _headPartId;
uint64_t _coreNumPerHead;
uint64_t _outNumPerCore;
TBuf<TPosition::VECCALC> _maskBuf;
LocalTensor<uint8_t> _indicesMask;
uint64_t _maskLen;
};
template<bool smallTail>
class ScatterAddWithTail final : public ScatterAddBaseKernel {
public:
__aicore__ inline ScatterAddWithTail() = delete;
__aicore__ inline ScatterAddWithTail(
GM_ADDR src, GM_ADDR indices, GM_ADDR out, ScatterAddTilingDataV1* tiling_data, TPipe* pipe)
: ScatterAddBaseKernel(src, indices, out, tiling_data, pipe)
{
if constexpr (smallTail) {
initTilingSmallTail(tiling_data);
} else {
initTilingLargeTail(tiling_data);
}
initBuffer();
}
__aicore__ inline void Process()
{
for (uint64_t i = 0; i < _loop; i++) {
if constexpr (smallTail) {
computeWithSmallTail(i);
} else {
computeWithLargeTail(i);
}
}
}
private:
__aicore__ inline void initTilingSmallTail(ScatterAddTilingDataV1* tiling_data)
{
_srcTailBigCore = tiling_data->srcTailBigCore;
_srcTailSmallCore = tiling_data->srcTailSmallCore;
_bigCoreNum = tiling_data->bigCoreNum;
_dbTimes = 1;
_tailElemNum = AlignUp(_tailLen, _dataNumPerBlock);
_tailNumPerBatch = min(_srcTailBigCore, _ubSize / sizeof(DTYPE_SRC) / (_tailElemNum + 1));
_tailNumPerBatch = min(_tailNumPerBatch, MAX_COPY_PAD);
_tailNumPerBatch = (_tailNumPerBatch == 0) ? 1 : _tailNumPerBatch;
_indicesLoadLen = AlignUp(_tailNumPerBatch, _dataNumPerBlock);
_srcLoadLen = _tailElemNum * _tailNumPerBatch;
if (_blockIdx < _bigCoreNum) {
_indicesBaseOffset = _blockIdx * _srcTailBigCore;
_loop = ceilDiv(_srcTailBigCore, _tailNumPerBatch);
_tailNumLeftover = _srcTailBigCore % _tailNumPerBatch;
} else {
_indicesBaseOffset = _bigCoreNum * _srcTailBigCore + (_blockIdx - _bigCoreNum) * _srcTailSmallCore;
_loop = ceilDiv(_srcTailSmallCore, _tailNumPerBatch);
_tailNumLeftover = _srcTailSmallCore % _tailNumPerBatch;
}
}
__aicore__ inline void initTilingLargeTail(ScatterAddTilingDataV1* tiling_data)
{
_srcTailBigCore = tiling_data->srcTailBigCore;
_srcTailSmallCore = tiling_data->srcTailSmallCore;
_bigCoreNum = tiling_data->bigCoreNum;
_tailLenThreshold = tiling_data->tailLenThreshold;
_dbTimes = min(_ubSize / sizeof(DTYPE_SRC) / min(_tailLen, _tailLenThreshold), BUFFER_NUM_MAX);
_dbTimes = (_dbTimes == 0) ? 1 : _dbTimes;
uint64_t srcElemNumTmp =
min(static_cast<uint64_t>(AlignUp(_tailLen, _dataNumPerBlock)), _tailLenThreshold) * _dbTimes;
uint64_t availIndicesSize = _ubSize - srcElemNumTmp * sizeof(DTYPE_SRC);
uint64_t srcTailNum = (_blockIdx < _bigCoreNum) ? _srcTailBigCore : _srcTailSmallCore;
_indicesLoadLen = min(availIndicesSize / BLOCK_SIZE * BLOCK_SIZE / sizeof(DTYPE_INDICES), srcTailNum);
_indicesLoadLen = (_indicesLoadLen == 0) ? 1 : _indicesLoadLen;
_loop = ceilDiv(srcTailNum, _indicesLoadLen);
_indicesNumPerBatch = _indicesLoadLen;
_indicesNumLeftover = srcTailNum % _indicesLoadLen;
_tailElemNum = (_ubSize - _indicesLoadLen * sizeof(DTYPE_INDICES)) / _dbTimes / BLOCK_SIZE * _dataNumPerBlock;
_tailElemNum = min(_tailElemNum, static_cast<uint64_t>(AlignUp(_tailLen, _dataNumPerBlock)));
_tailElemNum = (_tailElemNum == 0) ? 1 : _tailElemNum;
_srcLoadLen = _tailElemNum * _dbTimes;
_tailElemLoop = _tailLen / _tailElemNum;
_tailElemNumLeftover = _tailLen % _tailElemNum;
if (_blockIdx < _bigCoreNum) {
_indicesBaseOffset = _blockIdx * _srcTailBigCore;
} else {
_indicesBaseOffset = _bigCoreNum * _srcTailBigCore + (_blockIdx - _bigCoreNum) * _srcTailSmallCore;
}
_countDB = 0;
_eventID = EVENT_ID0;
}
__aicore__ inline void initBuffer()
{
_pipe->InitBuffer(_srcBuf, AlignUp(_srcLoadLen, _dataNumPerBlock) * sizeof(DTYPE_SRC));
_pipe->InitBuffer(_indicesBuf, AlignUp(_indicesLoadLen, _dataNumPerBlock) * sizeof(DTYPE_INDICES));
}
__aicore__ inline void computeWithSmallTail(uint64_t i)
{
_srcLocal = _srcBuf.Get<DTYPE_SRC>();
_indicesLocal = _indicesBuf.Get<DTYPE_INDICES>();
uint64_t indicesOffset = _indicesBaseOffset + i * _tailNumPerBatch;
uint64_t srcOffset = indicesOffset * _tailLen;
uint16_t loadTailNum = calcTailLoadNum(i);
PipeBarrier<PIPE_ALL>();
DataCopy(_indicesLocal, _indicesGm[indicesOffset], AlignUp(loadTailNum, _dataNumPerBlock));
copyParamsIn = {loadTailNum, static_cast<uint32_t>(_tailLen * sizeof(DTYPE_SRC)), 0, 0, 0};
DataCopyPadExtParams<DTYPE_OUT> queryCopyInPadParams {false, 0, 0, 0};
DataCopyPad(_srcLocal, _srcGm[srcOffset], copyParamsIn, queryCopyInPadParams);
SetAtomicAdd<DTYPE_SRC>();
for (uint64_t p = 0; p < loadTailNum; p++) {
DTYPE_INDICES idxVal = _indicesLocal.GetValue(p);
uint64_t idxGlobal = indicesOffset + p;
uint64_t headID = idxGlobal / _srcDimSize;
uint64_t srcLocalOffset = p * AlignUp(_tailLen, _dataNumPerBlock);
uint64_t outOffset = (idxVal + headID * _dimSize) * _tailLen;
copyParamsOut = {1, static_cast<uint32_t>(_tailLen * sizeof(DTYPE_OUT)), 0, 0, 0};
DataCopyPad(_outGm[outOffset], _srcLocal[srcLocalOffset], copyParamsOut);
}
SetAtomicNone();
}
__aicore__ inline void computeWithLargeTail(uint64_t i)
{
_srcLocal = _srcBuf.Get<DTYPE_SRC>();
_indicesLocal = _indicesBuf.Get<DTYPE_INDICES>();
uint64_t indicesOffset = _indicesBaseOffset + i * _indicesNumPerBatch;
uint64_t indicesLoadNum = calcIndicesLoadNum(i);
DataCopy(_indicesLocal, _indicesGm[indicesOffset], AlignUp(indicesLoadNum, _dataNumPerBlock));
setFlagMET3_2();
for (uint64_t p = 0; p < indicesLoadNum; p++) {
DTYPE_INDICES idxVal = _indicesLocal.GetValue(p);
uint64_t idxGlobal = indicesOffset + p;
auto srcOffset = idxGlobal * _tailLen;
SetAtomicAdd<DTYPE_SRC>();
computeLargeTailAdd(idxGlobal, idxVal, srcOffset);
SetAtomicNone();
}
waitFlagMET3_2();
}
__aicore__ inline void computeLargeTailAdd(uint64_t idxGlobal, uint64_t idxVal, uint64_t srcOffset)
{
uint64_t headID = idxGlobal / _srcDimSize;
uint64_t outGlobalOffset = (idxVal + headID * _dimSize) * _tailLen;
for (uint64_t k = 0; k < _tailElemLoop; k++) {
uint64_t offset = k * _tailElemNum;
uint64_t localOffset = getEventIdforDoublebuffer();
WaitFlag<HardEvent::MTE3_MTE2>(_eventID);
DataCopy(_srcLocal[localOffset], _srcGm[srcOffset + offset], _tailElemNum);
SetFlag<HardEvent::MTE2_MTE3>(_eventID);
WaitFlag<HardEvent::MTE2_MTE3>(_eventID);
DataCopy(_outGm[outGlobalOffset + offset], _srcLocal[localOffset], _tailElemNum);
SetFlag<HardEvent::MTE3_MTE2>(_eventID);
}
if (_tailElemNumLeftover != 0) {
uint64_t offset = _tailElemLoop * _tailElemNum;
uint64_t localOffset = getEventIdforDoublebuffer();
WaitFlag<HardEvent::MTE3_MTE2>(_eventID);
DataCopy(
_srcLocal[localOffset], _srcGm[srcOffset + offset], AlignUp(_tailElemNumLeftover, _dataNumPerBlock));
SetFlag<HardEvent::MTE2_MTE3>(_eventID);
WaitFlag<HardEvent::MTE2_MTE3>(_eventID);
copyParamsOut = {1, static_cast<uint32_t>(_tailElemNumLeftover * sizeof(DTYPE_SRC)), 0, 0, 0};
DataCopyPad(_outGm[outGlobalOffset + offset], _srcLocal[localOffset], copyParamsOut);
SetFlag<HardEvent::MTE3_MTE2>(_eventID);
}
}
__aicore__ inline uint64_t getEventIdforDoublebuffer()
{
_eventID = _countDB % _dbTimes;
uint64_t localOffset = _tailElemNum * _eventID;
_countDB++;
return localOffset;
}
__aicore__ inline void setFlagMET3_2()
{
for (uint64_t i = 0; i < _dbTimes; i++) {
SetFlag<HardEvent::MTE3_MTE2>(i);
}
}
__aicore__ inline void waitFlagMET3_2()
{
for (uint64_t i = 0; i < _dbTimes; i++) {
WaitFlag<HardEvent::MTE3_MTE2>(i);
}
}
__aicore__ inline uint64_t calcTailLoadNum(uint64_t i)
{
if (_tailNumLeftover == 0) {
return _tailNumPerBatch;
}
return (i == _loop - 1) ? _tailNumLeftover : _tailNumPerBatch;
}
__aicore__ inline uint64_t calcIndicesLoadNum(uint64_t i)
{
if (_indicesNumLeftover == 0) {
return _indicesNumPerBatch;
}
return (i == _loop - 1) ? _indicesNumLeftover : _indicesNumPerBatch;
}
private:
uint64_t _srcTailBigCore;
uint64_t _srcTailSmallCore;
uint64_t _bigCoreNum;
uint64_t _tailLenThreshold;
uint64_t _srcLoadLen;
uint64_t _indicesLoadLen;
uint64_t _indicesBaseOffset;
uint64_t _tailElemNum;
uint64_t _tailNumPerBatch;
uint64_t _tailNumLeftover;
uint64_t _indicesNumPerBatch;
uint64_t _indicesNumLeftover;
uint64_t _tailElemLoop;
uint64_t _tailElemNumLeftover;
uint64_t _dbTimes;
uint64_t _countDB;
uint64_t _eventID;
DataCopyExtParams copyParamsIn;
DataCopyExtParams copyParamsOut;
};
#endif