/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
 */
#ifndef SCATTER_ADD_GRAD_H_
#define SCATTER_ADD_GRAD_H_

#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "scatter_add_grad_base.h"
namespace ScatterAddGradNS {
using namespace AscendC;

template <typename T>
class ScatterAddGradV1 : public ScatterAddGradBase<T> {
public:
    __aicore__ inline ScatterAddGradV1() {}
    __aicore__ inline void Init(GM_ADDR gradOut, GM_ADDR index, GM_ADDR gradIn, const ScatterAddGradTilingData* tilingData)
    {
        this->InitTiling(tilingData);
        InitNoTailTiling(tilingData);

        gradInGm.SetGlobalBuffer((__gm__ T *)gradIn, this->gradInNum);
        indexGm.SetGlobalBuffer((__gm__ int32_t *)index, this->indexNum);
        gradOutGm.SetGlobalBuffer((__gm__ T *)gradOut, this->gradOutNum);

        pipe.InitBuffer(inGradOutUb, this->gradOutUbSize * sizeof(T));
        pipe.InitBuffer(inIndexUb, this->indexUbSize * sizeof(int32_t));
        pipe.InitBuffer(outGradInUb, this->indexUbSize * sizeof(T));
    }

    __aicore__ inline void Process()
    {
        if (this->tilingMode == 0) {
            this->copyParamsOut.blockLen = static_cast<uint32_t>(headIndexSize * sizeof(float));
            this->copyParamsIn.blockLen = static_cast<uint32_t>(headIndexSize * sizeof(float));
            this->copyParamsInPad.isPad = true;
            this->copyParamsInPad.rightPadding = headIndexSizeAlign - headIndexSize;
            for (uint64_t taskId = 0; taskId < taskNum - 1; taskId++) {
                ComputeModeSmallData(taskId, headTask);
            }
            if (headLastTask != 0) {
                ComputeModeSmallData(taskNum - 1, headLastTask);
            }
        } else {
            indexLoop = headIndexSize / this->indexUbSize;
            indexLast = headIndexSize - indexLoop * this->indexUbSize;
            this->copyParamsOut.blockLen = static_cast<int32_t>(indexLast * sizeof(float));
            SetFlag<HardEvent::MTE3_V>(0);
            SetFlag<HardEvent::V_MTE2>(0);
            SetFlag<HardEvent::V_MTE2>(1);
            for (uint64_t taskId = 0; taskId < taskNum - 1; taskId++) {
                ComputeModeLargeData(taskId, headTask);
            }
            if (headLastTask != 0) {
                ComputeModeLargeData(taskNum - 1, headLastTask);
            }
            WaitFlag<HardEvent::MTE3_V>(0);
            WaitFlag<HardEvent::V_MTE2>(0);
            WaitFlag<HardEvent::V_MTE2>(1);
        }
    }

private:
    __aicore__ inline void InitNoTailTiling(const ScatterAddGradTilingData *tiling_data)
    {
        auto headTaskSmall = tiling_data->headTaskSmall;
        auto taskNumSmall = tiling_data->taskNumSmall;
        auto headLastTaskSmall = tiling_data->headLastTaskSmall;
        auto headTaskBig = tiling_data->headTaskBig;
        auto taskNumBig = tiling_data->taskNumBig;
        auto headLastTaskBig = tiling_data->headLastTaskBig;

        headOutSize = this->dimRangeOut * this->paramsPro;
        headIndexSize = this->dimRange * this->paramsPro;
        headIndexSizeAlign = AlignUp(headIndexSize, B32_DATA_NUM_PER_BLOCK);

        auto headBigCore = (taskNumBig - 1) * headTaskBig + headLastTaskBig;
        auto headSmallCore = headBigCore - 1;

        if (this->curBlockIdx < this->bigCoreNum) {
            taskNum = taskNumBig;
            headTask = headTaskBig;
            headLastTask = headLastTaskBig;
            headBaseId = this->curBlockIdx * headBigCore;
        } else {
            taskNum = taskNumSmall;
            headTask = headTaskSmall;
            headLastTask = headLastTaskSmall;
            headBaseId = this->bigCoreNum * headBigCore + (this->curBlockIdx - this->bigCoreNum) * headSmallCore;
        }
    }

    __aicore__ inline void ComputeModeSmallData(uint64_t taskId, uint64_t headNum)
    {
        LocalTensor<int32_t> indexLocal = inIndexUb.Get<int32_t>();
        LocalTensor<T> gradOutLocal = inGradOutUb.Get<T>();
        LocalTensor<T> gradInLocal = outGradInUb.Get<T>();

        uint64_t firstHeadId = headBaseId + headTask * taskId;
        uint64_t indexOffset = firstHeadId * headIndexSize;
        uint64_t outOffset = firstHeadId * headOutSize;

        this->copyParamsIn.blockCount = static_cast<uint32_t>(headNum);
        this->copyParamsOut.blockCount = static_cast<uint32_t>(headNum);
        uint64_t outAlign = AlignUp(headNum * headOutSize, this->paramsEachBlock);

        SetFlag<HardEvent::V_MTE2>(0);
        WaitFlag<HardEvent::V_MTE2>(0);
        DataCopyPad(indexLocal, indexGm[indexOffset], copyParamsIn, copyParamsInPad);
        DataCopy(gradOutLocal, gradOutGm[outOffset], outAlign);
        SetFlag<HardEvent::MTE2_V>(0);
        WaitFlag<HardEvent::MTE2_V>(0);
        for (uint64_t head = 0; head < headNum; head++) {
            int32_t indexLocalOffset = head * headIndexSizeAlign;
            int32_t outLocalOffset = head * headOutSize;
            Adds(indexLocal[indexLocalOffset], indexLocal[indexLocalOffset], outLocalOffset, headIndexSizeAlign);
        }
        Muls(indexLocal, indexLocal, (int32_t)sizeof(T), headNum * headIndexSizeAlign);
        SetFlag<HardEvent::MTE3_V>(0);
        WaitFlag<HardEvent::MTE3_V>(0);
        Gather(gradInLocal, gradOutLocal, indexLocal.ReinterpretCast<uint32_t>(), 0, headNum * headIndexSizeAlign);
        SetFlag<HardEvent::V_MTE3>(0);
        WaitFlag<HardEvent::V_MTE3>(0);
        DataCopyPad(gradInGm[indexOffset], gradInLocal, this->copyParamsOut);
    }

    __aicore__ inline void ComputeModeLargeData(uint64_t taskId, uint64_t headNum)
    {
        LocalTensor<int32_t> indexLocal = inIndexUb.Get<int32_t>();
        LocalTensor<T> gradOutLocal = inGradOutUb.Get<T>();
        LocalTensor<T> gradInLocal = outGradInUb.Get<T>();

        uint64_t firstHeadId = headBaseId + headTask * taskId;
        uint64_t indexOffset = firstHeadId * headIndexSize;
        uint64_t outOffset = firstHeadId * headOutSize;
        uint64_t outAlign = AlignUp(headNum * headOutSize, this->paramsEachBlock);

        WaitFlag<HardEvent::V_MTE2>(1);
        DataCopy(gradOutLocal, gradOutGm[outOffset], outAlign);
        SetFlag<HardEvent::MTE2_V>(1);
        WaitFlag<HardEvent::MTE2_V>(1);

        for (uint64_t head = 0; head < headNum; head++) {
            uint64_t indicesAlign = AlignUp(headIndexSize, this->indicesEachBlock);
            auto headOutOffset = head * headOutSize;
            for (uint64_t loop = 0; loop < indexLoop; loop++) {
                uint64_t offset = this->indexUbSize * loop;
                WaitFlag<HardEvent::V_MTE2>(0);
                DataCopy(indexLocal, indexGm[indexOffset + head * headIndexSize + offset], this->indexUbSize);
                SetFlag<HardEvent::MTE2_V>(0);
                WaitFlag<HardEvent::MTE2_V>(0);
                Adds(indexLocal, indexLocal, (int32_t)headOutOffset, this->indexUbSize);
                Muls(indexLocal, indexLocal, (int32_t)sizeof(T), this->indexUbSize);
                WaitFlag<HardEvent::MTE3_V>(0);
                Gather(gradInLocal, gradOutLocal, indexLocal.ReinterpretCast<uint32_t>(), (uint32_t)0, this->indexUbSize);
                SetFlag<HardEvent::V_MTE2>(0);
                SetFlag<HardEvent::V_MTE3>(0);
                WaitFlag<HardEvent::V_MTE3>(0);
                DataCopy(gradInGm[indexOffset + head * headIndexSize + offset], gradInLocal, this->indexUbSize);
                SetFlag<HardEvent::MTE3_V>(0);
            }
            if (indexLast != 0) {
                uint64_t offset = this->indexUbSize * indexLoop;
                uint64_t indicesAlign = AlignUp(indexLast, this->indicesEachBlock);
                WaitFlag<HardEvent::V_MTE2>(0);
                DataCopy(indexLocal, indexGm[indexOffset + head * headIndexSize + offset], indicesAlign);
                SetFlag<HardEvent::MTE2_V>(0);
                WaitFlag<HardEvent::MTE2_V>(0);
                Adds(indexLocal, indexLocal, (int32_t)headOutOffset, indicesAlign);
                Muls(indexLocal, indexLocal, (int32_t)sizeof(T), indicesAlign);
                WaitFlag<HardEvent::MTE3_V>(0);
                Gather(gradInLocal, gradOutLocal, indexLocal.ReinterpretCast<uint32_t>(), (uint32_t)0, indexLast);
                SetFlag<HardEvent::V_MTE2>(0);
                SetFlag<HardEvent::V_MTE3>(0);
                WaitFlag<HardEvent::V_MTE3>(0);
                DataCopyPad(gradInGm[indexOffset + head * headIndexSize + offset], gradInLocal, this->copyParamsOut);
                SetFlag<HardEvent::MTE3_V>(0);
            }
        }
        SetFlag<HardEvent::V_MTE2>(1);
    }

private:
    TPipe pipe;
    TBuf<TPosition::VECCALC> inGradOutUb, inIndexUb, outGradInUb;

    GlobalTensor<T> gradInGm, gradOutGm;
    GlobalTensor<int32_t> indexGm;

    uint64_t headOutSize;
    uint64_t headIndexSize;
    uint64_t headIndexSizeAlign;
    uint64_t taskNum;
    uint64_t headTask;
    uint64_t headLastTask;
    uint64_t headBaseId;
    uint64_t indexLoop;
    uint64_t indexLast;

    DataCopyExtParams copyParamsIn = {1, 8, 0, 0, 0};
    DataCopyPadExtParams<int32_t> copyParamsInPad = {false, 0, 0, 0};
};
}
#endif