#include "kernel_operator.h"
using namespace AscendC;
#define BUFFER_NUM 2
#define BUFFER_ALIGN_BYTE 32
#define COMPARE_ALGIN_BYTE 256
template<typename dataType, bool aicpu>
class KernelRoiawareMaxpool3dGrad {
public:
__aicore__ inline KernelRoiawareMaxpool3dGrad() {}
__aicore__ void Init(const GM_ADDR argmax, const GM_ADDR gradOut, GM_ADDR gradIn, TPipe* pipe,
const RoiawareMaxpool3dGradTilingData* tiling_data)
{
this->pipe_ = pipe;
this->blkIdx_ = GetBlockIdx();
this->InitTiling(tiling_data);
this->InitGlobalBuffer(argmax, gradOut, gradIn, tiling_data);
this->InitUBBuffer();
}
__aicore__ void Process()
{
uint32_t globalIdx = goutCoreStartIdx_;
for (uint32_t curLoopTaskIdx = 0; curLoopTaskIdx < coreTask_; curLoopTaskIdx += singleLoopTask_, globalIdx += singleLoopTask_) {
uint32_t curLoopTaskCount = min(singleLoopTask_, coreTask_ - curLoopTaskIdx);
CopyInNTask(globalIdx, curLoopTaskCount);
ProcessNTask(curLoopTaskCount);
}
}
protected:
TPipe *pipe_;
uint32_t blkIdx_;
uint32_t channels_, nPoints_, channelAligned_, alignedArgmaxByteLength_;
uint32_t coreTask_, singleLoopTask_, singleLoopOutput_;
uint32_t argmaxCoreStartIdx_, goutCoreStartIdx_;
TQue<QuePosition::VECIN, BUFFER_NUM> argmaxQue_, goutQue_;
TBuf<TPosition::VECCALC> tmpMaskBuf_;
TBuf<TPosition::VECCALC> tmpSelectedBuf_;
TBuf<TPosition::VECCALC> ginBuf_;
GlobalTensor<int32_t> argmaxGm_;
GlobalTensor<dataType> goutGm_, ginGm_;
DataCopyExtParams goutCopyParams_;
DataCopyExtParams argmaxCopyParams_;
DataCopyPadExtParams<dataType> goutPadParams_{true, 0, 0, 0};
DataCopyPadExtParams<int32_t> argmaxPadParams_{true, 0, 0, 0};
private:
__aicore__ inline void InitTiling(const RoiawareMaxpool3dGradTilingData* tiling_data)
{
this->channels_ = tiling_data->channels;
this->nPoints_ = tiling_data->npoints;
this->singleLoopTask_ = tiling_data->singleLoopTask;
this->singleLoopOutput_ = tiling_data->singleLoopOutput;
this->channelAligned_ = tiling_data->channelAligned;
if (blkIdx_ >= tiling_data->firstSmallCoreIdx) {
this->coreTask_ = tiling_data->coreTask - 1;
argmaxCoreStartIdx_ = tiling_data->coreTask * blkIdx_ - (blkIdx_ - tiling_data->firstSmallCoreIdx);
goutCoreStartIdx_ = argmaxCoreStartIdx_;
} else {
this->coreTask_ = tiling_data->coreTask;
argmaxCoreStartIdx_ = tiling_data->coreTask * blkIdx_;
goutCoreStartIdx_ = argmaxCoreStartIdx_;
}
this->alignedArgmaxByteLength_ = Ceil(singleLoopTask_ * channelAligned_ * sizeof(int32_t), COMPARE_ALGIN_BYTE) * COMPARE_ALGIN_BYTE;
}
__aicore__ inline void InitGlobalBuffer(const GM_ADDR argmax, const GM_ADDR gradOut, const GM_ADDR gradIn,
const RoiawareMaxpool3dGradTilingData* tiling_data)
{
uint64_t argmaxLength = static_cast<uint64_t>(tiling_data->totalTask) * channels_ * sizeof(int32_t);
uint64_t gradOutLength = static_cast<uint64_t>(tiling_data->totalTask) * channels_ * sizeof(dataType);
uint64_t gradInLength = static_cast<uint64_t>(nPoints_) * channels_ * sizeof(dataType);
this->argmaxGm_.SetGlobalBuffer((__gm__ int32_t*) argmax, argmaxLength);
this->goutGm_.SetGlobalBuffer((__gm__ dataType*) gradOut, gradOutLength);
this->ginGm_.SetGlobalBuffer((__gm__ dataType*) gradIn, gradInLength);
}
__aicore__ inline void InitUBBuffer()
{
this->pipe_->InitBuffer(argmaxQue_, BUFFER_NUM, alignedArgmaxByteLength_);
this->pipe_->InitBuffer(goutQue_, BUFFER_NUM, singleLoopTask_ * channelAligned_ * sizeof(dataType));
this->pipe_->InitBuffer(ginBuf_, singleLoopOutput_ * channelAligned_ * sizeof(dataType));
uint32_t tmpMaskBufferByteLength = BUFFER_ALIGN_BYTE * Ceil(Ceil(channelAligned_ * singleLoopTask_, AscendCUtils::GetBitSize(sizeof(uint8_t))), BUFFER_ALIGN_BYTE);
this->pipe_->InitBuffer(this->tmpMaskBuf_, tmpMaskBufferByteLength);
this->pipe_->InitBuffer(this->tmpSelectedBuf_, singleLoopTask_ * channelAligned_ * sizeof(dataType));
}
__aicore__ inline void CopyInNTask(int64_t inputIdx, int32_t curLoopTaskCount)
{
LocalTensor<dataType> goutLocal = goutQue_.AllocTensor<dataType>();
LocalTensor<int32_t> argmaxLocal = argmaxQue_.AllocTensor<int32_t>();
int64_t goutLocalOffset = inputIdx * channels_;
goutCopyParams_ = {static_cast<uint16_t>(curLoopTaskCount), static_cast<uint32_t>(channels_ * sizeof(dataType)), 0, 0, 0};
DataCopyPad(goutLocal, goutGm_[goutLocalOffset], goutCopyParams_, goutPadParams_);
int64_t argmaxLocalOffset = inputIdx * channels_;
argmaxCopyParams_ = {static_cast<uint16_t>(curLoopTaskCount), static_cast<uint32_t>(channels_ * sizeof(int32_t)), 0, 0, 0};
DataCopyPad(argmaxLocal, argmaxGm_[argmaxLocalOffset], argmaxCopyParams_, argmaxPadParams_);
goutQue_.EnQue<dataType>(goutLocal);
argmaxQue_.EnQue<int32_t>(argmaxLocal);
}
__aicore__ inline void ProcessNTask(const uint32_t &taskCount)
{
bool free_data = false;
for (int32_t outputIdx = 0; outputIdx < nPoints_; outputIdx += singleLoopOutput_) {
uint32_t curOutputTask = min(singleLoopOutput_, nPoints_ - outputIdx);
if (outputIdx + singleLoopOutput_ >= nPoints_)
free_data = true;
Compute(taskCount, curOutputTask, outputIdx, free_data);
}
}
__aicore__ inline void GinLocalReduceSum(const uint32_t &oi, const uint32_t &curLoopTaskCount, LocalTensor<dataType> &ginLocal, LocalTensor<dataType> &selectLocal)
{
uint32_t selectRepStride = (channelAligned_ * sizeof(dataType)) / (BUFFER_ALIGN_BYTE);
uint64_t maxElementsCount = 256 / sizeof(dataType);
uint32_t baseIdx = oi * channelAligned_;
if (selectRepStride >= 64) {
for (uint32_t taskIdx = 0; taskIdx < curLoopTaskCount; taskIdx++) {
Add(ginLocal[oi * channelAligned_], selectLocal[taskIdx * channelAligned_], ginLocal[oi * channelAligned_], channels_);
}
return;
}
for (int32_t offsetIdx = 0; offsetIdx < channelAligned_; offsetIdx += maxElementsCount) {
uint64_t mask = maxElementsCount < channelAligned_ - offsetIdx ? maxElementsCount : channelAligned_ - offsetIdx;
Add(ginLocal[offsetIdx + baseIdx], selectLocal[offsetIdx], ginLocal[offsetIdx + baseIdx], mask, curLoopTaskCount, {1, 1, 1, 0, static_cast<uint8_t>(selectRepStride), 0});
}
}
__aicore__ inline void Compute(const uint32_t &curLoopTaskCount, const uint32_t &curLoopOutputCount,
const uint64_t &outputIdx, const bool &freeTensor)
{
LocalTensor<dataType> goutLocal = goutQue_.DeQue<dataType>();
LocalTensor<int32_t> argmaxLocal = argmaxQue_.DeQue<int32_t>();
LocalTensor<dataType> ginLocal = ginBuf_.AllocTensor<dataType>();
LocalTensor<uint8_t> maskLocal = tmpMaskBuf_.AllocTensor<uint8_t>();
LocalTensor<dataType> selectLocal = tmpSelectedBuf_.AllocTensor<dataType>();
int32_t eventId1 = static_cast<int32_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
SetFlag<HardEvent::MTE3_V>(eventId1);
WaitFlag<HardEvent::MTE3_V>(eventId1);
Duplicate<dataType>(ginLocal, 0, singleLoopOutput_ * channelAligned_);
PipeBarrier<PIPE_ALL>();
for (int32_t oi = 0; oi < curLoopOutputCount; oi++) {
int32_t curOutIdx = outputIdx + oi;
CompareScalar(maskLocal, argmaxLocal, static_cast<int32_t>(curOutIdx), CMPMODE::EQ,
alignedArgmaxByteLength_ / sizeof(int32_t));
PipeBarrier<PIPE_V>();
Select(selectLocal, maskLocal, goutLocal, static_cast<dataType>(0.0), SELMODE::VSEL_TENSOR_SCALAR_MODE,
curLoopTaskCount * channelAligned_);
PipeBarrier<PIPE_V>();
GinLocalReduceSum(oi, curLoopTaskCount, ginLocal, selectLocal);
}
int32_t eventId2 = static_cast<int32_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
SetFlag<HardEvent::V_MTE3>(eventId2);
WaitFlag<HardEvent::V_MTE3>(eventId2);
SetAtomicAdd<dataType>();
DataCopyPad(ginGm_[outputIdx * channels_], ginLocal, {static_cast<uint16_t>(curLoopOutputCount), static_cast<uint32_t>(channels_ * sizeof(dataType)), 0, 0, 0});
SetAtomicNone();
if (!freeTensor) {
goutQue_.EnQue<dataType>(goutLocal);
argmaxQue_.EnQue<int32_t>(argmaxLocal);
} else {
goutQue_.FreeTensor(goutLocal);
argmaxQue_.FreeTensor(argmaxLocal);
}
}
};
extern "C" __global__ __aicore__ void roiaware_maxpool3d_grad(GM_ADDR argmax, GM_ADDR grad_out, GM_ADDR grad_in, GM_ADDR workspace,
GM_ADDR tiling) {
GET_TILING_DATA(tiling_data, tiling);
SetSysWorkspace(workspace);
if (GetSysWorkSpacePtr() == nullptr) {
return;
}
TPipe pipe;
KernelRoiawareMaxpool3dGrad<float, false> op;
op.Init(argmax, grad_out, grad_in, &pipe, &tiling_data);
op.Process();
}