* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file is_close_dag.h
* \brief if features > 0 return gradients, else return *gradients.
*/
#ifndef IS_CLOSE_DAG_H
#define IS_CLOSE_DAG_H
#include "atvoss/util/dag.h"
#include "atvoss/util/vec.h"
#include "atvoss/util/placeholder.h"
#ifndef INFINITY
#define INFINITY (__builtin_inff())
#endif
namespace IsCloseOp {
using namespace Ops::Base;
constexpr int COMPARE_MODE_LT = 0;
constexpr int COMPARE_MODE_E = 2;
constexpr int COMPARE_MODE_LE = 3;
constexpr int COMPARE_MODE_NE = 5;
constexpr int GREATER_SEL_MODE = 2;
constexpr float INF_CONST = INFINITY;
template <int equalNan>
struct NanEqualCompare : public Vec::ElemwiseQuaternaryOP<uint8_t, float, float, float, float> {
__aicore__ inline NanEqualCompare(
LocalTensor<uint8_t>& dst, LocalTensor<float>& src1, LocalTensor<float>& src2, float rtol, float atol,
uint32_t count)
{
#ifdef __CCE_AICORE__
uint32_t dtypeSize = sizeof(float);
constexpr uint64_t VECTOR_REG_WIDTH = 256UL;
uint32_t vl = VECTOR_REG_WIDTH / dtypeSize;
uint32_t loopNum = (count + vl - 1) / vl;
uint32_t vlSize = vl;
__ubuf__ float* src1Addr = (__ubuf__ float*)src1.GetPhyAddr();
__ubuf__ float* src2Addr = (__ubuf__ float*)src2.GetPhyAddr();
__ubuf__ uint8_t* dstAddr = (__ubuf__ uint8_t*)dst.GetPhyAddr();
MicroAPI::RegTensor<float> x1Reg;
MicroAPI::RegTensor<float> x2Reg;
MicroAPI::RegTensor<float> absSubReg;
MicroAPI::RegTensor<float> x2AbsReg;
MicroAPI::RegTensor<float> rtolTmpReg;
MicroAPI::RegTensor<float> rtolReg;
MicroAPI::RegTensor<uint8_t> resultReg;
MicroAPI::RegTensor<float> regTensorInfinity;
MicroAPI::RegTensor<uint8_t> regTensorOne;
MicroAPI::RegTensor<uint8_t> regTensorZero;
MicroAPI::RegTensor<uint32_t> reguint32;
MicroAPI::RegTensor<uint16_t> reguint16;
MicroAPI::RegTensor<uint8_t> reguint8;
MicroAPI::MaskReg mask;
MicroAPI::MaskReg equalCmpMask;
MicroAPI::MaskReg tmpMask1;
MicroAPI::MaskReg tmpMask2;
MicroAPI::MaskReg bothNanMask;
MicroAPI::MaskReg funcCmpMask;
MicroAPI::MaskReg finiteMask;
MicroAPI::MaskReg resultMask;
if constexpr (equalNan == 1) {
__VEC_SCOPE__
{
for (uint16_t loopIdx = 0; loopIdx < loopNum; loopIdx++) {
mask = MicroAPI::UpdateMask<float, MicroAPI::RegTraitNumOne>(count);
MicroAPI::Duplicate(regTensorOne, (uint8_t)1.0, mask);
MicroAPI::Duplicate(regTensorZero, (uint8_t)0.0, mask);
MicroAPI::DataCopy(x1Reg, (__ubuf__ float*)(src1Addr + loopIdx * vlSize));
MicroAPI::DataCopy(x2Reg, (__ubuf__ float*)(src2Addr + loopIdx * vlSize));
MicroAPI::Compare<float, CMPMODE::EQ>(equalCmpMask, x1Reg, x2Reg, mask);
MicroAPI::Compare<float, CMPMODE::NE>(tmpMask1, x1Reg, x1Reg, mask);
MicroAPI::Compare<float, CMPMODE::NE>(tmpMask2, x2Reg, x2Reg, mask);
MicroAPI::MaskAnd(bothNanMask, tmpMask1, tmpMask2, mask);
MicroAPI::MaskOr(tmpMask2, bothNanMask, equalCmpMask, mask);
MicroAPI::FusedAbsSub(absSubReg, x1Reg, x2Reg, mask);
MicroAPI::Abs(x2AbsReg, x2Reg, mask);
MicroAPI::Muls(rtolTmpReg, x2AbsReg, rtol, mask);
MicroAPI::Adds(rtolReg, rtolTmpReg, atol, mask);
MicroAPI::Compare<float, CMPMODE::LE>(funcCmpMask, absSubReg, rtolReg, mask);
MicroAPI::Duplicate(regTensorInfinity, (float)INF_CONST, mask);
MicroAPI::Compare<float, CMPMODE::LT>(finiteMask, x2AbsReg, regTensorInfinity, mask);
MicroAPI::MaskAnd(tmpMask1, funcCmpMask, finiteMask, mask);
MicroAPI::MaskOr(resultMask, tmpMask2, tmpMask1, mask);
MicroAPI::Select(resultReg, regTensorOne, regTensorZero, resultMask);
MicroAPI::Pack(reguint16, (AscendC::MicroAPI::RegTensor<uint32_t>&)resultReg);
MicroAPI::Pack(reguint8, reguint16);
AscendC::MicroAPI::MaskPack<AscendC::MicroAPI::HighLowPart::LOWEST>(tmpMask1, mask);
AscendC::MicroAPI::MaskPack<AscendC::MicroAPI::HighLowPart::LOWEST>(mask, tmpMask1);
MicroAPI::DataCopy((__ubuf__ uint8_t*)(dstAddr + loopIdx * vlSize), reguint8, mask);
}
}
} else {
__VEC_SCOPE__
{
for (uint16_t loopIdx = 0; loopIdx < loopNum; loopIdx++) {
mask = MicroAPI::UpdateMask<float, MicroAPI::RegTraitNumOne>(count);
MicroAPI::Duplicate(regTensorOne, (uint8_t)1.0f, mask);
MicroAPI::Duplicate(regTensorZero, (uint8_t)0.0f, mask);
MicroAPI::DataCopy(x1Reg, (__ubuf__ float*)(src1Addr + loopIdx * vlSize));
MicroAPI::DataCopy(x2Reg, (__ubuf__ float*)(src2Addr + loopIdx * vlSize));
MicroAPI::Compare<float, CMPMODE::EQ>(equalCmpMask, x1Reg, x2Reg, mask);
MicroAPI::FusedAbsSub(absSubReg, x1Reg, x2Reg, mask);
MicroAPI::Abs(x2AbsReg, x2Reg, mask);
MicroAPI::Muls(rtolTmpReg, x2AbsReg, rtol, mask);
MicroAPI::Adds(rtolReg, rtolTmpReg, atol, mask);
MicroAPI::Compare<float, CMPMODE::LE>(funcCmpMask, absSubReg, rtolReg, mask);
MicroAPI::Duplicate(regTensorInfinity, (float)INF_CONST, mask);
MicroAPI::Compare<float, CMPMODE::LT>(finiteMask, x2AbsReg, regTensorInfinity, mask);
MicroAPI::MaskAnd(tmpMask1, funcCmpMask, finiteMask, mask);
MicroAPI::MaskOr(resultMask, equalCmpMask, tmpMask1, mask);
MicroAPI::Select(resultReg, regTensorOne, regTensorZero, resultMask);
MicroAPI::Pack(reguint16, (AscendC::MicroAPI::RegTensor<uint32_t>&)resultReg);
MicroAPI::Pack(reguint8, reguint16);
AscendC::MicroAPI::MaskPack<AscendC::MicroAPI::HighLowPart::LOWEST>(tmpMask1, mask);
AscendC::MicroAPI::MaskPack<AscendC::MicroAPI::HighLowPart::LOWEST>(mask, tmpMask1);
MicroAPI::DataCopy((__ubuf__ uint8_t*)(dstAddr + loopIdx * vlSize), reguint8, mask);
}
}
}
#endif
}
};
template <typename T>
struct IsCloseDag {
using InputX1T = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using InputX2T = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using InputX1 = Bind<Vec::Cast<float, T, 0>, InputX1T>;
using InputX2 = Bind<Vec::Cast<float, T, 0>, InputX2T>;
using OpResult = Bind<NanEqualCompare<0>, InputX1, InputX2, Placeholder::Var<float, 0>, Placeholder::Var<float, 1>>;
using OpCopyOut = Bind<Vec::CopyOut<uint8_t>, Placeholder::Out0<uint8_t>, OpResult>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T>
struct IsCloseEqualNanDag {
using InputX1T = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using InputX2T = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using InputX1 = Bind<Vec::Cast<float, T, 0>, InputX1T>;
using InputX2 = Bind<Vec::Cast<float, T, 0>, InputX2T>;
using OpResult = Bind<NanEqualCompare<1>, InputX1, InputX2, Placeholder::Var<float, 0>, Placeholder::Var<float, 1>>;
using OpCopyOut = Bind<Vec::CopyOut<uint8_t>, Placeholder::Out0<uint8_t>, OpResult>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
}
#endif