* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
*/
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
using namespace AscendC;
template<typename DTYPE_F>
class KernelDeformableAggregationGrad {
public:
__aicore__ inline KernelDeformableAggregationGrad() = delete;
__aicore__ inline KernelDeformableAggregationGrad(
GM_ADDR mc_ms_feat,
GM_ADDR spatial_shape,
GM_ADDR scale_start_index,
GM_ADDR sampling_location,
GM_ADDR weights,
GM_ADDR grad_output,
GM_ADDR grad_mc_ms_feat,
GM_ADDR grad_sampling_location,
GM_ADDR grad_weights,
const DeformableAggregationGradTilingData& tiling_data,
TPipe* pipe)
: pipe_(pipe)
{
InitTask(tiling_data);
InitGM(mc_ms_feat, spatial_shape, scale_start_index,
sampling_location, weights, grad_output,
grad_mc_ms_feat, grad_sampling_location, grad_weights);
InitBuffer();
}
__aicore__ inline void Process();
private:
__aicore__ inline void InitTask(const DeformableAggregationGradTilingData& tiling)
{
usedCoreNum_ = tiling.usedCoreNum;
avgWeightNum_ = tiling.avgWeightNum;
tailWeightNum_ = tiling.tailWeightNum;
coreId = GetBlockIdx();
taskOffset = coreId * avgWeightNum_;
totalTaskNum_ = avgWeightNum_;
if (coreId == usedCoreNum_ - 1) {
totalTaskNum_ = tailWeightNum_;
}
singleProcessTaskLen_ = min(tiling.singleProcessTaskLen, totalTaskNum_);
singleProcessTaskLen_ = max(singleProcessTaskLen_, (uint32_t)1);
taskRepeatTimes = (totalTaskNum_ - 1) / singleProcessTaskLen_ + 1;
pts_ = tiling.numPoints;
cam_ = tiling.numCams;
scale_ = tiling.numScale;
group_ = tiling.numGroups;
numEmbeds = tiling.numEmbeds;
numFeat = tiling.numFeat;
numAnchors = tiling.numAnchors;
totalGroups = numEmbeds / group_;
blockSize_ = 32;
blockDataNum_ = blockSize_ / sizeof(DTYPE_F);
}
__aicore__ inline void InitGM(GM_ADDR mc_ms_feat, GM_ADDR spatial_shape, GM_ADDR scale_start_index,
GM_ADDR sampling_location, GM_ADDR weights, GM_ADDR grad_output,
GM_ADDR grad_mc_ms_feat, GM_ADDR grad_sampling_location, GM_ADDR grad_weights)
{
int64_t samplingLocationOffset = taskOffset * pts_ * cam_ * 2;
int64_t weightOffset = taskOffset * pts_ * cam_ * scale_ * group_;
mcMsFeatGm.SetGlobalBuffer((__gm__ DTYPE_F*)(mc_ms_feat));
spatialShapeGm.SetGlobalBuffer((__gm__ int32_t*)(spatial_shape));
scaleStartLocationGm.SetGlobalBuffer((__gm__ int32_t*)(scale_start_index));
samplingLocationGm.SetGlobalBuffer((__gm__ DTYPE_F*)(sampling_location) + samplingLocationOffset);
weightGm.SetGlobalBuffer((__gm__ DTYPE_F*)(weights) + weightOffset);
outputGradGm.SetGlobalBuffer((__gm__ DTYPE_F*)(grad_output) + taskOffset * numEmbeds);
gradMcMsFeatGm.SetGlobalBuffer((__gm__ DTYPE_F*)(grad_mc_ms_feat));
gradSamplingLocalGm.SetGlobalBuffer((__gm__ DTYPE_F*)(grad_sampling_location) + samplingLocationOffset);
gradWeightsGm.SetGlobalBuffer((__gm__ DTYPE_F*)(grad_weights) + weightOffset);
}
__aicore__ inline void InitBuffer()
{
uint64_t singleWeightOffset = scale_ * group_;
uint64_t samplingOffset = pts_ * cam_ * 2;
pipe_->InitBuffer(weightQue_, AlignUp(singleWeightOffset, blockDataNum_) * sizeof(DTYPE_F));
pipe_->InitBuffer(gradOutputQue_, singleProcessTaskLen_ * numEmbeds * sizeof(DTYPE_F));
pipe_->InitBuffer(scaleStartLocationQue_, AlignUp(cam_ * scale_, B32_DATA_NUM_PER_BLOCK) * sizeof(int32_t));
pipe_->InitBuffer(samplingLocationQue_, AlignUp(samplingOffset, blockDataNum_) * sizeof(DTYPE_F));
pipe_->InitBuffer(spatialShapeQue_, AlignUp(cam_ * scale_ * 2, B32_DATA_NUM_PER_BLOCK) * sizeof(int32_t));
pipe_->InitBuffer(topGradMcMsFeatQue_, numEmbeds * sizeof(DTYPE_F));
pipe_->InitBuffer(gradValueQue_, 4 * numEmbeds * sizeof(DTYPE_F));
pipe_->InitBuffer(vQue_, 4 * numEmbeds * sizeof(DTYPE_F));
pipe_->InitBuffer(featureQue_, scale_ * numEmbeds * sizeof(DTYPE_F));
pipe_->InitBuffer(gradWeightsQue_, scale_ * group_ * sizeof(DTYPE_F));
pipe_->InitBuffer(pointGradWeightQue_, 4 * numEmbeds * sizeof(DTYPE_F));
pipe_->InitBuffer(gradSamplingQue_, blockDataNum_ * sizeof(DTYPE_F));
pipe_->InitBuffer(pointGradQue_, 2 * numEmbeds * sizeof(DTYPE_F));
pipe_->InitBuffer(weightBrobQue_, scale_ * numEmbeds * sizeof(DTYPE_F));
}
__aicore__ inline void Prepare()
{
int32_t scaleStartNum = AlignUp(cam_ * scale_, B32_DATA_NUM_PER_BLOCK);
int32_t spatialShapeNum = AlignUp(cam_ * scale_ * 2, B32_DATA_NUM_PER_BLOCK);
scaleStartLocation = scaleStartLocationQue_.Get<int32_t>();
spatialShape = spatialShapeQue_.Get<int32_t>();
weight = weightQue_.Get<DTYPE_F>();
gradOutput = gradOutputQue_.Get<DTYPE_F>();
samplingLocation = samplingLocationQue_.Get<DTYPE_F>();
gradWeightsLocal = gradWeightsQue_.Get<DTYPE_F>();
gradSamplingLocal = gradSamplingQue_.Get<DTYPE_F>();
gradValueLocal = gradValueQue_.Get<DTYPE_F>();
topGradMcMsFeatLocal = topGradMcMsFeatQue_.Get<DTYPE_F>();
vLocal = vQue_.Get<DTYPE_F>();
featureLocal = featureQue_.Get<DTYPE_F>();
pointGradWeightLocal = pointGradWeightQue_.Get<DTYPE_F>();
pointGradSum = pointGradQue_.Get<DTYPE_F>();
weightBrobLocal = weightBrobQue_.Get<DTYPE_F>();
Duplicate(pointGradSum, (DTYPE_F)0, 2 * numEmbeds);
Duplicate(featureLocal, (DTYPE_F)0, scale_ * numEmbeds);
Duplicate(vLocal, (DTYPE_F)0, numEmbeds * 4);
DataCopy(scaleStartLocation, scaleStartLocationGm, scaleStartNum);
DataCopy(spatialShape, spatialShapeGm, spatialShapeNum);
}
__aicore__ inline void ProcessSingle(uint64_t taskIdx, uint32_t actualWeightNum)
{
uint64_t singleWeightOffset = scale_ * group_;
uint32_t weightCopyLen = AlignUp(singleWeightOffset, blockDataNum_);
int32_t gradOuputNum = AlignUp(actualWeightNum * numEmbeds, blockDataNum_);
int32_t samplingLocationNum = AlignUp(pts_ * cam_ * 2, blockDataNum_);
uint64_t gradOutputOffset = taskIdx * singleProcessTaskLen_ * numEmbeds;
SetFlag<HardEvent::V_MTE2>(0);
WaitFlag<HardEvent::V_MTE2>(0);
DataCopy(gradOutput, outputGradGm[gradOutputOffset], gradOuputNum);
for (int32_t weightNumId = 0; weightNumId < actualWeightNum; weightNumId++) {
int64_t curBatch = (taskOffset + taskIdx * singleProcessTaskLen_ + weightNumId) / numAnchors;
int64_t featOffset = curBatch * numFeat * numEmbeds;
uint64_t samplingLocationOffset = (taskIdx * singleProcessTaskLen_ + weightNumId) * pts_ * cam_ * 2;
DataCopy(samplingLocation, samplingLocationGm[samplingLocationOffset], samplingLocationNum);
for (int32_t ptsId = 0; ptsId < pts_; ptsId++) {
for (int32_t camId = 0; camId < cam_; camId++) {
int32_t locOffset = ptsId * cam_ + camId;
float locW = samplingLocation.GetValue(locOffset * 2);
float locH = samplingLocation.GetValue(locOffset * 2 + 1);
if (locW <= 0 || locW >= 1 || locH <=0 || locH >=1) {
continue;
}
uint64_t weightGmOffset = (((taskIdx * singleProcessTaskLen_ + weightNumId) * pts_ + ptsId) * cam_ + camId) * singleWeightOffset;
uint64_t samplingLocationCopyOutOffset = samplingLocationOffset + (ptsId * cam_ + camId) * 2;
DataCopy(weight, weightGm[weightGmOffset], weightCopyLen);
SetFlag<HardEvent::MTE2_V>(0);
WaitFlag<HardEvent::MTE2_V>(0);
uint32_t dstShape_[2] = {scale_ * group_, totalGroups};
uint32_t srcShape_[2] = {scale_ * group_, 1};
BroadCast<DTYPE_F, 2, 1>(weightBrobLocal, weight, dstShape_, srcShape_);
SetFlag<HardEvent::V_MTE2>(0);
WaitFlag<HardEvent::V_MTE2>(0);
for (int32_t scaleId = 0; scaleId < scale_; scaleId++) {
int32_t scaleStartOffset = camId * scale_ + scaleId;
int32_t scaleStartIdx = scaleStartLocation.GetValue(scaleStartOffset);
int64_t featureOffset = (int64_t)scaleStartIdx * numEmbeds;
int32_t h = spatialShape.GetValue(scaleStartOffset * 2);
int32_t w = spatialShape.GetValue(scaleStartOffset * 2 + 1);
float hIm = locH * h - (float)0.5;
float wIm = locW * w - (float)0.5;
int32_t hLow = ScalarCast<float, int32_t, AscendC::RoundMode::CAST_FLOOR>(hIm);
int32_t wLow = ScalarCast<float, int32_t, AscendC::RoundMode::CAST_FLOOR>(wIm);
int32_t hHigh = hLow + 1;
int32_t wHigh = wLow + 1;
float lh = hIm - hLow;
float lw = wIm - wLow;
float hh = 1 - lh;
float hw = 1 - lw;
int32_t wStride = numEmbeds;
int32_t hStride = w * wStride;
int32_t hLowPtrOffset = hLow * hStride;
int32_t hHighPtrOffset = hLowPtrOffset + hStride;
int32_t wLowPtrOffset = wLow * wStride;
int32_t wHighPtrOffset = wLowPtrOffset + wStride;
float w1 = hh * hw;
float w2 = hh * lw;
float w3 = lh * hw;
float w4 = lh * lw;
uint64_t ptr1 = featureOffset + hLowPtrOffset + wLowPtrOffset;
uint64_t ptr2 = featureOffset + hLowPtrOffset + wHighPtrOffset;
uint64_t ptr3 = featureOffset + hHighPtrOffset + wLowPtrOffset;
uint64_t ptr4 = featureOffset + hHighPtrOffset + wHighPtrOffset;
uint64_t weightOffset = scaleId * numEmbeds;
uint64_t gradOuputBaseOffset = weightNumId * numEmbeds;
SetFlag<HardEvent::MTE3_V>(0);
WaitFlag<HardEvent::MTE3_V>(0);
Mul(topGradMcMsFeatLocal, weightBrobLocal[weightOffset], gradOutput[gradOuputBaseOffset], numEmbeds);
Muls(gradValueLocal, topGradMcMsFeatLocal, static_cast<DTYPE_F>(w1), numEmbeds);
Muls(gradValueLocal[numEmbeds * 1], topGradMcMsFeatLocal, static_cast<DTYPE_F>(w2), numEmbeds);
Muls(gradValueLocal[numEmbeds * 2], topGradMcMsFeatLocal, static_cast<DTYPE_F>(w3), numEmbeds);
Muls(gradValueLocal[numEmbeds * 3], topGradMcMsFeatLocal, static_cast<DTYPE_F>(w4), numEmbeds);
SetFlag<HardEvent::V_MTE3>(0);
WaitFlag<HardEvent::V_MTE3>(0);
SetAtomicAdd<DTYPE_F>();
if (hLow >= 0 && wLow >=0) {
DataCopy(gradMcMsFeatGm[featOffset + ptr1], gradValueLocal, numEmbeds);
DataCopy(vLocal, mcMsFeatGm[featOffset + ptr1], numEmbeds);
}
if (hLow >= 0 && wHigh <= w - 1) {
DataCopy(gradMcMsFeatGm[featOffset + ptr2], gradValueLocal[numEmbeds * 1], numEmbeds);
DataCopy(vLocal[numEmbeds], mcMsFeatGm[featOffset + ptr2], numEmbeds);
}
if (hHigh <= h - 1 && wLow >= 0) {
DataCopy(gradMcMsFeatGm[featOffset + ptr3], gradValueLocal[numEmbeds * 2], numEmbeds);
DataCopy(vLocal[numEmbeds * 2], mcMsFeatGm[featOffset + ptr3], numEmbeds);
}
if (hHigh <= h - 1 && wHigh <= w - 1) {
DataCopy(gradMcMsFeatGm[featOffset + ptr4], gradValueLocal[numEmbeds * 3], numEmbeds);
DataCopy(vLocal[numEmbeds * 3], mcMsFeatGm[featOffset + ptr4], numEmbeds);
}
SetAtomicNone();
SetFlag<HardEvent::MTE2_V>(0);
WaitFlag<HardEvent::MTE2_V>(0);
Muls(featureLocal[weightOffset], vLocal, static_cast<DTYPE_F>(w1), numEmbeds);
Axpy(featureLocal[weightOffset], vLocal[numEmbeds], static_cast<DTYPE_F>(w2), numEmbeds);
Axpy(featureLocal[weightOffset], vLocal[numEmbeds * 2], static_cast<DTYPE_F>(w3), numEmbeds);
Axpy(featureLocal[weightOffset], vLocal[numEmbeds * 3], static_cast<DTYPE_F>(w4), numEmbeds);
Mul(featureLocal[weightOffset], featureLocal[weightOffset], gradOutput[gradOuputBaseOffset], numEmbeds);
Sub(pointGradWeightLocal, vLocal[numEmbeds * 1], vLocal, numEmbeds);
Sub(pointGradWeightLocal[numEmbeds * 2], vLocal[numEmbeds * 3], vLocal[numEmbeds * 2], numEmbeds);
Sub(pointGradWeightLocal[numEmbeds * 1], vLocal[numEmbeds * 2], vLocal, numEmbeds);
Sub(pointGradWeightLocal[numEmbeds * 3], vLocal[numEmbeds * 3], vLocal[numEmbeds * 1], numEmbeds);
Duplicate(vLocal, (DTYPE_F)0, numEmbeds * 4);
SetFlag<HardEvent::V_MTE2>(0);
WaitFlag<HardEvent::V_MTE2>(0);
Muls(pointGradWeightLocal, pointGradWeightLocal, static_cast<DTYPE_F>(hh), numEmbeds);
Axpy(pointGradWeightLocal, pointGradWeightLocal[numEmbeds * 2], static_cast<DTYPE_F>(lh), numEmbeds);
Muls(pointGradWeightLocal[numEmbeds * 1], pointGradWeightLocal[numEmbeds * 1], static_cast<DTYPE_F>(hw), numEmbeds);
Axpy(pointGradWeightLocal[numEmbeds * 1], pointGradWeightLocal[numEmbeds * 3], static_cast<DTYPE_F>(lw), numEmbeds);
Mul(pointGradWeightLocal, pointGradWeightLocal, topGradMcMsFeatLocal, numEmbeds);
Mul(pointGradWeightLocal[numEmbeds], pointGradWeightLocal[numEmbeds], topGradMcMsFeatLocal, numEmbeds);
Muls(pointGradWeightLocal, pointGradWeightLocal, (DTYPE_F)w, numEmbeds);
Muls(pointGradWeightLocal[numEmbeds], pointGradWeightLocal[numEmbeds], (DTYPE_F)h, numEmbeds);
Add(pointGradSum, pointGradSum, pointGradWeightLocal, numEmbeds * 2);
}
SetFlag<HardEvent::MTE3_V>(0);
WaitFlag<HardEvent::MTE3_V>(0);
Sum(gradWeightsLocal, featureLocal, {scale_ * group_, totalGroups, totalGroups});
Sum(gradSamplingLocal, pointGradSum, {2, numEmbeds, numEmbeds});
SetFlag<HardEvent::V_MTE3>(0);
WaitFlag<HardEvent::V_MTE3>(0);
Duplicate(featureLocal, (DTYPE_F)0, scale_ * numEmbeds);
Duplicate(pointGradSum, (DTYPE_F)0, 2 * numEmbeds);
DataCopyExtParams locationCopyParams {1, (uint32_t)(2 * sizeof(DTYPE_F)), 0, 0, 0};
DataCopyExtParams weightsCopyParams {1, (uint32_t)(scale_ * group_ * sizeof(DTYPE_F)), 0, 0, 0};
DataCopyPad(gradSamplingLocalGm[samplingLocationCopyOutOffset], gradSamplingLocal, locationCopyParams);
DataCopyPad(gradWeightsGm[weightGmOffset], gradWeightsLocal, weightsCopyParams);
}
}
}
}
private:
TPipe* pipe_;
GlobalTensor<DTYPE_F> mcMsFeatGm, samplingLocationGm, weightGm, outputGradGm;
GlobalTensor<DTYPE_F> gradMcMsFeatGm, gradSamplingLocalGm, gradWeightsGm;
GlobalTensor<int32_t> spatialShapeGm, scaleStartLocationGm;
TBuf<TPosition::VECCALC> weightQue_, gradOutputQue_, samplingLocationQue_, scaleStartLocationQue_, spatialShapeQue_;
TBuf<TPosition::VECCALC> gradWeightsQue_, gradSamplingQue_, gradValueQue_;
TBuf<TPosition::VECCALC> topGradMcMsFeatQue_, vQue_, featureQue_, pointGradWeightQue_, pointGradQue_, weightBrobQue_;
LocalTensor<int32_t> scaleStartLocation, spatialShape;
LocalTensor<DTYPE_F> weight, gradOutput, samplingLocation;
LocalTensor<DTYPE_F> gradWeightsLocal, gradSamplingLocal, gradValueLocal;
LocalTensor<DTYPE_F> topGradMcMsFeatLocal, vLocal, featureLocal, pointGradWeightLocal, pointGradSum, weightBrobLocal;
uint32_t usedCoreNum_, avgWeightNum_, tailWeightNum_, coreId;
uint32_t totalTaskNum_, singleProcessTaskLen_, taskRepeatTimes;
uint32_t pts_, cam_, scale_, group_, numEmbeds, numFeat, numAnchors, totalGroups;
uint32_t blockSize_, blockDataNum_;
int64_t taskOffset;
};
template<typename DTYPE_F>
__aicore__ inline void KernelDeformableAggregationGrad<DTYPE_F>::Process()
{
Prepare();
for (uint32_t i = 0; i < taskRepeatTimes; ++i) {
uint32_t actualWeightNum = singleProcessTaskLen_;
if (unlikely(i == taskRepeatTimes - 1)) {
actualWeightNum = (totalTaskNum_ - 1) % singleProcessTaskLen_ + 1;
}
ProcessSingle(i, actualWeightNum);
}
}
extern "C" __global__ __aicore__ void deformable_aggregation_grad(
GM_ADDR mc_ms_feat,
GM_ADDR spatial_shape,
GM_ADDR scale_start_index,
GM_ADDR sampling_location,
GM_ADDR weights,
GM_ADDR grad_output,
GM_ADDR grad_mc_ms_feat,
GM_ADDR grad_sampling_location,
GM_ADDR grad_weights,
GM_ADDR workspace,
GM_ADDR tiling)
{
GET_TILING_DATA(tiling_data, tiling);
TPipe pipe;
KernelDeformableAggregationGrad<DTYPE_MC_MS_FEAT> op(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
grad_output,
grad_mc_ms_feat,
grad_sampling_location,
grad_weights,
tiling_data,
&pipe
);
op.Process();
}