/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*!
 * \file max_pool2d.cpp
 * \brief
 */

#include "kernel_operator.h"
using namespace AscendC;

class KernelMaxPool2d {
public:
    __aicore__ inline KernelMaxPool2d() {}
    __aicore__ inline void Init(
        GM_ADDR x_trans, GM_ADDR y_trans, const MaxPool2dTilingData* tiling_data, TPipe* tmpPipe)
    {
        pipe = tmpPipe;
        ASSERT(GetBlockNum() != 0 && "block dim can not be zero!");
        dataAlign = blockNum / sizeof(DTYPE_X_TRANS);
        batchSize = tiling_data->batchSize;
        channel = tiling_data->channel;
        inHeight = tiling_data->inHeight;
        inWidth = tiling_data->inWidth;
        outHeight = tiling_data->outHeight;
        outWidth = tiling_data->outWidth;
        coreNum = tiling_data->coreNum;

        batchNum = channel * kernelSize;

        taskNum = batchSize * outHeight;
        taskNumPerCore = DivCeil(taskNum, coreNum);

        curBlockIdx = GetBlockIdx();
        startOffset = curBlockIdx * taskNumPerCore;
        endOffset = (curBlockIdx + 1) * taskNumPerCore;
        if (endOffset > taskNum) {
            endOffset = taskNum;
        }

        wBatch = (numAlign / channel - 1) / stride;
        validW = outWidth - 1;
        if (inWidth % 2 == 1) {
            validW = outWidth - 2;
        }
        wRounds = validW / wBatch;
        wTail = validW % wBatch;

        eventIdVToMte3 = static_cast<event_t>(pipe->AllocEventID<HardEvent::V_MTE3>());

        copyParams = {(uint16_t)kernelSize, uint32_t(batchNum * sizeof(DTYPE_X_TRANS)),
            uint32_t((inWidth - kernelSize) * channel * sizeof(DTYPE_X_TRANS)), 0, 0};

        xTransGm.SetGlobalBuffer(
            reinterpret_cast<__gm__ DTYPE_X_TRANS*>(x_trans), batchSize * inHeight * inWidth * channel);
        yTransGm.SetGlobalBuffer(
            reinterpret_cast<__gm__ DTYPE_X_TRANS*>(y_trans), batchSize * outHeight * outWidth * channel);

        pipe->InitBuffer(xPart1Ub, batchNum * kernelSize * sizeof(DTYPE_X_TRANS));
        pipe->InitBuffer(xPart2Ub, batchNum * sizeof(DTYPE_X_TRANS));
        pipe->InitBuffer(xPart3Ub, batchNum * sizeof(DTYPE_X_TRANS));

        pipe->InitBuffer(xBatchUb1, numAlign * sizeof(DTYPE_X_TRANS));
        pipe->InitBuffer(xBatchUb2, numAlign * sizeof(DTYPE_X_TRANS));
        pipe->InitBuffer(xBatchUb3, numAlign * sizeof(DTYPE_X_TRANS));
        pipe->InitBuffer(xBatchUb4, numAlign * sizeof(DTYPE_X_TRANS));

        pipe->InitBuffer(maxPart1Ub, batchNum * sizeof(DTYPE_X_TRANS));
        pipe->InitBuffer(maxPart2Ub, batchNum * sizeof(DTYPE_X_TRANS));

        pipe->InitBuffer(resUb, channel * sizeof(DTYPE_X_TRANS));
    }

    __aicore__ inline void Process()
    {
        ComputeNH();
    }

private:
    __aicore__ inline void MovePart()
    {
        DataCopy(xPart2Local, xTransGm[baseOffset * channel], batchNum - channel);
        DataCopy(xPart3Local, xTransGm[(baseOffset + inWidth) * channel], batchNum - channel);
        PipeBarrier<PIPE_ALL>();

        Max(maxPart2Local, xPart2Local, xPart3Local, batchNum - channel);
        Max(resLocal, maxPart2Local, maxPart2Local[channel], channel);

        SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
        WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);

        DataCopy(yTransGm[outOffset], resLocal, channel);

        for (uint32_t idx1 = 0; idx1 < wRounds; idx1++) {
            inOffset = baseOffset + (idx1 * wBatch + 1) * stride - padding;
            DataCopy(xBatchLocal1, xTransGm[inOffset * channel], (wBatch * stride + 1) * channel);
            DataCopy(xBatchLocal2, xTransGm[(inOffset + inWidth) * channel], (wBatch * stride + 1) * channel);
            PipeBarrier<PIPE_ALL>();
            Max(xBatchLocal3, xBatchLocal1, xBatchLocal2, (wBatch * stride + 1) * channel);

            for (uint32_t idx2 = 0; idx2 < wBatch; idx2++) {
                Max(xBatchLocal4[idx2 * channel], xBatchLocal3[idx2 * stride * channel],
                    xBatchLocal3[idx2 * stride * channel + channel], channel);
                Max(xBatchLocal4[idx2 * channel], xBatchLocal4[idx2 * channel],
                    xBatchLocal3[idx2 * stride * channel + 2 * channel], channel);
            }
            PipeBarrier<PIPE_ALL>();

            DataCopy(yTransGm[outOffset + (idx1 * wBatch + 1) * channel], xBatchLocal4, wBatch * channel);
        }
        if (wTail > 0) {
            inOffset = baseOffset + (wBatch * wRounds + 1) * stride - padding;
            DataCopy(xBatchLocal1, xTransGm[inOffset * channel], (wTail * stride + 1) * channel);
            DataCopy(xBatchLocal2, xTransGm[(inOffset + inWidth) * channel], (wTail * stride + 1) * channel);
            PipeBarrier<PIPE_ALL>();
            Max(xBatchLocal3, xBatchLocal1, xBatchLocal2, (wTail * stride + 1) * channel);

            for (uint32_t idx2 = 0; idx2 < wTail; idx2++) {
                Max(xBatchLocal4[idx2 * channel], xBatchLocal3[idx2 * stride * channel],
                    xBatchLocal3[idx2 * stride * channel + channel], channel);
                Max(xBatchLocal4[idx2 * channel], xBatchLocal4[idx2 * channel],
                    xBatchLocal3[idx2 * stride * channel + 2 * channel], channel);
            }
            PipeBarrier<PIPE_ALL>();

            DataCopy(yTransGm[outOffset + (wBatch * wRounds + 1) * channel], xBatchLocal4, wTail * channel);
        }
        if (inWidth % 2 == 1) {
            inOffset = baseOffset + (outWidth - 1) * stride - padding;
            DataCopy(xPart2Local, xTransGm[inOffset * channel], batchNum - channel);
            DataCopy(xPart3Local, xTransGm[(inOffset + inWidth) * channel], batchNum - channel);
            PipeBarrier<PIPE_ALL>();

            Max(maxPart2Local, xPart2Local, xPart3Local, batchNum - channel);
            Max(resLocal, maxPart2Local, maxPart2Local[channel], channel);

            SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
            WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);

            DataCopy(yTransGm[outOffset + (outWidth - 1) * channel], resLocal, channel);
        }
    }

    __aicore__ inline void MoveMain()
    {
        DataCopy(xPart1Local, xTransGm[baseOffset * channel], batchNum - channel);
        DataCopy(xPart2Local, xTransGm[(baseOffset + inWidth) * channel], batchNum - channel);
        DataCopy(xPart3Local, xTransGm[(baseOffset + inWidth * 2) * channel], batchNum - channel);
        PipeBarrier<PIPE_ALL>();

        Max(maxPart1Local, xPart1Local, xPart2Local, batchNum - channel);
        Max(maxPart2Local, maxPart1Local, xPart3Local, batchNum - channel);
        Max(xBatchLocal4, maxPart2Local, maxPart2Local[channel], channel);
        SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
        WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
        DataCopy(yTransGm[outOffset], xBatchLocal4, channel);

        for (uint32_t idx1 = 0; idx1 < wRounds; idx1++) {
            inOffset = baseOffset + (idx1 * wBatch + 1) * stride - padding;
            DataCopy(xBatchLocal1, xTransGm[inOffset * channel], (wBatch * stride + 1) * channel);
            DataCopy(xBatchLocal2, xTransGm[(inOffset + inWidth) * channel], (wBatch * stride + 1) * channel);
            DataCopy(xBatchLocal3, xTransGm[(inOffset + inWidth * 2) * channel], (wBatch * stride + 1) * channel);

            PipeBarrier<PIPE_ALL>();

            Max(xBatchLocal1, xBatchLocal1, xBatchLocal2, (wBatch * stride + 1) * channel);
            Max(xBatchLocal3, xBatchLocal1, xBatchLocal3, (wBatch * stride + 1) * channel);

            for (uint32_t idx2 = 0; idx2 < wBatch; idx2++) {
                Max(xBatchLocal4[idx2 * channel], xBatchLocal3[idx2 * stride * channel],
                    xBatchLocal3[idx2 * stride * channel + channel], channel);
                Max(xBatchLocal4[idx2 * channel], xBatchLocal4[idx2 * channel],
                    xBatchLocal3[idx2 * stride * channel + 2 * channel], channel);
            }
            PipeBarrier<PIPE_ALL>();

            DataCopy(yTransGm[outOffset + (idx1 * wBatch + 1) * channel], xBatchLocal4, wBatch * channel);
        }
        if (wTail > 0) {
            inOffset = baseOffset + (wBatch * wRounds + 1) * stride - padding;
            DataCopy(xBatchLocal1, xTransGm[inOffset * channel], (wTail * stride + 1) * channel);
            DataCopy(xBatchLocal2, xTransGm[(inOffset + inWidth) * channel], (wTail * stride + 1) * channel);
            DataCopy(xBatchLocal3, xTransGm[(inOffset + inWidth * 2) * channel], (wTail * stride + 1) * channel);
            PipeBarrier<PIPE_ALL>();

            Max(xBatchLocal1, xBatchLocal1, xBatchLocal2, (wTail * stride + 1) * channel);
            Max(xBatchLocal3, xBatchLocal1, xBatchLocal3, (wTail * stride + 1) * channel);

            for (uint32_t idx2 = 0; idx2 < wTail; idx2++) {
                Max(xBatchLocal4[idx2 * channel], xBatchLocal3[idx2 * stride * channel],
                    xBatchLocal3[idx2 * stride * channel + channel], channel);
                Max(xBatchLocal4[idx2 * channel], xBatchLocal4[idx2 * channel],
                    xBatchLocal3[idx2 * stride * channel + 2 * channel], channel);
            }
            PipeBarrier<PIPE_ALL>();

            DataCopy(yTransGm[outOffset + (wBatch * wRounds + 1) * channel], xBatchLocal4, wTail * channel);
        }

        if (inWidth % 2 == 1) {
            inOffset = baseOffset + (outWidth - 1) * stride - padding;

            DataCopy(xPart1Local, xTransGm[inOffset * channel], batchNum - channel);
            DataCopy(xPart2Local, xTransGm[(inOffset + inWidth) * channel], batchNum - channel);
            DataCopy(xPart3Local, xTransGm[(inOffset + inWidth * 2) * channel], batchNum - channel);
            PipeBarrier<PIPE_ALL>();

            Max(maxPart1Local, xPart1Local, xPart2Local, batchNum - channel);
            Max(maxPart2Local, maxPart1Local, xPart3Local, batchNum - channel);
            Max(resLocal, maxPart2Local, maxPart2Local[channel], channel);
            SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
            WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
            DataCopy(yTransGm[outOffset + (outWidth - 1) * channel], resLocal, channel);
        }
    }

    __aicore__ inline void ComputeNH()
    {
        xPart1Local = xPart1Ub.Get<DTYPE_X_TRANS>();
        xPart2Local = xPart2Ub.Get<DTYPE_X_TRANS>();
        xPart3Local = xPart3Ub.Get<DTYPE_X_TRANS>();

        maxPart1Local = maxPart1Ub.Get<DTYPE_X_TRANS>();
        maxPart2Local = maxPart2Ub.Get<DTYPE_X_TRANS>();

        xBatchLocal1 = xBatchUb1.Get<DTYPE_X_TRANS>();
        xBatchLocal2 = xBatchUb2.Get<DTYPE_X_TRANS>();
        xBatchLocal3 = xBatchUb3.Get<DTYPE_X_TRANS>();
        xBatchLocal4 = xBatchUb4.Get<DTYPE_X_TRANS>();

        resLocal = resUb.Get<DTYPE_X_TRANS>();

        for (uint32_t idx = startOffset; idx < endOffset; idx++) {
            high = idx % outHeight;
            batch = idx / outHeight;

            outOffset = idx * outWidth * channel;
            oriHeight = high * stride - padding;
            baseOffset = (batch * inHeight + oriHeight) * inWidth;

            if (oriHeight == -padding) {
                baseOffset = baseOffset + inWidth;
                MovePart();
            } else if (oriHeight + kernelSize > inHeight) {
                MovePart();
            } else {
                MoveMain();
            }
        }
    }

private:
    TPipe* pipe;
    GlobalTensor<DTYPE_X_TRANS> xTransGm, yTransGm;
    TBuf<TPosition::VECCALC> xPart1Ub, xPart2Ub, xPart3Ub, maxPart1Ub, maxPart2Ub, resUb, xBatchUb1, xBatchUb2,
        xBatchUb3, xBatchUb4;
    LocalTensor<DTYPE_X_TRANS> xPart1Local, xPart2Local, xPart3Local, maxPart1Local, maxPart2Local, xBatchLocal1,
        xBatchLocal2, xBatchLocal3, xBatchLocal4;
    LocalTensor<DTYPE_X_TRANS> resLocal;
    uint32_t batchSize;
    uint32_t channel;
    uint32_t inHeight;
    uint32_t inWidth;
    uint32_t outHeight;
    uint32_t outWidth;
    uint32_t coreNum;
    uint32_t numAlign = 64 * 64;

    uint32_t wBatch;
    uint32_t validW;

    uint32_t wRounds;
    uint32_t wTail;

    uint32_t oriHeight;
    uint32_t oriWidth;
    uint32_t inOffset;
    uint32_t baseOffset;
    uint32_t outOffset;

    uint32_t batch;
    uint32_t high;
    uint32_t wide;

    uint32_t taskNum;
    uint32_t taskNumPerCore;
    uint32_t curBlockIdx;
    uint32_t startOffset;
    uint32_t endOffset;
    uint32_t dataAlign;
    uint32_t blockNum = 32;
    uint32_t padding = 1;
    uint32_t stride = 2;
    uint32_t kernelSize = 3;
    uint32_t batchNum;

    event_t eventIdVToMte3;

    DataCopyExtParams copyParams;
    DataCopyPadExtParams<DTYPE_X_TRANS> padParams {false, 0, 0, 0};
};

extern "C" __global__ __aicore__ void max_pool2d(GM_ADDR x_trans, GM_ADDR y_trans, GM_ADDR workspace, GM_ADDR tiling)
{
    TPipe pipe;
    GET_TILING_DATA(tiling_data, tiling);
    KernelMaxPool2d op;
    op.Init(x_trans, y_trans, &tiling_data, &pipe);
    op.Process();
}