#include "bev_pool_v2.h"
using namespace AscendC;
namespace BEVPoolV2 {
template<typename T, bool Align32B>
__aicore__ inline void BEVPoolV2GradKernel<T, Align32B>::DoProcess()
{
LocalTensor<T> gradFeatT = gradFeatQue_.AllocTensor<T>();
Duplicate(gradFeatT, T(0.f), this->alignUpCCount_);
for (int32_t i = 0; i < this->length_; ++i) {
this->depthOffset_ = this->rDGm_.GetValue(this->start_ + i);
this->outOffset_ = this->rBGm_.GetValue(this->start_ + i) * this->stride0_;
LocalTensor<T> gradOutT = gradOutQue_.AllocTensor<T>();
DataCopy(gradOutT, this->gOGm_[this->outOffset_], this->cpFeatParams_);
gradOutQue_.EnQue(gradOutT);
gradOutT = gradOutQue_.DeQue<T>();
if (i == this->length_ - 1) {
this->featOffset_ = this->rFGm_.GetValue(this->start_ + i) * this->stride0_;
LocalTensor<T> featT = this->featQue_.template AllocTensor<T>();
DataCopy(featT, this->fGm_[this->featOffset_], this->cpFeatParams_);
this->featQue_.EnQue(featT);
featT = this->featQue_.template DeQue<T>();
Mul(featT, gradOutT, featT, this->alignUpCCount_);
LocalTensor<T> gradDepthT = gradDepthQue_.AllocTensor<T>();
ReduceSum(gradDepthT, featT, workT_, this->stride0_);
this->featQue_.FreeTensor(featT);
gradDepthQue_.EnQue(gradDepthT);
gradDepthT = gradDepthQue_.DeQue<T>();
DataCopyPad(gDGm_[this->depthOffset_], gradDepthT, this->cpDepthParams_);
gradDepthQue_.FreeTensor(gradDepthT);
}
T depth = this->dGm_.GetValue(this->depthOffset_);
Muls(gradOutT, gradOutT, depth, this->alignUpCCount_);
Add(gradFeatT, gradFeatT, gradOutT, this->alignUpCCount_);
this->featQue_.FreeTensor(gradOutT);
}
gradFeatQue_.EnQue(gradFeatT);
gradFeatT = gradFeatQue_.DeQue<T>();
int32_t featOffset = this->rFGm_.GetValue(this->start_) * this->stride0_;
if (Align32B) {
DataCopy(this->gFGm_[featOffset], gradFeatT, this->cpFeatParams_);
} else {
DataCopyPad(this->gFGm_[featOffset], gradFeatT, this->cpPadParams_);
}
gradFeatQue_.FreeTensor(gradFeatT);
}
}
extern "C" __global__ __aicore__ void bev_pool_v2_grad(GM_ADDR gradOut, GM_ADDR depth, GM_ADDR feat, GM_ADDR ranksDepth,
GM_ADDR ranksFeat, GM_ADDR ranksBev, GM_ADDR intervalLengths, GM_ADDR intervalStarts, GM_ADDR gradDepth,
GM_ADDR gradFeat, GM_ADDR workspace, GM_ADDR tiling)
{
GET_TILING_DATA(bevPoolTiling, tiling);
int32_t blkIdx = GetBlockIdx();
int32_t c = bevPoolTiling.stride0;
#if __CCE_AICORE__ == 220
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
#endif
if (TILING_KEY_IS(3)) {
const int32_t cBytes = c * sizeof(float);
const int32_t divCeilC = DivCeil(cBytes, ONE_BLK_SIZE);
const int32_t alignUpCBytes = divCeilC * ONE_BLK_SIZE;
BEVPoolV2::BEVPoolV2GradKernel<float, true> op(blkIdx, cBytes, divCeilC, alignUpCBytes, gradOut, depth, feat,
ranksDepth, ranksFeat, ranksBev, intervalLengths, intervalStarts, gradDepth, gradFeat, bevPoolTiling);
op.Process();
} else if (TILING_KEY_IS(2)) {
const int32_t cBytes = c * sizeof(float);
const int32_t divCeilC = DivCeil(cBytes, ONE_BLK_SIZE);
const int32_t alignUpCBytes = divCeilC * ONE_BLK_SIZE;
BEVPoolV2::BEVPoolV2GradKernel<float, false> op(blkIdx, cBytes, divCeilC, alignUpCBytes, gradOut, depth, feat,
ranksDepth, ranksFeat, ranksBev, intervalLengths, intervalStarts, gradDepth, gradFeat, bevPoolTiling);
op.Process();
}
PipeBarrier<PIPE_ALL>();
}