* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* NOTE: Portions of this code were AI-generated and have been
* technically reviewed for functional accuracy and security
*/
#ifndef APPROXIMATE_EQUAL_KERNEL_H_
#define APPROXIMATE_EQUAL_KERNEL_H_
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "approximate_equal_tiling_data.h"
#include "approximate_equal_tiling_key.h"
namespace NsApproximateEqual {
using AscendC::GlobalTensor;
using AscendC::LocalTensor;
using AscendC::TPipe;
using AscendC::TQue;
using AscendC::TBuf;
using AscendC::QuePosition;
template <typename T, uint32_t TILING_KEY>
class KernelApproximateEqual {
static constexpr int32_t BUFFER_NUM = 2;
static constexpr bool NEEDS_CAST =
std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t>;
public:
__aicore__ inline KernelApproximateEqual() = default;
__aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR y,
const ApproximateEqualTilingData* tiling)
{
const int64_t blockIdx = AscendC::GetBlockIdx();
const int64_t remainder = tiling->totalNum - tiling->blockFactor * blockIdx;
blockLength_ = (remainder > tiling->blockFactor) ? tiling->blockFactor : remainder;
if (blockLength_ < 0) { blockLength_ = 0; }
ubLength_ = tiling->ubFactor;
tolerance_ = tiling->tolerance;
const int64_t offset = tiling->blockFactor * blockIdx;
inputGMX1_.SetGlobalBuffer((__gm__ T*)x1 + offset, blockLength_);
inputGMX2_.SetGlobalBuffer((__gm__ T*)x2 + offset, blockLength_);
outputGMY_.SetGlobalBuffer((__gm__ uint8_t*)y + offset, blockLength_);
pipe_.InitBuffer(inputQueueX1_, BUFFER_NUM, ubLength_ * sizeof(T));
pipe_.InitBuffer(inputQueueX2_, BUFFER_NUM, ubLength_ * sizeof(T));
pipe_.InitBuffer(outputQueueY_, BUFFER_NUM, ubLength_ * sizeof(uint8_t));
pipe_.InitBuffer(bitmapBuf_, ubLength_ * sizeof(uint8_t));
pipe_.InitBuffer(onesBuf_, ubLength_ * sizeof(uint8_t));
if constexpr (NEEDS_CAST) {
pipe_.InitBuffer(castBuf1_, ubLength_ * sizeof(float));
pipe_.InitBuffer(castBuf2_, ubLength_ * sizeof(float));
}
LocalTensor<uint8_t> onesLocal = onesBuf_.Get<uint8_t>();
AscendC::Duplicate<uint8_t>(onesLocal, static_cast<uint8_t>(1),
static_cast<int32_t>(ubLength_));
}
__aicore__ inline void Process()
{
if (blockLength_ <= 0) { return; }
const int64_t loopCount = (blockLength_ + ubLength_ - 1) / ubLength_;
for (int64_t i = 0; i < loopCount; ++i) {
const int64_t currentNum = (i == loopCount - 1) ? (blockLength_ - ubLength_ * i) : ubLength_;
CopyIn(i, currentNum);
Compute(currentNum);
CopyOut(i, currentNum);
}
}
private:
__aicore__ inline void CopyIn(int64_t progress, int64_t count)
{
LocalTensor<T> x1Local = inputQueueX1_.template AllocTensor<T>();
LocalTensor<T> x2Local = inputQueueX2_.template AllocTensor<T>();
AscendC::DataCopyParams copyParams;
copyParams.blockCount = 1;
copyParams.blockLen = static_cast<uint16_t>(count * sizeof(T));
copyParams.srcStride = 0;
copyParams.dstStride = 0;
AscendC::DataCopyPad(x1Local, inputGMX1_[progress * ubLength_], copyParams, {false, 0, 0, 0});
AscendC::DataCopyPad(x2Local, inputGMX2_[progress * ubLength_], copyParams, {false, 0, 0, 0});
inputQueueX1_.EnQue(x1Local);
inputQueueX2_.EnQue(x2Local);
}
__aicore__ inline void Compute(int64_t count)
{
constexpr int64_t COMPARE_ALIGN_ELEMS = 256 / static_cast<int64_t>(sizeof(T));
const int64_t cmpCount = ((count + COMPARE_ALIGN_ELEMS - 1) / COMPARE_ALIGN_ELEMS) * COMPARE_ALIGN_ELEMS;
LocalTensor<T> x1Local = inputQueueX1_.template DeQue<T>();
LocalTensor<T> x2Local = inputQueueX2_.template DeQue<T>();
LocalTensor<uint8_t> yLocal = outputQueueY_.template AllocTensor<uint8_t>();
LocalTensor<uint8_t> bitmap = bitmapBuf_.template Get<uint8_t>();
LocalTensor<uint8_t> onesLoc = onesBuf_.template Get<uint8_t>();
if constexpr (std::is_same_v<T, float>) {
ComputeFp32Body(x1Local, x2Local, bitmap, count, cmpCount);
AscendC::Select<uint8_t, uint8_t>(yLocal, bitmap, onesLoc, static_cast<uint8_t>(0),
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
static_cast<int32_t>(count));
} else {
LocalTensor<float> x1f = castBuf1_.template Get<float>();
LocalTensor<float> x2f = castBuf2_.template Get<float>();
AscendC::Cast<float, T>(x1f, x1Local, AscendC::RoundMode::CAST_NONE,
static_cast<int32_t>(count));
AscendC::Cast<float, T>(x2f, x2Local, AscendC::RoundMode::CAST_NONE,
static_cast<int32_t>(count));
AscendC::Sub<float>(x1f, x1f, x2f, static_cast<int32_t>(count));
AscendC::Abs<float>(x1f, x1f, static_cast<int32_t>(count));
AscendC::CompareScalar<float, uint8_t>(bitmap, x1f, tolerance_,
AscendC::CMPMODE::LT,
static_cast<int32_t>(cmpCount));
AscendC::Select<uint8_t, uint8_t>(yLocal, bitmap, onesLoc, static_cast<uint8_t>(0),
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
static_cast<int32_t>(count));
}
outputQueueY_.template EnQue<uint8_t>(yLocal);
inputQueueX1_.FreeTensor(x1Local);
inputQueueX2_.FreeTensor(x2Local);
}
__aicore__ inline void ComputeFp32Body(LocalTensor<T>& x1Local,
LocalTensor<T>& x2Local,
LocalTensor<uint8_t>& bitmap,
int64_t count, int64_t cmpCount)
{
AscendC::Sub(x1Local, x1Local, x2Local, count);
AscendC::Abs(x1Local, x1Local, count);
AscendC::CompareScalar<T, uint8_t>(bitmap, x1Local, static_cast<T>(tolerance_),
AscendC::CMPMODE::LT, cmpCount);
}
__aicore__ inline void CopyOut(int64_t progress, int64_t count)
{
LocalTensor<uint8_t> yLocal = outputQueueY_.template DeQue<uint8_t>();
AscendC::DataCopyParams copyParams;
copyParams.blockCount = 1;
copyParams.blockLen = static_cast<uint16_t>(count * sizeof(uint8_t));
copyParams.srcStride = 0;
copyParams.dstStride = 0;
AscendC::DataCopyPad(outputGMY_[progress * ubLength_], yLocal, copyParams);
outputQueueY_.FreeTensor(yLocal);
}
private:
TPipe pipe_;
TQue<QuePosition::VECIN, BUFFER_NUM> inputQueueX1_;
TQue<QuePosition::VECIN, BUFFER_NUM> inputQueueX2_;
TQue<QuePosition::VECOUT, BUFFER_NUM> outputQueueY_;
TBuf<QuePosition::VECCALC> bitmapBuf_;
TBuf<QuePosition::VECCALC> onesBuf_;
TBuf<QuePosition::VECCALC> castBuf1_;
TBuf<QuePosition::VECCALC> castBuf2_;
GlobalTensor<T> inputGMX1_;
GlobalTensor<T> inputGMX2_;
GlobalTensor<uint8_t> outputGMY_;
int64_t blockLength_ = 0;
int64_t ubLength_ = 0;
float tolerance_ = 1e-5f;
};
}
#endif