/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
 * MindIE is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */
#ifndef __SRC_OPS_KERNEL_LA_PREPROCESS_H__
#define __SRC_OPS_KERNEL_LA_PREPROCESS_H__

#include <kernel_operator.h>


namespace mmdit_ops {

namespace kernels {

template <typename T, typename DST, int32_t HEAD_DIM = 128, int32_t QUEUE_DEPTH = 1>
class LaPreprocess {
public:
    __aicore__ inline LaPreprocess()
        : blockIdx_(AscendC::GetBlockIdx()), blockDim_(AscendC::GetBlockNum()) {}

    __aicore__ inline void Init(
        GM_ADDR query, GM_ADDR key, GM_ADDR value,
        GM_ADDR outQuery, GM_ADDR outKey, GM_ADDR outValue,
        const LaPreprocessTilingData *tiling, AscendC::TPipe *pipe)
    {
        batchSize_ = tiling->batchSize;
        headNum_ = tiling->headNum;

        qSeqLen_ = tiling->qSeqLen;
        kSeqLen_ = tiling->kSeqLen;
        vSeqLen_ = tiling->vSeqLen;

        alignLen_ = tiling->alignLen;
        ubSize_ = tiling->ubSize;

        if constexpr (std::is_same_v<T, bfloat16_t>) {
            blockSeqLen_ =
                (ubSize_ - HEAD_DIM * sizeof(T)) /
                (QUEUE_DEPTH * headNum_ * HEAD_DIM) / (sizeof(float) + sizeof(T));
        } else {
            blockSeqLen_ =
                (ubSize_ - HEAD_DIM * sizeof(T)) /
                (QUEUE_DEPTH * headNum_ * HEAD_DIM) / sizeof(T);
        }

        pipe_ = pipe;

        InitBuffers();
        InitGlobal(query, key, value, outQuery, outKey, outValue);
    }

    __aicore__ inline void Process()
    {
        SplitTask(qSeqLen_);
        CopyData(outQueryGm_, queryGm_);

        SplitTask(kSeqLen_);
        CopyData(outKeyGm_, keyGm_);

        SplitTask(vSeqLen_);
        CopyData(outValueGm_, valueGm_);
    }

private:
    __aicore__ inline void InitBuffers()
    {
        bufLen_ = blockSeqLen_ * headNum_ * HEAD_DIM;
        if constexpr (std::is_same_v<T, bfloat16_t>) {
            pipe_->InitBuffer(inQue_, QUEUE_DEPTH, bufLen_ * sizeof(float));
            pipe_->InitBuffer(outQue_, QUEUE_DEPTH, bufLen_ * sizeof(T));
        } else {
            pipe_->InitBuffer(movQueBind_, QUEUE_DEPTH, bufLen_ * sizeof(T));
        }
        pipe_->InitBuffer(zeroBuf_, HEAD_DIM * sizeof(DST));

        zeroTensor_ = zeroBuf_.Get<DST>();
        AscendC::Duplicate<DST>(zeroTensor_, static_cast<DST>(0.0), HEAD_DIM);
    }

    __aicore__ inline void InitGlobal(
        GM_ADDR query, GM_ADDR key, GM_ADDR value,
        GM_ADDR outQuery, GM_ADDR outKey, GM_ADDR outValue)
    {
        queryGm_.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(query));
        keyGm_.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(key));
        valueGm_.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(value));
        outQueryGm_.SetGlobalBuffer(reinterpret_cast<__gm__ DST *>(outQuery));
        outKeyGm_.SetGlobalBuffer(reinterpret_cast<__gm__ DST *>(outKey));
        outValueGm_.SetGlobalBuffer(reinterpret_cast<__gm__ DST *>(outValue));
    }

    __aicore__ inline void SplitTask(uint32_t seqLen)
    {
        uint32_t tailLen = 0;

        curSeqLen_ = seqLen;

        singleSeqLen_ = curSeqLen_ / blockDim_;
        tailLen = curSeqLen_ % blockDim_;
        curSeq_ = singleSeqLen_ * blockIdx_;
        if (blockIdx_ < tailLen) {
            singleSeqLen_ += 1;
            curSeq_ += blockIdx_;
        } else {
            curSeq_ += tailLen;
        }

        curSeqAlignLen_ = AscendC::AlignUp(curSeqLen_, alignLen_);

        // for pad
        curPadSeqLen_ = curSeqAlignLen_ - curSeqLen_;

        singlePadSeqLen_ = curPadSeqLen_ / blockDim_;
        tailLen = curPadSeqLen_ % blockDim_;
        curPadSeq_ = singlePadSeqLen_ * blockIdx_ + curSeqLen_;
        if (blockIdx_ < tailLen) {
            singlePadSeqLen_ += 1;
            curPadSeq_ += blockIdx_;
        } else {
            curPadSeq_ += tailLen;
        }
    }

    __aicore__ inline void CopyIn(const AscendC::GlobalTensor<T>& src, uint32_t seqLen)
    {
        if constexpr (std::is_same_v<T, bfloat16_t>) {
            AscendC::LocalTensor<T> srcLocal = inQue_.template AllocTensor<T>();

            AscendC::DataCopyExtParams inCopyParams{
                1, seqLen * headNum_ * HEAD_DIM * static_cast<uint32_t>(sizeof(T)), 0, 0, 0};
            AscendC::DataCopyPadExtParams<T> inPadParams{false, 0, 0, 0};
            AscendC::DataCopyPad(srcLocal[bufLen_], src, inCopyParams, inPadParams);

            inQue_.EnQue(srcLocal);
            AscendC::LocalTensor<T> castLocal = inQue_.template DeQue<T>();
            AscendC::LocalTensor<DST> dstLocal = outQue_.template AllocTensor<DST>();

            AscendC::Cast<float, T>(
                castLocal.template ReinterpretCast<float>(), castLocal[bufLen_],
                AscendC::RoundMode::CAST_NONE, bufLen_);
            AscendC::PipeBarrier<PIPE_V>();
            AscendC::Cast<DST, float>(
                dstLocal, castLocal.template ReinterpretCast<float>(),
                AscendC::RoundMode::CAST_RINT, bufLen_);

            outQue_.EnQue(dstLocal);
            inQue_.FreeTensor(castLocal);
        } else {
            AscendC::LocalTensor<T> srcLocal = movQueBind_.template AllocTensor<T>();

            AscendC::DataCopyExtParams inCopyParams{
                1, seqLen * headNum_ * HEAD_DIM * static_cast<uint32_t>(sizeof(T)), 0, 0, 0};
            AscendC::DataCopyPadExtParams<T> inPadParams{false, 0, 0, 0};
            AscendC::DataCopyPad(srcLocal, src, inCopyParams, inPadParams);

            movQueBind_.EnQue(srcLocal);
        }
    }

    __aicore__ inline void CopyOut(const AscendC::GlobalTensor<DST>& dst, uint32_t seqLen)
    {
        if constexpr (std::is_same_v<T, bfloat16_t>) {
            AscendC::LocalTensor<DST> dstLocal = outQue_.template DeQue<DST>();

            uint32_t dstStride = (curSeqAlignLen_ - 1) * HEAD_DIM * sizeof(DST);

            for (uint32_t i = 0; i < seqLen; ++i) {
                AscendC::DataCopyExtParams outCopyParams{
                    static_cast<uint16_t>(headNum_), HEAD_DIM * sizeof(T), 0, dstStride, 0};
                AscendC::DataCopyPad(dst[i * HEAD_DIM], dstLocal[i * headNum_ * HEAD_DIM], outCopyParams);
            }

            outQue_.FreeTensor(dstLocal);
        } else {
            AscendC::LocalTensor<DST> dstLocal = movQueBind_.template DeQue<DST>();

            uint32_t dstStride = (curSeqAlignLen_ - 1) * HEAD_DIM * sizeof(DST);

            for (uint32_t i = 0; i < seqLen; ++i) {
                AscendC::DataCopyExtParams outCopyParams{
                    static_cast<uint16_t>(headNum_), HEAD_DIM * sizeof(DST), 0, dstStride, 0};
                AscendC::DataCopyPad(dst[i * HEAD_DIM], dstLocal[i * headNum_ * HEAD_DIM], outCopyParams);
            }

            movQueBind_.FreeTensor(dstLocal);
        }
    }

    __aicore__ inline void PadOut(const AscendC::GlobalTensor<DST>& dst, uint32_t seqLen)
    {
        for (uint32_t i = 0; i < headNum_; ++i) {
            for (uint32_t j = 0; j < seqLen; ++j) {
                AscendC::DataCopyExtParams outCopyParams{
                    1, HEAD_DIM * sizeof(DST), 0, 0, 0};
                AscendC::DataCopyPad(
                    dst[j * HEAD_DIM + i * curSeqAlignLen_ * HEAD_DIM], zeroTensor_, outCopyParams);
            }
        }
    }

    __aicore__ inline void CopyData(const AscendC::GlobalTensor<DST>& dst, const AscendC::GlobalTensor<T>& src)
    {
        for (uint32_t i = 0; i < batchSize_; ++i) {
            uint32_t inBatchOffset = i * curSeqLen_ * headNum_ * HEAD_DIM;
            uint32_t outBatchOffset = i * curSeqAlignLen_ * headNum_ * HEAD_DIM;

            if (singlePadSeqLen_ > 0) {
                PadOut(dst[outBatchOffset + curPadSeq_ * HEAD_DIM], singlePadSeqLen_);
            }

            for (uint32_t j = curSeq_; j < curSeq_ + singleSeqLen_; j += blockSeqLen_) {
                uint32_t seqLen =
                    j + blockSeqLen_ > curSeq_ + singleSeqLen_ ?
                    curSeq_ + singleSeqLen_ - j : blockSeqLen_;

                CopyIn(src[inBatchOffset + j * headNum_ * HEAD_DIM], seqLen);

                CopyOut(dst[outBatchOffset + j * HEAD_DIM], seqLen);
            }
        }
    }

    uint32_t blockIdx_;
    uint32_t blockDim_;

    uint32_t batchSize_;
    uint32_t qSeqLen_;
    uint32_t kSeqLen_;
    uint32_t vSeqLen_;

    uint32_t curSeqAlignLen_;

    uint32_t headNum_;
    uint32_t alignLen_;
    uint32_t ubSize_;
    uint32_t blockSeqLen_;

    uint32_t singleSeqLen_;
    uint32_t curSeq_;
    uint32_t curSeqLen_;

    uint32_t singlePadSeqLen_;
    uint32_t curPadSeq_;
    uint32_t curPadSeqLen_;

    uint32_t bufLen_;

    AscendC::TQue<AscendC::TPosition::VECIN, QUEUE_DEPTH> inQue_;
    AscendC::TQue<AscendC::TPosition::VECOUT, QUEUE_DEPTH> outQue_;
    AscendC::TQueBind<AscendC::TPosition::VECIN, AscendC::TPosition::VECOUT, QUEUE_DEPTH> movQueBind_;
    AscendC::TBuf<AscendC::TPosition::VECCALC> zeroBuf_;

    AscendC::LocalTensor<DST> zeroTensor_;

    AscendC::GlobalTensor<T> queryGm_;
    AscendC::GlobalTensor<T> keyGm_;
    AscendC::GlobalTensor<T> valueGm_;
    AscendC::GlobalTensor<DST> outQueryGm_;
    AscendC::GlobalTensor<DST> outKeyGm_;
    AscendC::GlobalTensor<DST> outValueGm_;

    AscendC::TPipe *pipe_ = nullptr;
};

} // namespace kernels

} // namespace mmdit_ops

#endif // __SRC_OPS_KERNEL_LA_PREPROCESS_H__