/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/**
 * NOTE: Portions of this code were AI-generated and have been
 * technically reviewed for functional accuracy and security
 */

#ifndef APPROXIMATE_EQUAL_KERNEL_H_
#define APPROXIMATE_EQUAL_KERNEL_H_

#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "approximate_equal_tiling_data.h"
#include "approximate_equal_tiling_key.h"

namespace NsApproximateEqual {

using AscendC::GlobalTensor;
using AscendC::LocalTensor;
using AscendC::TPipe;
using AscendC::TQue;
using AscendC::TBuf;
using AscendC::QuePosition;

template <typename T, uint32_t TILING_KEY>
class KernelApproximateEqual {
    static constexpr int32_t BUFFER_NUM = 2;

    static constexpr bool NEEDS_CAST =
        std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t>;

public:
    __aicore__ inline KernelApproximateEqual() = default;

    __aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR y,
                                const ApproximateEqualTilingData* tiling)
    {
        const int64_t blockIdx = AscendC::GetBlockIdx();
        const int64_t remainder = tiling->totalNum - tiling->blockFactor * blockIdx;
        blockLength_ = (remainder > tiling->blockFactor) ? tiling->blockFactor : remainder;
        if (blockLength_ < 0) { blockLength_ = 0; }
        ubLength_ = tiling->ubFactor;
        tolerance_ = tiling->tolerance;

        const int64_t offset = tiling->blockFactor * blockIdx;
        inputGMX1_.SetGlobalBuffer((__gm__ T*)x1 + offset, blockLength_);
        inputGMX2_.SetGlobalBuffer((__gm__ T*)x2 + offset, blockLength_);
        outputGMY_.SetGlobalBuffer((__gm__ uint8_t*)y + offset, blockLength_);

        pipe_.InitBuffer(inputQueueX1_, BUFFER_NUM, ubLength_ * sizeof(T));
        pipe_.InitBuffer(inputQueueX2_, BUFFER_NUM, ubLength_ * sizeof(T));
        pipe_.InitBuffer(outputQueueY_, BUFFER_NUM, ubLength_ * sizeof(uint8_t));
        pipe_.InitBuffer(bitmapBuf_, ubLength_ * sizeof(uint8_t));
        pipe_.InitBuffer(onesBuf_, ubLength_ * sizeof(uint8_t));
        if constexpr (NEEDS_CAST) {
            pipe_.InitBuffer(castBuf1_, ubLength_ * sizeof(float));
            pipe_.InitBuffer(castBuf2_, ubLength_ * sizeof(float));
        }

        LocalTensor<uint8_t> onesLocal = onesBuf_.Get<uint8_t>();
        AscendC::Duplicate<uint8_t>(onesLocal, static_cast<uint8_t>(1),
                                    static_cast<int32_t>(ubLength_));
    }

    __aicore__ inline void Process()
    {
        if (blockLength_ <= 0) { return; }
        const int64_t loopCount = (blockLength_ + ubLength_ - 1) / ubLength_;
        for (int64_t i = 0; i < loopCount; ++i) {
            const int64_t currentNum = (i == loopCount - 1) ? (blockLength_ - ubLength_ * i) : ubLength_;
            CopyIn(i, currentNum);
            Compute(currentNum);
            CopyOut(i, currentNum);
        }
    }

private:
    __aicore__ inline void CopyIn(int64_t progress, int64_t count)
    {
        LocalTensor<T> x1Local = inputQueueX1_.template AllocTensor<T>();
        LocalTensor<T> x2Local = inputQueueX2_.template AllocTensor<T>();
        AscendC::DataCopyParams copyParams;
        copyParams.blockCount = 1;
        copyParams.blockLen = static_cast<uint16_t>(count * sizeof(T));
        copyParams.srcStride = 0;
        copyParams.dstStride = 0;
        AscendC::DataCopyPad(x1Local, inputGMX1_[progress * ubLength_], copyParams, {false, 0, 0, 0});
        AscendC::DataCopyPad(x2Local, inputGMX2_[progress * ubLength_], copyParams, {false, 0, 0, 0});
        inputQueueX1_.EnQue(x1Local);
        inputQueueX2_.EnQue(x2Local);
    }

    __aicore__ inline void Compute(int64_t count)
    {
        // CompareScalar requires count*sizeof(T) to be 256-byte aligned.
        constexpr int64_t COMPARE_ALIGN_ELEMS = 256 / static_cast<int64_t>(sizeof(T));
        const int64_t cmpCount = ((count + COMPARE_ALIGN_ELEMS - 1) / COMPARE_ALIGN_ELEMS) * COMPARE_ALIGN_ELEMS;

        LocalTensor<T> x1Local = inputQueueX1_.template DeQue<T>();
        LocalTensor<T> x2Local = inputQueueX2_.template DeQue<T>();
        LocalTensor<uint8_t> yLocal = outputQueueY_.template AllocTensor<uint8_t>();

        LocalTensor<uint8_t> bitmap  = bitmapBuf_.template Get<uint8_t>();
        LocalTensor<uint8_t> onesLoc = onesBuf_.template Get<uint8_t>();

        if constexpr (std::is_same_v<T, float>) {
            ComputeFp32Body(x1Local, x2Local, bitmap, count, cmpCount);
            AscendC::Select<uint8_t, uint8_t>(yLocal, bitmap, onesLoc, static_cast<uint8_t>(0),
                                              AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
                                              static_cast<int32_t>(count));
        } else {
            LocalTensor<float> x1f = castBuf1_.template Get<float>();
            LocalTensor<float> x2f = castBuf2_.template Get<float>();
            AscendC::Cast<float, T>(x1f, x1Local, AscendC::RoundMode::CAST_NONE,
                                    static_cast<int32_t>(count));
            AscendC::Cast<float, T>(x2f, x2Local, AscendC::RoundMode::CAST_NONE,
                                    static_cast<int32_t>(count));
            AscendC::Sub<float>(x1f, x1f, x2f, static_cast<int32_t>(count));
            AscendC::Abs<float>(x1f, x1f, static_cast<int32_t>(count));
            AscendC::CompareScalar<float, uint8_t>(bitmap, x1f, tolerance_,
                                                   AscendC::CMPMODE::LT,
                                                   static_cast<int32_t>(cmpCount));
            AscendC::Select<uint8_t, uint8_t>(yLocal, bitmap, onesLoc, static_cast<uint8_t>(0),
                                              AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
                                              static_cast<int32_t>(count));
        }

        outputQueueY_.template EnQue<uint8_t>(yLocal);
        inputQueueX1_.FreeTensor(x1Local);
        inputQueueX2_.FreeTensor(x2Local);
    }

    __aicore__ inline void ComputeFp32Body(LocalTensor<T>& x1Local,
                                           LocalTensor<T>& x2Local,
                                           LocalTensor<uint8_t>& bitmap,
                                           int64_t count, int64_t cmpCount)
    {
        AscendC::Sub(x1Local, x1Local, x2Local, count);
        AscendC::Abs(x1Local, x1Local, count);
        AscendC::CompareScalar<T, uint8_t>(bitmap, x1Local, static_cast<T>(tolerance_),
                                           AscendC::CMPMODE::LT, cmpCount);
    }

    __aicore__ inline void CopyOut(int64_t progress, int64_t count)
    {
        LocalTensor<uint8_t> yLocal = outputQueueY_.template DeQue<uint8_t>();
        AscendC::DataCopyParams copyParams;
        copyParams.blockCount = 1;
        copyParams.blockLen = static_cast<uint16_t>(count * sizeof(uint8_t));
        copyParams.srcStride = 0;
        copyParams.dstStride = 0;
        AscendC::DataCopyPad(outputGMY_[progress * ubLength_], yLocal, copyParams);
        outputQueueY_.FreeTensor(yLocal);
    }

private:
    TPipe pipe_;
    TQue<QuePosition::VECIN, BUFFER_NUM> inputQueueX1_;
    TQue<QuePosition::VECIN, BUFFER_NUM> inputQueueX2_;
    TQue<QuePosition::VECOUT, BUFFER_NUM> outputQueueY_;
    TBuf<QuePosition::VECCALC> bitmapBuf_;
    TBuf<QuePosition::VECCALC> onesBuf_;
    TBuf<QuePosition::VECCALC> castBuf1_;
    TBuf<QuePosition::VECCALC> castBuf2_;

    GlobalTensor<T> inputGMX1_;
    GlobalTensor<T> inputGMX2_;
    GlobalTensor<uint8_t> outputGMY_;

    int64_t blockLength_ = 0;
    int64_t ubLength_ = 0;
    float tolerance_ = 1e-5f;
};

}  // namespace NsApproximateEqual

#endif  // APPROXIMATE_EQUAL_KERNEL_H_