* Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
*/
#ifndef SCATTER_ADD_GRAD_BASE_H_
#define SCATTER_ADD_GRAD_BASE_H_
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
namespace ScatterAddGradNS {
using namespace AscendC;
constexpr uint32_t BLOCK_BYTES = 32;
constexpr uint32_t MASK_BYTES = 256;
constexpr uint32_t MASK = 256 / sizeof(int32_t);
constexpr uint32_t BUFFER_NUM = 4;
template <typename T>
class ScatterAddGradBase {
public:
__aicore__ inline ScatterAddGradBase() {}
__aicore__ inline void InitTiling(const ScatterAddGradTilingData* tilingData)
{
ASSERT(GetBlockNum() != 0 && "block dim can not be zero!");
this->curBlockIdx = GetBlockIdx();
this->tilingMode = tilingData->tilingMode;
this->dimRange = tilingData->dimRange;
this->dimRangeOut = tilingData->dimRangeOut;
this->paramsPro = tilingData->paramsPro;
this->tail = tilingData->tail;
this->body = this->paramsPro / this->tail;
this->bigCoreNum = tilingData->bigCoreNum;
this->indexUbSize = tilingData->indexUbSize;
this->gradOutUbSize = tilingData->gradOutUbSize;
this->gradInNum = tilingData->gradInNum;
this->indexNum = tilingData->indexNum;
this->gradOutNum = tilingData->gradOutNum;
this->indicesEachBlock = BLOCK_BYTES / sizeof(DTYPE_INDEX);
this->paramsEachBlock = BLOCK_BYTES / sizeof(float);
}
protected:
uint32_t curBlockIdx;
uint64_t indexUbSize;
uint64_t gradOutUbSize;
uint64_t dimRange;
uint64_t dimRangeOut;
uint64_t paramsPro;
uint64_t gradInNum;
uint64_t indexNum;
uint64_t gradOutNum;
int32_t dim;
uint32_t tilingMode;
uint32_t tail;
uint32_t body;
uint32_t bigCoreNum;
uint32_t indicesEachBlock;
uint32_t paramsEachBlock;
DataCopyExtParams copyParamsOut = {1, 8, 0, 0, 0};
};
}
#endif