/*
Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*/
#include "kernel_operator.h"
using namespace AscendC;

namespace {
constexpr int32_t ONE_DIM = 1;
constexpr int32_t FOUR_DIM = 4;
constexpr int32_t PCR_DIM = 2;
constexpr int32_t B_DIM = 3;
constexpr int32_t FLOAT_EIGHT_NUM = 8;
constexpr int32_t FLOAT_SIXTEEN_NUM = 16;
constexpr int32_t FLOAT_THIRTY_TWO_NUM = 32;
constexpr int32_t FLOAT_SIXTY_FOUR_NUM = 64;
constexpr int32_t FLOAT_ONE_HUNDRED_TWENTY_EIGHT_NUM = 128;
constexpr int32_t COPYLEN_SIX = 6;
constexpr int32_t COPYLEN_SEVEN = 7;
} // namespace

class KernelGaussian {
public:
    __aicore__ inline KernelGaussian() = delete;

    __aicore__ inline KernelGaussian(
        GM_ADDR gt_boxes,
        GM_ADDR center_int,
        GM_ADDR radius,
        GM_ADDR mask,
        GM_ADDR ind,
        GM_ADDR ret_boxes,
        const GaussianTilingData& tiling_data,
        TPipe* pipe)
        : pipe_(pipe)
    {
        InitTask(tiling_data);
        InitGM(gt_boxes, center_int, radius, mask, ind, ret_boxes);
        InitBuffer();
    }

    __aicore__ inline void Process();

private:
    __aicore__ inline void InitTask(const GaussianTilingData& tiling)
    {
        coreId = GetBlockIdx();
        usedCoreNum = tiling.usedCoreNum;
        numObjs = tiling.numObjs;
        totalCoreTaskNum = tiling.totalCoreTaskNum;
        coreProcessTaskNum = tiling.coreProcessTaskNum;
        lastCoreProcessTaskNum = tiling.lastCoreProcessTaskNum;
        singleProcessTaskNum = tiling.singleProcessTaskNum;
        featureMapSizeX = tiling.featureMapSizeX;
        featureMapSizeY = tiling.featureMapSizeY;
        voxelXSize = tiling.voxelXSize;
        voxelYSize = tiling.voxelYSize;
        prcX = tiling.prcX;
        prcY = tiling.prcY;
        featureMapStride = tiling.featureMapStride;
        numMaxObjs = tiling.numMaxObjs;
        minRadius = tiling.minRadius;
        minOverLap = tiling.minOverLap;
        dimSize = tiling.dimSize;
        normBbox = tiling.normBbox;
        flipAngle = tiling.flipAngle;
        curCoreTaskNum = coreProcessTaskNum;
        if (unlikely(coreId == usedCoreNum - 1)) {
            curCoreTaskNum = lastCoreProcessTaskNum;
        }
        if (singleProcessTaskNum == 0) {
            return;
        }
        coreRepeatTimes = (curCoreTaskNum - 1) / singleProcessTaskNum + 1;
        a1 = ONE_DIM;
        a2 = FOUR_DIM;
        a3 = FOUR_DIM * minOverLap;
    }

    __aicore__ inline void InitGM(GM_ADDR gt_boxes,
                                  GM_ADDR center_int,
                                  GM_ADDR radius,
                                  GM_ADDR mask,
                                  GM_ADDR ind,
                                  GM_ADDR ret_boxes)
    {
        gtBoxesGm.SetGlobalBuffer((__gm__ float*)(gt_boxes));
        centerIntGm.SetGlobalBuffer((__gm__ int32_t*)(center_int));
        radiusGm.SetGlobalBuffer((__gm__ int32_t*)(radius));
        maskGm.SetGlobalBuffer((__gm__ uint8_t*)(mask));
        indGm.SetGlobalBuffer((__gm__ int32_t*)(ind));
        retBoxesGm.SetGlobalBuffer((__gm__ float*)(ret_boxes));
    }

     __aicore__ inline void InitBuffer()
    {
        pipe_->InitBuffer(gtBoxesQue_, singleProcessTaskNum * dimSize * sizeof(float));
        pipe_->InitBuffer(pcrUB, singleProcessTaskNum * PCR_DIM * sizeof(float));
        pipe_->InitBuffer(voxelSizeUB, singleProcessTaskNum * PCR_DIM * sizeof(float));
        pipe_->InitBuffer(featureMapStrideUB, singleProcessTaskNum * sizeof(float));
        pipe_->InitBuffer(coordUB, singleProcessTaskNum * PCR_DIM * sizeof(float));
        pipe_->InitBuffer(centerIntUB, singleProcessTaskNum * PCR_DIM * sizeof(int32_t));
        pipe_->InitBuffer(centerFloatUB, singleProcessTaskNum * PCR_DIM * sizeof(float));
        pipe_->InitBuffer(dxUb, singleProcessTaskNum * sizeof(float));
        pipe_->InitBuffer(dyUb, singleProcessTaskNum * sizeof(float));
        pipe_->InitBuffer(sumDxDyUB, singleProcessTaskNum * sizeof(float));
        pipe_->InitBuffer(mulDxDyUB, singleProcessTaskNum * sizeof(float));
        pipe_->InitBuffer(bUB, singleProcessTaskNum * B_DIM * sizeof(int32_t));
        pipe_->InitBuffer(cUB, singleProcessTaskNum * B_DIM * sizeof(float));
        pipe_->InitBuffer(sqrtUB, singleProcessTaskNum * B_DIM * sizeof(float));
        pipe_->InitBuffer(rUB, singleProcessTaskNum * sizeof(float));
        pipe_->InitBuffer(radiusUB, singleProcessTaskNum * sizeof(int32_t));
        pipe_->InitBuffer(cmpUB, singleProcessTaskNum * sizeof(int32_t));
        pipe_->InitBuffer(maskHalfUB, singleProcessTaskNum * sizeof(float));
        pipe_->InitBuffer(maskUB, singleProcessTaskNum * sizeof(int32_t));
        pipe_->InitBuffer(indUB, singleProcessTaskNum * sizeof(int32_t));
        pipe_->InitBuffer(indFloatUB, singleProcessTaskNum * sizeof(float));
        pipe_->InitBuffer(retBoxesUB, singleProcessTaskNum * (dimSize + 1) * sizeof(float));
    }

    __aicore__ inline void ProcessSingle(uint64_t taskIdx, uint32_t actualTaskNum)
    {
        uint64_t singleBaseGmOffset =  coreId * coreProcessTaskNum + taskIdx * singleProcessTaskNum;
        uint32_t copyLen = AlignUp(actualTaskNum, FLOAT_EIGHT_NUM);
        uint32_t halfCopyLen = AlignUp(actualTaskNum, FLOAT_SIXTEEN_NUM);
        uint32_t uintCopyLen = AlignUp(actualTaskNum, FLOAT_THIRTY_TWO_NUM);
        LocalTensor<float> gtBoxes = gtBoxesQue_.Get<float>();
        LocalTensor<float> pcr = pcrUB.Get<float>();
        LocalTensor<float> voxelSize = voxelSizeUB.Get<float>();
        LocalTensor<float> featureMapStride_ = featureMapStrideUB.Get<float>();
        LocalTensor<float> coord = coordUB.Get<float>();
        LocalTensor<int32_t> centerInt = centerIntUB.Get<int32_t>();
        LocalTensor<float> centerFloat = centerFloatUB.Get<float>();
        LocalTensor<float> dx = dxUb.Get<float>();
        LocalTensor<float> dy = dyUb.Get<float>();
        LocalTensor<float> sumDxDy = sumDxDyUB.Get<float>();
        LocalTensor<float> mulDxDy = mulDxDyUB.Get<float>();
        LocalTensor<float> bLocal = bUB.Get<float>();
        LocalTensor<float> cLocal = cUB.Get<float>();
        LocalTensor<float> sqrLocal = sqrtUB.Get<float>();
        LocalTensor<float> rLocal = rUB.Get<float>();
        LocalTensor<int32_t> radiusLocal = radiusUB.Get<int32_t>();
        LocalTensor<uint8_t> cmpLocal = cmpUB.Get<uint8_t>();
        LocalTensor<half> maskHalf = maskHalfUB.Get<half>();
        LocalTensor<uint8_t> mask = maskUB.Get<uint8_t>();
        LocalTensor<int32_t> indLocal = indUB.Get<int32_t>();
        LocalTensor<float> indFloatLocal = indFloatUB.Get<float>();
        LocalTensor<float> retBoxes = retBoxesUB.Get<float>();

        Duplicate(pcr, prcX, copyLen);
        Duplicate(pcr[copyLen], prcY, copyLen);
        Duplicate(voxelSize, voxelXSize, copyLen);
        Duplicate(voxelSize[copyLen], voxelYSize, copyLen);
        Duplicate(featureMapStride_, static_cast<float>(featureMapStride), copyLen);
        Duplicate(maskHalf, static_cast<half>(1.0), halfCopyLen);
        for (uint32_t i = 0; i < dimSize; i++) {
            DataCopy(gtBoxes[copyLen * i], gtBoxesGm[singleBaseGmOffset + numObjs * i], copyLen);
        }
        PipeBarrier<PIPE_ALL>();
        Sub(coord, gtBoxes, pcr, copyLen);
        Sub(coord[copyLen], gtBoxes[copyLen], pcr[copyLen], copyLen);
        Div(coord, coord, voxelSize, copyLen);
        Div(coord[copyLen], coord[copyLen], voxelSize[copyLen], copyLen);
        Div(coord, coord, featureMapStride_, copyLen);
        Div(coord[copyLen], coord[copyLen], featureMapStride_, copyLen);
        Cast(centerInt,  coord, RoundMode::CAST_TRUNC, copyLen * PCR_DIM);
        Cast(centerFloat, coord, RoundMode::CAST_TRUNC, copyLen * PCR_DIM);
        Div(dx, gtBoxes[copyLen * B_DIM], voxelSize, copyLen);
        Div(dy, gtBoxes[copyLen * FOUR_DIM], voxelSize[copyLen], copyLen);
        Div(dx, dx, featureMapStride_, copyLen);
        Div(dy, dy, featureMapStride_, copyLen);
        Add(sumDxDy, dx, dy, copyLen);
        Mul(mulDxDy, dx, dy, copyLen);
        Muls(bLocal, sumDxDy, 1.0f, copyLen);
        Muls(bLocal[copyLen], sumDxDy, 2.0f, copyLen);
        Muls(bLocal[copyLen * PCR_DIM], sumDxDy, (-2.0f * minOverLap), copyLen);
        Muls(cLocal, mulDxDy, (1.0f - minOverLap) / (1.0f + minOverLap), copyLen);
        Muls(cLocal[copyLen], mulDxDy, (1.0f - minOverLap), copyLen);
        Muls(cLocal[copyLen * PCR_DIM], mulDxDy, (minOverLap - 1.0f), copyLen);
        Muls(cLocal, cLocal, 4.0f, copyLen * B_DIM);
        Muls(cLocal, cLocal, a1, copyLen);
        Muls(cLocal[copyLen], cLocal[copyLen], a2, copyLen);
        Muls(cLocal[copyLen * PCR_DIM], cLocal[copyLen * PCR_DIM], a3, copyLen);
        Mul(sqrLocal, bLocal, bLocal, copyLen * B_DIM);
        Sub(sqrLocal, sqrLocal, cLocal, copyLen * B_DIM);
        Sqrt(sqrLocal, sqrLocal, copyLen * B_DIM);
        Add(sqrLocal, sqrLocal, bLocal, copyLen * B_DIM);
        Muls(sqrLocal, sqrLocal, 0.5f, copyLen * B_DIM);
        Min(rLocal, sqrLocal, sqrLocal[copyLen], copyLen);
        Min(rLocal, rLocal, sqrLocal[copyLen * PCR_DIM], copyLen);
        Cast(radiusLocal, rLocal, RoundMode::CAST_TRUNC, copyLen);
        Maxs(radiusLocal, radiusLocal, minRadius, copyLen);
        // mask
        CompareScalar(cmpLocal, dx, 0.0f, CMPMODE::GT, AlignUp(copyLen, FLOAT_SIXTY_FOUR_NUM));
        Select(maskHalf, cmpLocal, maskHalf, (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, copyLen);
        CompareScalar(cmpLocal, dy, 0.0f, CMPMODE::GT, AlignUp(copyLen, FLOAT_SIXTY_FOUR_NUM));
        Select(maskHalf, cmpLocal, maskHalf, (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, copyLen);
        CompareScalar(cmpLocal, centerFloat, 0.0f, CMPMODE::GE, AlignUp(copyLen, FLOAT_SIXTY_FOUR_NUM));
        Select(maskHalf, cmpLocal, maskHalf, (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, copyLen);
        CompareScalar(cmpLocal, centerFloat[copyLen], 0.0f, CMPMODE::GE, AlignUp(copyLen, FLOAT_SIXTY_FOUR_NUM));
        Select(maskHalf, cmpLocal, maskHalf, (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, copyLen);
        CompareScalar(cmpLocal, centerFloat, static_cast<float>(featureMapSizeX), CMPMODE::LT, AlignUp(copyLen, FLOAT_SIXTY_FOUR_NUM));
        Select(maskHalf, cmpLocal, maskHalf, (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, copyLen);
        CompareScalar(cmpLocal, centerFloat[copyLen], static_cast<float>(featureMapSizeY), CMPMODE::LT, AlignUp(copyLen, FLOAT_SIXTY_FOUR_NUM));
        Select(maskHalf, cmpLocal, maskHalf, (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, copyLen);
        Cast(mask, maskHalf, RoundMode::CAST_NONE, copyLen);
        CompareScalar(cmpLocal, maskHalf, (half)1.0, CMPMODE::EQ, AlignUp(copyLen, FLOAT_ONE_HUNDRED_TWENTY_EIGHT_NUM));
        // ind
        Muls(indLocal, centerInt[copyLen], featureMapSizeX, copyLen);
        Add(indLocal, indLocal, centerInt, copyLen);
        Cast(indFloatLocal, indLocal, RoundMode::CAST_TRUNC, copyLen);
        Select(indFloatLocal, cmpLocal, indFloatLocal, 0.0f, SELMODE::VSEL_TENSOR_SCALAR_MODE, copyLen);
        Cast(indLocal, indFloatLocal, RoundMode::CAST_TRUNC, copyLen);
        // ret
        for (uint32_t i = 0; i < PCR_DIM; i++) {
            Sub(retBoxes[copyLen * i], coord[copyLen * i], centerFloat[copyLen * i], copyLen);
        }
        Muls(retBoxes[copyLen * PCR_DIM], gtBoxes[copyLen * PCR_DIM], 1.0f, copyLen * FOUR_DIM);
        if (normBbox == true) {
            Log(retBoxes[copyLen * B_DIM], retBoxes[copyLen * B_DIM], copyLen * B_DIM);
        }
        Sin(retBoxes[copyLen * COPYLEN_SIX], gtBoxes[copyLen * COPYLEN_SIX], copyLen);
        Cos(retBoxes[copyLen * COPYLEN_SEVEN], gtBoxes[copyLen * COPYLEN_SIX], copyLen);
        if (flipAngle == true) {
            Cos(retBoxes[copyLen * COPYLEN_SIX], gtBoxes[copyLen * COPYLEN_SIX], copyLen);
            Sin(retBoxes[copyLen * COPYLEN_SEVEN], gtBoxes[copyLen * COPYLEN_SIX], copyLen);
        }
        for (uint32_t i = 7; i < dimSize ; i++) {
            Muls(retBoxes[copyLen * (i + 1)], gtBoxes[copyLen * i], 1.0f, copyLen);
        }
        for (uint32_t i = 0; i < dimSize + 1; i++) {
            Select(retBoxes[copyLen * i], cmpLocal, retBoxes[copyLen * i], 0.0f, SELMODE::VSEL_TENSOR_SCALAR_MODE, copyLen);
        }
        PipeBarrier<PIPE_ALL>();
        DataCopyExtParams centerIntCopyParams {1, (uint16_t)(actualTaskNum * sizeof(int32_t)), 0, 0, 0};
        DataCopyExtParams radiusCopyParams {1, (uint16_t)(actualTaskNum * sizeof(int32_t)), 0, 0, 0};
        DataCopyExtParams maskCopyParams {1, (uint16_t)(actualTaskNum * sizeof(uint8_t)), 0, 0, 0};
        DataCopyExtParams indCopyParams {1, (uint16_t)(actualTaskNum * sizeof(int32_t)), 0, 0, 0};
        DataCopyExtParams retCopyParams {1, (uint16_t)(actualTaskNum * sizeof(float)), 0, 0, 0};

        DataCopyPad(centerIntGm[singleBaseGmOffset], centerInt, centerIntCopyParams);
        DataCopyPad(centerIntGm[singleBaseGmOffset + totalCoreTaskNum], centerInt[copyLen], centerIntCopyParams);
        DataCopyPad(radiusGm[singleBaseGmOffset], radiusLocal, radiusCopyParams);
        DataCopyPad(maskGm[singleBaseGmOffset], mask, maskCopyParams);
        DataCopyPad(indGm[singleBaseGmOffset], indLocal, indCopyParams);
        for (uint32_t i = 0; i < dimSize + 1; i++) {
            DataCopyPad(retBoxesGm[singleBaseGmOffset + numMaxObjs * i], retBoxes[copyLen * i], retCopyParams);
        }
        PipeBarrier<PIPE_ALL>();
    }

private:
    TPipe* pipe_;
    TBuf<TPosition::VECCALC> gtBoxesQue_, pcrUB, voxelSizeUB, featureMapStrideUB, coordUB;
    TBuf<TPosition::VECCALC> centerIntUB, centerFloatUB, dxUb, dyUb, sumDxDyUB, mulDxDyUB;
    TBuf<TPosition::VECCALC> bUB, cUB, sqrtUB, rUB, radiusUB, cmpUB;
    TBuf<TPosition::VECCALC> maskHalfUB, maskUB, indUB, indFloatUB, retBoxesUB;
    GlobalTensor<float> gtBoxesGm, retBoxesGm;
    GlobalTensor<int32_t> centerIntGm, radiusGm, indGm;
    GlobalTensor<uint8_t> maskGm;
    float a1, a2, a3;
    float prcX, prcY, voxelXSize, voxelYSize, minOverLap;
    int32_t numMaxObjs, numObjs, featureMapStride, minRadius, featureMapSizeX, featureMapSizeY, dimSize;
    bool normBbox, flipAngle;
    int32_t coreId, usedCoreNum, totalCoreTaskNum, coreProcessTaskNum;
    int32_t lastCoreProcessTaskNum, singleProcessTaskNum, curCoreTaskNum, coreRepeatTimes;
};

__aicore__ inline void KernelGaussian::Process()
{
    for (uint32_t i = 0; i < coreRepeatTimes; ++i) {
        uint32_t actualTaskNum = singleProcessTaskNum;
        if (unlikely(i == coreRepeatTimes - 1)) {
            actualTaskNum = (curCoreTaskNum - 1) % singleProcessTaskNum + 1;
        }
        ProcessSingle(i, actualTaskNum);
        PipeBarrier<PIPE_ALL>();
    }
}

extern "C" __global__ __aicore__ void gaussian(GM_ADDR gt_boxes, GM_ADDR center_int,
                                               GM_ADDR radius, GM_ADDR mask,
                                               GM_ADDR ind, GM_ADDR ret_boxes,
                                               GM_ADDR workspace, GM_ADDR tiling)
{
    GET_TILING_DATA(tiling_data, tiling);
    TPipe pipe;
    if (GetSysWorkSpacePtr() == nullptr) {
        return;
    }
    KernelGaussian op(
        gt_boxes,
        center_int,
        radius,
        mask,
        ind,
        ret_boxes,
        tiling_data,
        &pipe
    );
    op.Process();
}