/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
 *
 */
#ifndef _SCATTER_MAX_V1_H_
#define _SCATTER_MAX_V1_H_

#include "kernel_operator.h"

using namespace AscendC;

constexpr uint64_t MASK_ALIGN_SIZE = 256;


class KernelScatterMaxBase {
public:
    __aicore__ inline KernelScatterMaxBase() = delete;

    __aicore__ inline KernelScatterMaxBase(
        GM_ADDR src, GM_ADDR idx, GM_ADDR res, GM_ADDR argmax, ScatterMaxTilingDataV1* tiling_data, TPipe* pipe)
        : _pipe(pipe)
    {
        uint64_t blockIdx = GetBlockIdx();
        uint64_t blockNum = GetBlockNum();

        ASSERT(blockNum != 0 && "block dim can not be zero!");

        _srcElemNum = tiling_data->srcElemNum;
        _idxElemNum = tiling_data->idxElemNum;
        _resElemNum = tiling_data->resElemNum;
        _tailElemNum = tiling_data->tailElemNum;
        _elemNumPerBlock = tiling_data->elemNumPerBlock;
        _tailElemNumAlign = AlignUp(_tailElemNum, _elemNumPerBlock);
        _tailSize = _tailElemNum * sizeof(DTYPE_SRC);
        _tailSizeAlign = _tailElemNumAlign * sizeof(DTYPE_SRC);

        _idxNumPerCore = tiling_data->idxNumPerCore;
        _idxBatchNum = tiling_data->idxBatchNum;
        _idxBatchNumAlign = AlignUp(_idxBatchNum, _elemNumPerBlock);
        _idxBaseOffset = _idxNumPerCore * blockIdx;
        _idxLoop = _idxNumPerCore ? ceilDiv(_idxNumPerCore, _idxBatchNum) : 0;

        _tailBatchNum = tiling_data->tailBatchNum;
        _srcBatchNum = tiling_data->srcBatchNum;
        _srcBatchNumAlign = AlignUp(_srcBatchNum, _elemNumPerBlock);

        _coreNumPerTail = tiling_data->coreNumPerTail;
        _leftSrcNumBigCore = tiling_data->leftSrcNumBigCore;
        _leftSrcBigCoreNum = tiling_data->leftSrcBigCoreNum;
        _leftSrcBatchNum = tiling_data->leftSrcBatchNum;
        _leftSrcBatchNumAlign = AlignUp(_leftSrcBatchNum, _elemNumPerBlock);

        if (_coreNumPerTail == 0) {
            _leftSrcIdxPos = _idxNumPerCore * blockNum;
        } else {
            _leftSrcIdxPos = _idxNumPerCore * blockNum + blockIdx / _coreNumPerTail;
        }

        uint64_t leftIdxNum = _idxElemNum % blockNum;
        if (blockIdx < leftIdxNum * _coreNumPerTail) {
            if (blockIdx % _coreNumPerTail < _leftSrcBigCoreNum) {
                _leftSrcBaseOffset = _leftSrcIdxPos * _tailElemNum + (blockIdx % _coreNumPerTail) * _leftSrcNumBigCore;
                _leftSrcNumCurCore = _leftSrcNumBigCore;
            } else {
                _leftSrcBaseOffset = _leftSrcIdxPos * _tailElemNum + _leftSrcBigCoreNum * _leftSrcNumBigCore +
                                     (blockIdx % _coreNumPerTail - _leftSrcBigCoreNum) * (_leftSrcNumBigCore - 1);
                _leftSrcNumCurCore = _leftSrcNumBigCore - 1;
            }
        } else {
            _leftSrcBaseOffset = 0;
            _leftSrcNumCurCore = 0;
        }
        _leftSrcLoop = (_leftSrcBatchNum == 0) ? 0 : ceilDiv(_leftSrcNumCurCore, _leftSrcBatchNum);

        _srcGM.SetGlobalBuffer((__gm__ DTYPE_SRC*)src, _srcElemNum);
        _resGM.SetGlobalBuffer((__gm__ DTYPE_RES*)res, _resElemNum);
        _idxGM.SetGlobalBuffer((__gm__ DTYPE_INDEX*)idx, _idxElemNum);
        _argmaxGM.SetGlobalBuffer((__gm__ DTYPE_ARGMAX*)argmax, _resElemNum);
    }

protected:
    template<typename T1, typename T2>
    __aicore__ inline T1 ceilDiv(T1 a, T2 b)
    {
        return b == 0 ? 0 : (a + b - 1) / b;
    };

protected:
    TPipe* _pipe;
    TBuf<TPosition::VECCALC> _srcBuf;
    TBuf<TPosition::VECCALC> _idxBuf;

    GlobalTensor<DTYPE_SRC> _srcGM;
    GlobalTensor<DTYPE_INDEX> _idxGM;
    GlobalTensor<DTYPE_RES> _resGM;
    GlobalTensor<DTYPE_ARGMAX> _argmaxGM;

    LocalTensor<DTYPE_SRC> _srcLocal;
    LocalTensor<DTYPE_INDEX> _idxLocal;

    uint64_t _srcElemNum;
    uint64_t _idxElemNum;
    uint64_t _resElemNum;
    uint64_t _tailElemNum;
    uint64_t _tailElemNumAlign;
    uint64_t _tailSize;
    uint64_t _tailSizeAlign;
    uint64_t _elemNumPerBlock;

    uint64_t _idxNumPerCore;
    uint64_t _idxBatchNum;
    uint64_t _idxBatchNumAlign;
    uint64_t _idxBaseOffset;
    uint64_t _idxLoop;

    uint64_t _tailBatchNum;
    uint64_t _srcBatchNum;
    uint64_t _srcBatchNumAlign;
    uint64_t _srcLoop;

    uint64_t _coreNumPerTail;
    uint64_t _leftSrcNumBigCore;
    uint64_t _leftSrcBigCoreNum;
    uint64_t _leftSrcNumCurCore;
    uint64_t _leftSrcBatchNum;
    uint64_t _leftSrcBatchNumAlign;
    uint64_t _leftSrcIdxPos;
    uint64_t _leftSrcBaseOffset;
    uint64_t _leftSrcLoop;
};

template<bool smallTail>
class KernelScatterMaxV1 : public KernelScatterMaxBase {
public:
    __aicore__ inline KernelScatterMaxV1() = delete;

    __aicore__ inline KernelScatterMaxV1(
        GM_ADDR src, GM_ADDR idx, GM_ADDR res, GM_ADDR argmax, ScatterMaxTilingDataV1* tiling_data, TPipe* pipe)
        : KernelScatterMaxBase(src, idx, res, argmax, tiling_data, pipe)
    {
        if constexpr (smallTail) {
            _srcLoop = 0;
        } else {
            _srcLoop = ceilDiv(_tailElemNum, _srcBatchNum);
        }
    }

public:
    __aicore__ inline void Process()
    {
        initBatchProcessBuffer();
        for (uint64_t i = 0; i < _idxLoop; i++) {
            batchProcess(i);
        }

        _pipe->Reset();
        initLeftSrcBuffer();
        for (uint64_t i = 0; i < _leftSrcLoop; i++) {
            processLeftSrc(i);
        }
    }

private:
    __aicore__ inline void initBatchProcessBuffer()
    {
        _pipe->InitBuffer(_idxBuf, _idxBatchNumAlign * sizeof(DTYPE_INDEX));
        if constexpr (smallTail) {
            _pipe->InitBuffer(_srcBuf, _tailBatchNum * _tailSizeAlign);
        } else {
            _pipe->InitBuffer(_srcBuf, _srcBatchNumAlign * sizeof(DTYPE_SRC));
        }
    }

    __aicore__ inline void initLeftSrcBuffer()
    {
        _pipe->InitBuffer(_idxBuf, _elemNumPerBlock * sizeof(DTYPE_INDEX));
        _pipe->InitBuffer(_srcBuf, _leftSrcBatchNumAlign * sizeof(DTYPE_SRC));
    }

    __aicore__ inline void batchProcess(uint64_t i)
    {
        uint64_t idxOffset = _idxBaseOffset + i * _idxBatchNum;
        uint64_t idxLoadNum = min(_idxBatchNum, _idxNumPerCore - i * _idxBatchNum);
        uint64_t idxLoadNumAlgin = AlignUp(idxLoadNum, _elemNumPerBlock);
        uint64_t tailLoop = ceilDiv(idxLoadNum, _tailBatchNum);

        _idxLocal = _idxBuf.Get<DTYPE_INDEX>();
        DataCopy(_idxLocal, _idxGM[idxOffset], idxLoadNumAlgin);

        if constexpr (smallTail) {
            for (uint64_t k = 0; k < tailLoop; k++) {
                tailWisebatchProcess(k, idxOffset, idxLoadNum);
            }
        } else {
            for (uint64_t k = 0; k < idxLoadNum; k++) {
                elemWiseBatchProcess(k, idxOffset);
            }
        }
    }

    __aicore__ inline void tailWisebatchProcess(uint64_t k, uint64_t idxOffset, uint64_t idxLoadNum)
    {
        uint64_t tailOffset = idxOffset + k * _tailBatchNum;
        uint64_t tailLoadNum = min(_tailBatchNum, idxLoadNum - k * _tailBatchNum);
        DataCopyExtParams copyParams = {static_cast<uint16_t>(tailLoadNum), static_cast<uint32_t>(_tailSize), 0, 0, 0};

        _srcLocal = _srcBuf.Get<DTYPE_SRC>();
        DataCopyPad(_srcLocal, _srcGM[tailOffset * _tailElemNum], copyParams, {0, 0, 0, 0});
        SetFlag<HardEvent::MTE2_MTE3>(EVENT_ID0);
        WaitFlag<HardEvent::MTE2_MTE3>(EVENT_ID0);

        SetAtomicMax<DTYPE_RES>();
        for (uint64_t n = 0; n < tailLoadNum; n++) {
            DTYPE_INDEX idxVal = _idxLocal.GetValue(k * _tailBatchNum + n);
            DataCopyPad(_resGM[idxVal * _tailElemNum], _srcLocal[n * _tailElemNumAlign],
                {1, static_cast<uint32_t>(_tailSize), 0, 0, 0});
        }
        SetAtomicNone();
        SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
        WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
    }

    __aicore__ inline void elemWiseBatchProcess(uint64_t k, uint64_t idxOffset)
    {
        DTYPE_INDEX idxVal = _idxLocal.GetValue(idxOffset + k);

        for (uint64_t n = 0; n < _srcLoop; n++) {
            uint64_t srcOffset = (idxOffset + k) * _tailElemNum + n * _srcBatchNum;
            uint64_t srcLoadNum = min(_srcBatchNum, _tailElemNum - n * _srcBatchNum);
            uint64_t srcLoadNumAlign = AlignUp(srcLoadNum, _elemNumPerBlock);

            _srcLocal = _srcBuf.Get<DTYPE_SRC>();
            DataCopy(_srcLocal, _srcGM[srcOffset], srcLoadNumAlign);
            SetFlag<HardEvent::MTE2_MTE3>(EVENT_ID0);
            WaitFlag<HardEvent::MTE2_MTE3>(EVENT_ID0);

            SetAtomicMax<DTYPE_RES>();
            DataCopyPad(_resGM[idxVal * _tailElemNum], _srcLocal,
                {1, static_cast<uint32_t>(srcLoadNum * sizeof(DTYPE_RES)), 0, 0, 0});
            SetAtomicNone();
            SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
            WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
        }
    }

    __aicore__ inline void processLeftSrc(uint64_t i)
    {
        uint64_t srcOffset = _leftSrcBaseOffset + i * _leftSrcBatchNum;
        uint64_t srcLoadNum = min(_leftSrcBatchNum, _leftSrcNumCurCore - i * _leftSrcBatchNum);
        uint64_t srcLoadNumAlign = AlignUp(srcLoadNum, _elemNumPerBlock);

        _idxLocal = _idxBuf.Get<DTYPE_INDEX>();
        _srcLocal = _srcBuf.Get<DTYPE_SRC>();

        DataCopy(_idxLocal, _idxGM[_leftSrcIdxPos], _elemNumPerBlock);
        DataCopy(_srcLocal, _srcGM[srcOffset], srcLoadNumAlign);
        SetFlag<HardEvent::MTE2_MTE3>(EVENT_ID0);
        WaitFlag<HardEvent::MTE2_MTE3>(EVENT_ID0);

        DTYPE_INDEX idxVal = _idxLocal.GetValue(0);
        uint64_t resOffset = idxVal * _tailElemNum + srcOffset % _tailElemNum;

        SetAtomicMax<DTYPE_RES>();
        DataCopyPad(_resGM[resOffset], _srcLocal, {1, static_cast<uint32_t>(srcLoadNum * sizeof(DTYPE_RES)), 0, 0, 0});
        SetAtomicNone();
        SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
        WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
    }
};

template<bool smallTail>
class KernelScatterMaxArgmaxV1 : public KernelScatterMaxBase {
public:
    __aicore__ inline KernelScatterMaxArgmaxV1() = delete;

    __aicore__ inline KernelScatterMaxArgmaxV1(
        GM_ADDR src, GM_ADDR idx, GM_ADDR res, GM_ADDR argmax, ScatterMaxTilingDataV1* tiling_data, TPipe* pipe)
        : KernelScatterMaxBase(src, idx, res, argmax, tiling_data, pipe)
    {
        if constexpr (smallTail) {
            _srcLoop = 0;
        } else {
            _srcLoop = ceilDiv(_tailElemNum, _srcBatchNum);
        }
    }

public:
    __aicore__ inline void Process()
    {
        initBatchProcessBuffer();
        for (uint64_t i = 0; i < _idxLoop; i++) {
            batchProcess(i);
        }

        _pipe->Reset();
        initLeftSrcBuffer();
        for (uint64_t i = 0; i < _leftSrcLoop; i++) {
            processLeftSrc(i);
        }
    }

private:
    __aicore__ inline void initBatchProcessBuffer()
    {
        uint64_t maskBitNum = AscendCUtils::GetBitSize(sizeof(uint8_t));

        if constexpr (smallTail) {
            _pipe->InitBuffer(_srcBuf, _tailBatchNum * _tailSizeAlign);
            _pipe->InitBuffer(_resBuf, _tailSizeAlign);
            _pipe->InitBuffer(_argmaxBuf, _tailSizeAlign);
            _srcMaskNum = AlignUp(_tailSize, MASK_ALIGN_SIZE) / sizeof(DTYPE_SRC);
        } else {
            _pipe->InitBuffer(_srcBuf, _srcBatchNumAlign * sizeof(DTYPE_SRC));
            _pipe->InitBuffer(_resBuf, _srcBatchNumAlign * sizeof(DTYPE_RES));
            _pipe->InitBuffer(_argmaxBuf, _srcBatchNumAlign * sizeof(DTYPE_ARGMAX));
            _srcMaskNum = AlignUp(_srcBatchNum * sizeof(DTYPE_SRC), MASK_ALIGN_SIZE) / sizeof(DTYPE_SRC);
        }

        uint64_t maskBufSize = ceilDiv(_srcMaskNum, maskBitNum) * sizeof(uint8_t);
        _pipe->InitBuffer(_maskBuf, maskBufSize);
        _pipe->InitBuffer(_idxBuf, _idxBatchNumAlign * sizeof(DTYPE_INDEX));
    }

    __aicore__ inline void initLeftSrcBuffer()
    {
        _pipe->InitBuffer(_idxBuf, _elemNumPerBlock * sizeof(DTYPE_INDEX));
        _pipe->InitBuffer(_srcBuf, _leftSrcBatchNumAlign * sizeof(DTYPE_SRC));
        _pipe->InitBuffer(_resBuf, _leftSrcBatchNumAlign * sizeof(DTYPE_RES));
        _pipe->InitBuffer(_argmaxBuf, _leftSrcBatchNumAlign * sizeof(DTYPE_ARGMAX));

        _srcMaskNum = AlignUp(_leftSrcBatchNumAlign * sizeof(DTYPE_SRC), MASK_ALIGN_SIZE) / sizeof(DTYPE_SRC);
        uint64_t maskBitNum = AscendCUtils::GetBitSize(sizeof(uint8_t));
        uint64_t maskBufSize = ceilDiv(_srcMaskNum, maskBitNum) * sizeof(uint8_t);
        _pipe->InitBuffer(_maskBuf, maskBufSize);
    }

    __aicore__ inline void batchProcess(uint64_t i)
    {
        uint64_t idxOffset = _idxBaseOffset + i * _idxBatchNum;
        uint64_t idxLoadNum = min(_idxBatchNum, _idxNumPerCore - i * _idxBatchNum);
        uint64_t idxLoadNumAlgin = AlignUp(idxLoadNum, _elemNumPerBlock);
        uint64_t tailLoop = ceilDiv(idxLoadNum, _tailBatchNum);

        _idxLocal = _idxBuf.Get<DTYPE_INDEX>();
        DataCopy(_idxLocal, _idxGM[idxOffset], idxLoadNumAlgin);

        if constexpr (smallTail) {
            for (uint64_t k = 0; k < tailLoop; k++) {
                tailWisebatchProcess(k, idxOffset, idxLoadNum);
            }
        } else {
            for (uint64_t k = 0; k < idxLoadNum; k++) {
                elemWiseBatchProcess(k, idxOffset);
            }
        }
    }

    __aicore__ inline void tailWisebatchProcess(uint64_t k, uint64_t idxOffset, uint64_t idxLoadNum)
    {
        uint64_t tailOffset = idxOffset + k * _tailBatchNum;
        uint64_t tailLoadNum = min(_tailBatchNum, idxLoadNum - k * _tailBatchNum);
        DataCopyExtParams copyParams = {static_cast<uint16_t>(tailLoadNum), static_cast<uint32_t>(_tailSize), 0, 0, 0};

        _srcLocal = _srcBuf.Get<DTYPE_SRC>();
        _resLocal = _resBuf.Get<DTYPE_RES>();
        _argmaxLocal = _argmaxBuf.Get<DTYPE_ARGMAX>();
        auto _argmaxFloatLocal = _argmaxLocal.ReinterpretCast<float>();
        _maskLocal = _maskBuf.Get<uint8_t>();

        DataCopyPad(_srcLocal, _srcGM[tailOffset * _tailElemNum], copyParams, {0, 0, 0, 0});

        SetAtomicMax<DTYPE_ARGMAX>();
        for (uint64_t n = 0; n < tailLoadNum; n++) {
            DTYPE_INDEX idxVal = _idxLocal.GetValue(k * _tailBatchNum + n);
            uint64_t resOffset = idxVal * _tailElemNum;
            int64_t srcGlobalPos = tailOffset + n;

            DataCopy(_resLocal, _resGM[resOffset], _tailElemNumAlign);
            DataCopy(_argmaxLocal, _argmaxGM[resOffset], _tailElemNumAlign);
            SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
            WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);

            Compare(_maskLocal, _srcLocal[n * _tailElemNumAlign], _resLocal, CMPMODE::NE, _srcMaskNum);
            Cast(_argmaxFloatLocal, _argmaxLocal, RoundMode::CAST_NONE, _tailElemNumAlign);
            Select(_argmaxFloatLocal, _maskLocal, _argmaxFloatLocal, static_cast<float>(srcGlobalPos),
                SELMODE::VSEL_TENSOR_SCALAR_MODE, _tailElemNum);
            Cast(_argmaxLocal, _argmaxFloatLocal, RoundMode::CAST_RINT, _tailElemNumAlign);
            SetFlag<HardEvent::V_MTE3>(EVENT_ID0);
            WaitFlag<HardEvent::V_MTE3>(EVENT_ID0);

            DataCopyPad(_argmaxGM[resOffset], _argmaxLocal, {1, static_cast<uint32_t>(_tailSize), 0, 0, 0});
            SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
            WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
        }
        SetAtomicNone();
    }

    __aicore__ inline void elemWiseBatchProcess(uint64_t k, uint64_t idxOffset)
    {
        DTYPE_INDEX idxVal = _idxLocal.GetValue(idxOffset + k);

        _srcLocal = _srcBuf.Get<DTYPE_SRC>();
        _resLocal = _resBuf.Get<DTYPE_RES>();
        _argmaxLocal = _argmaxBuf.Get<DTYPE_ARGMAX>();
        auto _argmaxFloatLocal = _argmaxLocal.ReinterpretCast<float>();
        _maskLocal = _maskBuf.Get<uint8_t>();

        for (uint64_t n = 0; n < _srcLoop; n++) {
            int64_t idxPos = idxOffset + k;
            uint64_t resOffset = idxVal * _tailElemNum + n * _srcBatchNum;
            uint64_t srcOffset = idxPos * _tailElemNum + n * _srcBatchNum;
            uint64_t srcLoadNum = min(_srcBatchNum, _tailElemNum - n * _srcBatchNum);
            uint64_t srcLoadNumAlign = AlignUp(srcLoadNum, _elemNumPerBlock);

            DataCopy(_srcLocal, _srcGM[srcOffset], srcLoadNumAlign);
            DataCopy(_resLocal, _resGM[resOffset], srcLoadNumAlign);
            DataCopy(_argmaxLocal, _argmaxGM[resOffset], srcLoadNumAlign);
            SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
            WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);

            Compare(_maskLocal, _srcLocal, _resLocal, CMPMODE::NE, _srcMaskNum);
            Cast(_argmaxFloatLocal, _argmaxLocal, RoundMode::CAST_NONE, srcLoadNumAlign);
            Select(_argmaxFloatLocal, _maskLocal, _argmaxFloatLocal, static_cast<float>(idxPos),
                SELMODE::VSEL_TENSOR_SCALAR_MODE, _tailElemNum);
            Cast(_argmaxLocal, _argmaxFloatLocal, RoundMode::CAST_RINT, srcLoadNumAlign);
            SetFlag<HardEvent::V_MTE3>(EVENT_ID0);
            WaitFlag<HardEvent::V_MTE3>(EVENT_ID0);

            SetAtomicMax<DTYPE_ARGMAX>();
            DataCopyPad(_argmaxGM[resOffset], _argmaxLocal,
                {1, static_cast<uint32_t>(srcLoadNum * sizeof(DTYPE_ARGMAX)), 0, 0, 0});
            SetAtomicNone();
            SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
            WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
        }
    }

    __aicore__ inline void processLeftSrc(uint64_t i)
    {
        uint64_t srcOffset = _leftSrcBaseOffset + i * _leftSrcBatchNum;
        uint64_t srcLoadNum = min(_leftSrcBatchNum, _leftSrcNumCurCore - i * _leftSrcBatchNum);
        uint64_t srcLoadNumAlign = AlignUp(srcLoadNum, _elemNumPerBlock);
        int64_t idxPos = _leftSrcIdxPos;

        _idxLocal = _idxBuf.Get<DTYPE_INDEX>();
        _srcLocal = _srcBuf.Get<DTYPE_SRC>();
        _resLocal = _resBuf.Get<DTYPE_RES>();
        _argmaxLocal = _argmaxBuf.Get<DTYPE_ARGMAX>();
        auto _argmaxFloatLocal = _argmaxLocal.ReinterpretCast<float>();
        _maskLocal = _maskBuf.Get<uint8_t>();

        DataCopy(_idxLocal, _idxGM[_leftSrcIdxPos], _elemNumPerBlock);
        DataCopy(_srcLocal, _srcGM[srcOffset], srcLoadNumAlign);

        DTYPE_INDEX idxVal = _idxLocal.GetValue(0);
        uint64_t resOffset = idxVal * _tailElemNum + srcOffset % _tailElemNum;

        DataCopy(_resLocal, _resGM[resOffset], srcLoadNumAlign);
        DataCopy(_argmaxLocal, _argmaxGM[resOffset], srcLoadNumAlign);
        SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
        WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);

        Compare(_maskLocal, _srcLocal, _resLocal, CMPMODE::NE, _srcMaskNum);
        Cast(_argmaxFloatLocal, _argmaxLocal, RoundMode::CAST_NONE, srcLoadNumAlign);
        Select(_argmaxFloatLocal, _maskLocal, _argmaxFloatLocal, static_cast<float>(idxPos),
            SELMODE::VSEL_TENSOR_SCALAR_MODE, srcLoadNum);
        Cast(_argmaxLocal, _argmaxFloatLocal, RoundMode::CAST_RINT, srcLoadNumAlign);
        SetFlag<HardEvent::V_MTE3>(EVENT_ID0);
        WaitFlag<HardEvent::V_MTE3>(EVENT_ID0);

        SetAtomicMax<DTYPE_ARGMAX>();
        DataCopyPad(
            _argmaxGM[resOffset], _argmaxLocal, {1, static_cast<uint32_t>(srcLoadNum * sizeof(DTYPE_ARGMAX)), 0, 0, 0});
        SetAtomicNone();
        SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
        WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0);
    }

private:
    TBuf<TPosition::VECCALC> _resBuf;
    TBuf<TPosition::VECCALC> _argmaxBuf;
    TBuf<TPosition::VECCALC> _maskBuf;

    LocalTensor<DTYPE_RES> _resLocal;
    LocalTensor<DTYPE_ARGMAX> _argmaxLocal;
    LocalTensor<uint8_t> _maskLocal;

    uint64_t _srcMaskNum;
};

#endif