* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*/
#include <cmath>
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
using namespace AscendC;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t SIZE_OF_FP32 = 4;
constexpr int32_t BLOCK_ALIGN = 32 / SIZE_OF_FP32;
enum borderMode {top = 0, left = 1, bottom = 2, right = 3};
class KernelBorderAlignGrad {
public:
__aicore__ inline KernelBorderAlignGrad() {}
__aicore__ inline void Init(
GM_ADDR gradOut, GM_ADDR boxes, GM_ADDR argmaxIdx, GM_ADDR gradInput, const BorderAlignGradTilingData* tilingData)
{
channels = tilingData->channels;
boxSize = tilingData->boxSize;
height = tilingData->height;
width = tilingData->width;
poolSize = tilingData->poolSize;
batchSize = tilingData->batchSize;
coreCompNum = tilingData->coreCompNum;
taskLast = tilingData->taskLast;
ASSERT(GetBlockNum() != 0 && "block dim can not be zero!");
gradOutLength = batchSize * channels * boxSize * 4;
argmaxIdxLength = batchSize * channels * boxSize * 4;
boxesLength = batchSize * boxSize * 4;
gradInputLength = batchSize * 4 * channels * boxSize;
argmaxIdxGm.SetGlobalBuffer((__gm__ DTYPE_ARGMAXIDX *)argmaxIdx, argmaxIdxLength);
gradOutGm.SetGlobalBuffer((__gm__ DTYPE_GRADOUT *)gradOut, gradOutLength);
boxesGm.SetGlobalBuffer((__gm__ DTYPE_BOXES *)boxes, boxesLength);
gradInputGm.SetGlobalBuffer((__gm__ DTYPE_GRADINPUT *)gradInput, gradInputLength);
pipe.InitBuffer(inQueueGradOut, BUFFER_NUM, 4 * compNum * sizeof(float));
pipe.InitBuffer(inQueueArgmaxIdx, BUFFER_NUM, 4 * compNum * sizeof(int32_t));
pipe.InitBuffer(inQueueBoxes, BUFFER_NUM, BLOCK_ALIGN);
pipe.InitBuffer(outQueueGradInput, BUFFER_NUM, BLOCK_ALIGN);
}
__aicore__ inline void Process()
{
int64_t offset = coreCompNum * GetBlockIdx() + taskLast;
if (GetBlockIdx() < taskLast) {
coreCompNum = coreCompNum + 1;
offset = coreCompNum * GetBlockIdx();
}
int64_t lastNum = coreCompNum % compNum;
int64_t loopTimes = coreCompNum / compNum;
for (int64_t currentLoop = 0; currentLoop < loopTimes; currentLoop++) {
ComputeAndCopyOut(currentLoop * compNum, offset, compNum);
}
if (lastNum != 0) {
ComputeAndCopyOut(loopTimes * compNum, offset, lastNum);
}
}
__aicore__ inline void ComputeAndCopyOut(int64_t index, int64_t offset, int64_t taskNum)
{
LocalTensor<float> gradOutLocal = inQueueGradOut.AllocTensor<float>();
LocalTensor<int32_t> argmaxIdxLocal = inQueueArgmaxIdx.AllocTensor<int32_t>();
LocalTensor<float> boxesLocal = inQueueBoxes.AllocTensor<float>();
LocalTensor<float> gradInputLocal = outQueueGradInput.AllocTensor<float>();
DataCopy(gradOutLocal, gradOutGm[(static_cast<int64_t>(offset) + index) * 4], 4 * compNum);
DataCopy(argmaxIdxLocal, argmaxIdxGm[(static_cast<int64_t>(offset) + index) * 4], 4 * compNum);
PipeBarrier<PIPE_ALL>();
for (int64_t currentTask = 0; currentTask < taskNum; currentTask++) {
int64_t batchIdx = (offset + index + currentTask) / (channels * boxSize);
int64_t boxIdx = (offset + index + currentTask) % boxSize + batchIdx * boxSize;
DataCopy(boxesLocal, boxesGm[boxIdx * 4], 8);
PipeBarrier<PIPE_ALL>();
int64_t channelsIdx = (offset + index + currentTask) / boxSize % channels;
float boxWidth;
float boxHeight;
float stride;
float xStride;
float yStride;
float x;
float y;
float w1;
float w2;
float w3;
float w4;
int32_t xLow;
int32_t xHigh;
int32_t yLow;
int32_t yHigh;
boxWidth = boxesLocal.GetValue(2) - boxesLocal.GetValue(0);
boxHeight = boxesLocal.GetValue(3) - boxesLocal.GetValue(1);
for (int32_t i = 0; i < 4; i++) {
float gradOutput = gradOutLocal.GetValue(4 * currentTask + i);
int32_t offsetArgmaxIdx = argmaxIdxLocal.GetValue(4 * currentTask + i);
switch (i) {
case borderMode::top:
stride = boxWidth / poolSize;
xStride = stride;
yStride = 0;
break;
case borderMode::left:
stride = boxHeight / poolSize;
xStride = 0;
yStride = stride;
break;
case borderMode::bottom:
stride = boxWidth / poolSize;
xStride = -stride;
yStride = 0;
break;
case borderMode::right:
stride = boxHeight / poolSize;
xStride = 0;
yStride = -stride;
break;
default:
break;
}
x = boxesLocal.GetValue((i / 2 * 2));
y = boxesLocal.GetValue((i / 2 * 2 + 1));
x += xStride * float(offsetArgmaxIdx);
y += yStride * float(offsetArgmaxIdx);
if (y < -1.0f || y > height || x < -1.0f || x > width) {
w1 = w2 = w3 = w4 = 0.0;
xLow = xHigh = yLow = yHigh = -1;
continue;
}
if (y <= 0.0f) {
y = 0;
}
if (x <= 0.0f) {
x = 0;
}
yLow = AscendC::ScalarCast<float, int32_t, AscendC::RoundMode::CAST_FLOOR>(y);
xLow = AscendC::ScalarCast<float, int32_t, AscendC::RoundMode::CAST_FLOOR>(x);
if (yLow >= height - 1) {
yHigh = yLow = height - 1;
y = static_cast<float>(yLow);
} else {
yHigh = yLow + 1;
}
if (xLow >= width - 1) {
xHigh = xLow = width - 1;
x = static_cast<float>(xLow);
} else {
xHigh = xLow + 1;
}
float ly = y - yLow;
float lx = x - xLow;
float hy = 1.0f - ly;
float hx = 1.0f - lx;
w1 = hy * hx;
w2 = hy * lx;
w3 = ly * hx;
w4 = ly * lx;
int64_t dstIdx1 = (batchIdx * channels * 4 + i * channels + channelsIdx) * height * width + yLow * width + xLow;
float values1 = w1 * gradOutput;
int64_t dstIdx2 = (batchIdx * channels * 4 + i * channels + channelsIdx) * height * width + yLow * width + xHigh;
float values2 = w2 * gradOutput;
int64_t dstIdx3 = (batchIdx * channels * 4 + i * channels + channelsIdx) * height * width + yHigh * width + xLow;
float values3 = w3 * gradOutput;
int64_t dstIdx4 = (batchIdx * channels * 4 + i * channels + channelsIdx) * height * width + yHigh * width + xHigh;
float values4 = w4 * gradOutput;
AscendC::SetAtomicAdd<float>();
gradInputLocal.SetValue(0, values1);
DataCopyExtParams outCopyParams {1, static_cast<uint32_t>(sizeof(float)), 0, 0, 0};
DataCopyPad(gradInputGm[dstIdx1], gradInputLocal, outCopyParams);
gradInputLocal.SetValue(0, values2);
DataCopyPad(gradInputGm[dstIdx2], gradInputLocal, outCopyParams);
gradInputLocal.SetValue(0, values3);
DataCopyPad(gradInputGm[dstIdx3], gradInputLocal, outCopyParams);
gradInputLocal.SetValue(0, values4);
DataCopyPad(gradInputGm[dstIdx4], gradInputLocal, outCopyParams);
AscendC::SetAtomicNone();
PipeBarrier<PIPE_ALL>();
}
}
inQueueGradOut.FreeTensor(gradOutLocal);
inQueueArgmaxIdx.FreeTensor(argmaxIdxLocal);
inQueueBoxes.FreeTensor(boxesLocal);
outQueueGradInput.FreeTensor(gradInputLocal);
}
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGradOut;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueArgmaxIdx;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBoxes;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueGradInput;
GlobalTensor<int32_t> argmaxIdxGm;
GlobalTensor<float> gradOutGm;
GlobalTensor<float> boxesGm;
GlobalTensor<float> gradInputGm;
int64_t gradOutLength;
int64_t argmaxIdxLength;
int64_t boxesLength;
int64_t gradInputLength;
int32_t channels;
int32_t boxSize;
int32_t height;
int32_t width;
int64_t compNum = 128;
int32_t poolSize;
int64_t batchSize;
int64_t coreCompNum;
int64_t taskLast;
DataCopyExtParams outCopyParams;
};
extern "C" __global__ __aicore__ void border_align_grad(
GM_ADDR gradOut, GM_ADDR boxes, GM_ADDR argmaxIdx, GM_ADDR gradInput, GM_ADDR workspace, GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
KernelBorderAlignGrad op;
op.Init(gradOut, boxes, argmaxIdx, gradInput, &tilingData);
op.Process();
}