/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
 */
#ifndef _SCATTER_MEAN_GRAD_LARGE_H_
#define _SCATTER_MEAN_GRAD_LARGE_H_

#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "scatter_mean_grad_base.h"
namespace ScatterMeanGradNS {
using namespace AscendC;

template <typename T>
class ScatterMeanGradLarge : public ScatterMeanGradBase<T> {
public:
    __aicore__ inline ScatterMeanGradLarge() {}
    __aicore__ inline void Init(GM_ADDR gradOut, GM_ADDR index, GM_ADDR count, GM_ADDR gradIn, const ScatterMeanGradTilingData* tilingData)
    {
        this->InitTiling(tilingData);
        InitNoTailTiling(tilingData);
        gradInGm.SetGlobalBuffer((__gm__ T *)gradIn, this->gradInNum);
        indexGm.SetGlobalBuffer((__gm__ int32_t *)index, this->indexNum);
        gradOutGm.SetGlobalBuffer((__gm__ T *)gradOut, this->gradOutNum);
        countGm.SetGlobalBuffer((__gm__ T *)count, this->countNum);

        pipe.InitBuffer(inGradOutUb, this->gradOutUbSize * sizeof(T));
        pipe.InitBuffer(inIndexUb, this->indexUbSize * sizeof(int32_t));
        pipe.InitBuffer(outGradInUb, this->indexUbSize * sizeof(T));
        pipe.InitBuffer(inCountUb, this->gradOutUbSize * sizeof(T));
    }

    __aicore__ inline void Process()
    {
        for (uint64_t taskId = 0; taskId < taskNum; taskId++) {
            auto taskIdAll = taskId + baseTaskNum;
            uint64_t headPartId = taskIdAll % taskEachHead;
            uint64_t headBaseId = taskIdAll / taskEachHead;
            if (headPartId == taskEachHead - 1) {
                auto lastDealNum = headOutSize % this->gradOutUbSize;
                taskDealNum = lastDealNum == 0 ? this->gradOutUbSize : lastDealNum;
            } else {
                taskDealNum = this->gradOutUbSize;
            }
            ComputeModePart(taskId, taskDealNum, headBaseId, headPartId);
        }
    }

private:
    __aicore__ inline void InitNoTailTiling(const ScatterMeanGradTilingData *tiling_data)
    {
        auto taskNumSmall = tiling_data->taskNumSmall;
        auto taskNumBig = tiling_data->taskNumBig;
        taskEachHead = tiling_data->taskEachHead;
        headOutSize = this->dimRangeOut * this->paramsPro;
        headIndexSize = this->dimRange * this->paramsPro;
        taskDealNum = this->gradOutUbSize;

        if (this->curBlockIdx < this->bigCoreNum) {
            taskNum = taskNumBig;
            baseTaskNum = this->curBlockIdx * taskNum;
        } else {
            taskNum = taskNumSmall;
            baseTaskNum = this->bigCoreNum * taskNumBig + (this->curBlockIdx - this->bigCoreNum) * taskNum;
        }

        indexLoop = headIndexSize / this->indexUbSize;
        indexLast = headIndexSize - indexLoop * this->indexUbSize;

        this->copyParamsOut.blockLen = static_cast<uint32_t>(indexLast * sizeof(float));
    }

    __aicore__ inline void ComputeModePart(uint64_t taskId, uint64_t taskDealNum, uint64_t headBaseId, uint64_t headPartId)
    {
        LocalTensor<int32_t> indexLocal = inIndexUb.Get<int32_t>();
        LocalTensor<T> gradOutLocal = inGradOutUb.Get<T>();
        LocalTensor<T> countLocal = inCountUb.Get<T>();
        LocalTensor<T> gradInLocal = outGradInUb.Get<T>();

        uint64_t indexOffset = headBaseId * headIndexSize;
        uint64_t outOffset = headBaseId * headOutSize + headPartId * this->gradOutUbSize;
        uint64_t outAlign = AlignUp(taskDealNum, this->paramsEachBlock);
        auto baseOutOffset = headPartId * this->gradOutUbSize;

        DataCopy(gradOutLocal, gradOutGm[outOffset], outAlign);
        DataCopy(countLocal, countGm[outOffset], outAlign);

        PipeBarrier<PIPE_ALL>();
        Div(gradOutLocal, gradOutLocal, countLocal, outAlign);

        uint64_t offset = 0;
        SetFlag<HardEvent::MTE3_V>(EVENT_ID0);
        for (uint64_t loop = 0; loop < indexLoop; loop++) {
            offset = loop * this->indexUbSize;
            DataCopy(indexLocal, indexGm[indexOffset + offset], this->indexUbSize);

            WaitFlag<HardEvent::MTE3_V>(EVENT_ID0);
            Duplicate(gradInLocal, float(0), this->indexUbSize);
            for (uint64_t idx = 0; idx < this->indexUbSize; idx++) {
                auto indexValue = indexLocal.GetValue(idx);
                if (indexValue >= baseOutOffset && indexValue < baseOutOffset + taskDealNum) {
                    auto gradOutValue = gradOutLocal.GetValue(indexValue - baseOutOffset);
                    gradInLocal.SetValue(idx, gradOutValue);
                }
            }
            SetAtomicAdd<T>();
            DataCopy(gradInGm[indexOffset + offset], gradInLocal, this->indexUbSize);
            SetAtomicNone();
            SetFlag<HardEvent::MTE3_V>(EVENT_ID0);
        }
        if (indexLast != 0) {
            offset = indexLoop * this->indexUbSize;
            uint64_t indicesAlign = AlignUp(indexLast, this->indicesEachBlock);
            DataCopy(indexLocal, indexGm[indexOffset + offset], indicesAlign);
            WaitFlag<HardEvent::MTE3_V>(EVENT_ID0);
            Duplicate(gradInLocal, float(0), indicesAlign);
            for (uint64_t idx = 0; idx < indexLast; idx++) {
                auto indexValue = indexLocal.GetValue(idx);
                if (indexValue >= baseOutOffset && indexValue < baseOutOffset + taskDealNum) {
                    auto gradOutValue = gradOutLocal.GetValue(indexValue - baseOutOffset);
                    gradInLocal.SetValue(idx, gradOutValue);
                }
            }
            SetAtomicAdd<T>();
            DataCopyPad(gradInGm[indexOffset + offset], gradInLocal, this->copyParamsOut);
            SetAtomicNone();
            SetFlag<HardEvent::MTE3_V>(EVENT_ID0);
        }
        WaitFlag<HardEvent::MTE3_V>(EVENT_ID0);
    }

private:
    TPipe pipe;
    TBuf<TPosition::VECCALC> inGradOutUb, inIndexUb, outGradInUb, inCountUb;

    GlobalTensor<T> gradInGm, gradOutGm, countGm;
    GlobalTensor<int32_t> indexGm;

    uint64_t headOutSize;
    uint64_t headIndexSize;
    uint64_t taskNum;
    uint64_t headTask;
    uint64_t headLastTask;
    uint64_t headBaseId;
    uint64_t baseTaskNum;
    uint64_t taskDealNum;
    uint64_t indexLoop;
    uint64_t indexLast;
    uint64_t taskEachHead;
};
}
#endif