/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
 *
 */
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
using namespace AscendC;

template<typename DTYPE_F, typename DTYPE_I>
class KernelDeformableAggregation {
public:
    __aicore__ inline KernelDeformableAggregation() {}
    __aicore__ inline void Init(GM_ADDR mc_ms_feat, GM_ADDR spatial_shape, GM_ADDR scale_start_index,
        GM_ADDR sampling_location, GM_ADDR weights, GM_ADDR out, const DeformableAggregationTilingData* tiling_data, TPipe *tmpPipe)
    {
        pipe_ = tmpPipe;
        bs_ = tiling_data->bs;
        numFeats_ = tiling_data->numFeats;
        numEmbeds_ = tiling_data->numEmbeds;
        numAnchors_ = tiling_data->numAnchors;
        numPoints_ = tiling_data->numPoints;
        numCams_ = tiling_data->numCams;
        numScales_ = tiling_data->numScales;
        numGroups_ = tiling_data->numGroups;
        cAligned_ = tiling_data->cAligned;
        coreNum_ = tiling_data->coreNum;
        numChannels_ = numEmbeds_ / numGroups_;

        uint32_t blockSize_ = 32;
        blockAlignFloat_ = blockSize_ / sizeof(DTYPE_F);
        blockAlignInt_ = blockSize_ / sizeof(DTYPE_I);

        weightBufSize_ = AlignUp(numScales_ * numGroups_, blockAlignFloat_);
        locBufSize_ = AlignUp(numPoints_ * numCams_ * 2, blockAlignFloat_);
        scaleStartBufSize_ = AlignUp(numCams_ * numScales_, blockAlignInt_);
        spatialShapeBufSize_ = AlignUp(numCams_ * numScales_ * 2, blockAlignInt_);
        weightMulBufSize_ = numScales_ * cAligned_;

        v1Offset_ = 0 * weightMulBufSize_;
        v2Offset_ = 1 * weightMulBufSize_;
        v3Offset_ = 2 * weightMulBufSize_;
        v4Offset_ = 3 * weightMulBufSize_;

        copyOutParams_ = {1, static_cast<uint32_t>(numEmbeds_ * sizeof(DTYPE_F)), 0, 0, 0};
        srcShape_[0] = numScales_ * numGroups_;
        srcShape_[1] = 1;
        dstShape_[0] = numScales_ * numGroups_;
        dstShape_[1] = numChannels_;

        ASSERT(GetBlockNum() != 0 && "block dim can not be zero!");

        uint64_t mcMsFeatGmLength = bs_ * numFeats_ * numEmbeds_;
        uint64_t scaleStartIndexLength = numCams_ * numScales_;
        uint64_t spatialShapeGmLength = scaleStartIndexLength * 2;
        uint64_t samplingLocationGmLength = bs_ * numAnchors_ * numPoints_ * numCams_ * 2;
        uint64_t weightsGmLength = bs_ * numAnchors_ * numPoints_ * numCams_ * numScales_ * numGroups_;
        uint64_t outGmLength = bs_ * numAnchors_ * numEmbeds_;

        mcMsFeatGm_.SetGlobalBuffer((__gm__ DTYPE_F*)mc_ms_feat, mcMsFeatGmLength);
        samplingLocationGm_.SetGlobalBuffer((__gm__ DTYPE_F*)sampling_location, samplingLocationGmLength);
        weightsGm_.SetGlobalBuffer((__gm__ DTYPE_F*)weights, weightsGmLength);
        outGm_.SetGlobalBuffer((__gm__ DTYPE_F*)out, outGmLength);
        spatialShapesGm_.SetGlobalBuffer((__gm__ DTYPE_I*)spatial_shape, spatialShapeGmLength);
        scaleStartIndexGm_.SetGlobalBuffer((__gm__ DTYPE_I*)scale_start_index, scaleStartIndexLength);
    }

    __aicore__ inline void GetLocalTensor()
    {
        pipe_->InitBuffer(weightBuf_, weightBufSize_ * sizeof(DTYPE_F));
        pipe_->InitBuffer(locationBuf_, locBufSize_ * sizeof(DTYPE_F));
        pipe_->InitBuffer(scaleStartBuf_, scaleStartBufSize_ * sizeof(DTYPE_I));
        pipe_->InitBuffer(spatialShapeBuf_, spatialShapeBufSize_ * sizeof(DTYPE_I));
        pipe_->InitBuffer(vBuf_, 4 * weightMulBufSize_ * sizeof(DTYPE_F));
        pipe_->InitBuffer(weightMulBuf_, 4 * weightMulBufSize_ * sizeof(DTYPE_F));
        pipe_->InitBuffer(resBuf_, cAligned_ * sizeof(DTYPE_F));

        weightLocal_ = weightBuf_.Get<DTYPE_F>();
        locationLocal_ = locationBuf_.Get<DTYPE_F>();
        scaleStartLocal_ = scaleStartBuf_.Get<DTYPE_I>();
        spatialShapeLocal_ = spatialShapeBuf_.Get<DTYPE_I>();
        vLocal_ = vBuf_.Get<DTYPE_F>();
        weightMulLocal_ = weightMulBuf_.Get<DTYPE_F>();
        resLocal_ = resBuf_.Get<DTYPE_F>();
    }

    __aicore__ inline void Process()
    {
        taskNum_ = bs_ * numAnchors_;
        taskNumPerCore_ = DivCeil(taskNum_, coreNum_);
        curBlockIdx_ = GetBlockIdx();
        startOffset_ = curBlockIdx_ * taskNumPerCore_;
        endOffset_ = (curBlockIdx_ + 1) * taskNumPerCore_;
        if (endOffset_ > taskNum_) {
            endOffset_ = taskNum_;
        }

        Duplicate(vLocal_, static_cast<DTYPE_F>(0.0f), 4 * weightMulBufSize_);
        DataCopy(scaleStartLocal_, scaleStartIndexGm_, scaleStartBufSize_);
        DataCopy(spatialShapeLocal_, spatialShapesGm_, spatialShapeBufSize_);

        for (uint32_t taskIdx = startOffset_; taskIdx < endOffset_; ++taskIdx) {
            ComputeAndCopyOut(taskIdx);
        }
    }

    __aicore__ inline void ComputeAndCopyOut(int32_t taskIdx)
    {
        uint32_t batchIdx = taskIdx / numAnchors_;
        uint32_t anchorIdx = taskIdx % numAnchors_;
        uint64_t refOffsetGm = (batchIdx * numAnchors_ + anchorIdx) * numEmbeds_;
        uint64_t locationOffsetGm = (batchIdx * numAnchors_ +
                                     anchorIdx) * numPoints_ * numCams_ * 2;
        DataCopy(locationLocal_, samplingLocationGm_[locationOffsetGm], locBufSize_);
        Duplicate(resLocal_, static_cast<DTYPE_F>(0.0f), cAligned_);
        for (uint32_t pointIdx = 0; pointIdx < numPoints_; ++pointIdx) {
            for (uint32_t camIdx = 0; camIdx < numCams_; ++camIdx) {
                uint32_t locationOffsetLocal = (pointIdx * numCams_ + camIdx) * 2;
                float locW = locationLocal_.GetValue(locationOffsetLocal);
                if (locW <= 0 || locW >= 1) {
                    continue;
                }
                float locH = locationLocal_.GetValue(locationOffsetLocal + 1);
                if (locH <= 0 || locH >= 1) {
                    continue;
                }
                uint64_t weightOffsetGm = (((batchIdx * numAnchors_ + anchorIdx) * numPoints_ + pointIdx)
                                          * numCams_ + camIdx) * numScales_ * numGroups_;
                DataCopy(weightLocal_, weightsGm_[weightOffsetGm], weightBufSize_);
                SetFlag<HardEvent::MTE2_V>(0);
                WaitFlag<HardEvent::MTE2_V>(0);
                BroadCast<DTYPE_F, 2, 1>(weightMulLocal_, weightLocal_, dstShape_, srcShape_);
                SetFlag<HardEvent::V_MTE2>(0);
                WaitFlag<HardEvent::V_MTE2>(0);
                for (uint32_t i = 1; i < 4; ++i) {
                    Adds(weightMulLocal_[i * weightMulBufSize_], weightMulLocal_, static_cast<DTYPE_F>(0.0f), weightMulBufSize_);
                }
                for (uint32_t scaleIdx = 0; scaleIdx < numScales_; ++scaleIdx) {
                    uint32_t localOffset = scaleIdx * cAligned_;
                    uint32_t scaleStartOffset = camIdx * numScales_ + scaleIdx;
                    uint32_t spatialShapeOffset = scaleStartOffset * 2;
                    uint32_t scaleStartIdx = scaleStartLocal_.GetValue(scaleStartOffset);
                    uint32_t valueOffset = (batchIdx * numFeats_ + scaleStartIdx) * numEmbeds_;

                    uint32_t localPtr1_ = v1Offset_ + localOffset;
                    uint32_t localPtr2_ = v2Offset_ + localOffset;
                    uint32_t localPtr3_ = v3Offset_ + localOffset;
                    uint32_t localPtr4_ = v4Offset_ + localOffset;

                    DTYPE_I h = spatialShapeLocal_.GetValue(spatialShapeOffset);
                    DTYPE_I w = spatialShapeLocal_.GetValue(spatialShapeOffset + 1);

                    float hIm = locH * h - 0.5f;
                    float wIm = locW * w - 0.5f;

                    DTYPE_I hLow = ScalarCast<float, DTYPE_I, RoundMode::CAST_FLOOR>(hIm);
                    DTYPE_I wLow = ScalarCast<float, DTYPE_I, RoundMode::CAST_FLOOR>(wIm);
                    DTYPE_I hHigh = hLow + 1;
                    DTYPE_I wHigh = wLow + 1;

                    float lh = hIm - hLow;
                    float lw = wIm - wLow;
                    float hh = 1 - lh;
                    float hw = 1 - lw;

                    DTYPE_I wStride = numEmbeds_;
                    DTYPE_I hStride = w * wStride;
                    DTYPE_I hLowPtrOffset = hLow * hStride;
                    DTYPE_I hHighPtrOffset = hLowPtrOffset + hStride;
                    DTYPE_I wLowPtrOffset = wLow * wStride;
                    DTYPE_I wHighPtrOffset = wLowPtrOffset + wStride;

                    float w1 = hh * hw;
                    float w2 = hh * lw;
                    float w3 = lh * hw;
                    float w4 = lh * lw;

                    if (hLow >= 0) {
                        basePtr_ = valueOffset + hLowPtrOffset;
                        if (wLow >= 0) {
                            realPtr_ = basePtr_ + wLowPtrOffset;
                            DataCopy(vLocal_[localPtr1_], mcMsFeatGm_[realPtr_], cAligned_);
                        }
                        if (wHigh <= w - 1) {
                            realPtr_ = basePtr_ + wHighPtrOffset;
                            DataCopy(vLocal_[localPtr2_], mcMsFeatGm_[realPtr_], cAligned_);
                        }
                    }
                    if (hHigh <= h - 1) {
                        basePtr_ = valueOffset + hHighPtrOffset;
                        if (wLow >= 0) {
                            realPtr_ = basePtr_ + wLowPtrOffset;
                            DataCopy(vLocal_[localPtr3_], mcMsFeatGm_[realPtr_], cAligned_);
                        }
                        if (wHigh <= w - 1) {
                            realPtr_ = basePtr_ + wHighPtrOffset;
                            DataCopy(vLocal_[localPtr4_], mcMsFeatGm_[realPtr_], cAligned_);
                        }
                    }
                    Muls(weightMulLocal_[localPtr1_], weightMulLocal_[localPtr1_], static_cast<DTYPE_F>(w1), cAligned_);
                    Muls(weightMulLocal_[localPtr2_], weightMulLocal_[localPtr2_], static_cast<DTYPE_F>(w2), cAligned_);
                    Muls(weightMulLocal_[localPtr3_], weightMulLocal_[localPtr3_], static_cast<DTYPE_F>(w3), cAligned_);
                    Muls(weightMulLocal_[localPtr4_], weightMulLocal_[localPtr4_], static_cast<DTYPE_F>(w4), cAligned_);
                }
                SetFlag<HardEvent::MTE2_V>(0);
                WaitFlag<HardEvent::MTE2_V>(0);
                Mul(weightMulLocal_, weightMulLocal_, vLocal_, 4 * weightMulBufSize_);
                Duplicate(vLocal_, static_cast<DTYPE_F>(0.0f), 4 * weightMulBufSize_);
                SetFlag<HardEvent::V_MTE2>(0);
                WaitFlag<HardEvent::V_MTE2>(0);
                for (uint32_t i = 0; i < 4 * numScales_; ++i) {
                    Add(resLocal_, resLocal_, weightMulLocal_[i * cAligned_], cAligned_);
                }
            }
        }
        SetFlag<HardEvent::V_MTE3>(0);
        WaitFlag<HardEvent::V_MTE3>(0);
        DataCopyPad(outGm_[refOffsetGm], resLocal_, copyOutParams_);
        SetFlag<HardEvent::MTE3_V>(0);
        WaitFlag<HardEvent::MTE3_V>(0);
    }

private:
    TPipe *pipe_;

    TBuf<TPosition::VECCALC> weightBuf_, locationBuf_, scaleStartBuf_, spatialShapeBuf_;
    TBuf<TPosition::VECCALC> vBuf_, weightMulBuf_, resBuf_;

    GlobalTensor<DTYPE_F> mcMsFeatGm_, samplingLocationGm_, weightsGm_, outGm_;
    GlobalTensor<DTYPE_I> spatialShapesGm_, scaleStartIndexGm_;

    LocalTensor<DTYPE_F> locationLocal_, weightLocal_;
    LocalTensor<DTYPE_I> spatialShapeLocal_, scaleStartLocal_;
    LocalTensor<DTYPE_F> vLocal_, weightMulLocal_, resLocal_;

    uint32_t basePtr_, realPtr_;
    uint32_t coreNum_, curBlockIdx_;
    uint32_t taskNum_, taskNumPerCore_, startOffset_, endOffset_;
    uint32_t weightBufSize_, locBufSize_, scaleStartBufSize_, spatialShapeBufSize_, weightMulBufSize_;
    uint32_t bs_, numFeats_, numEmbeds_, numAnchors_, numPoints_, numCams_, numScales_, numGroups_, numChannels_, cAligned_;
    uint32_t blockAlignFloat_, blockAlignInt_;
    uint32_t v1Offset_ = 0, v2Offset_ = 1, v3Offset_ = 2, v4Offset_ = 3;
    
    uint32_t srcShape_[2], dstShape_[2];
    DataCopyExtParams copyOutParams_;
};

extern "C" __global__ __aicore__ void deformable_aggregation(GM_ADDR mc_ms_feat, GM_ADDR spatial_shape,
    GM_ADDR scale_start_index, GM_ADDR sampling_location, GM_ADDR weights, GM_ADDR out, GM_ADDR workspace,
    GM_ADDR tiling)
{
    TPipe pipe;
    GET_TILING_DATA(tiling_data, tiling);
    KernelDeformableAggregation<DTYPE_MC_MS_FEAT, int32_t> op;
    op.Init(mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, out, &tiling_data, &pipe);
    op.GetLocalTensor();
    op.Process();
}