* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file sign_float_dag.h
* \brief sign_float_dag.h
*/
#ifndef SIGNBIT_FLOAT_DAG_H
#define SIGNBIT_FLOAT_DAG_H
#include "atvoss/util/dag.h"
#include "atvoss/util/vec.h"
#include "atvoss/util/placeholder.h"
#ifdef __CCE_AICORE__
#include "op_kernel/platform_util.h"
#endif
using namespace Ops::Base;
namespace SignbitFloatOp {
constexpr int CAST_NONE_MODE = 0;
const int16_t STATE_BIT_SHF_VALUE = 31;
const int16_t DOUBLE_STATE_BIT_SHF_VALUE = 63;
template <class T>
struct FloatComputeCustom : public Vec::ElemwiseUnaryOP<uint8_t, T> {
__aicore__ inline FloatComputeCustom(LocalTensor<uint8_t>& dst, LocalTensor<T>& src1, uint32_t count)
{
#ifdef __CCE_AICORE__
constexpr uint32_t VECTOR_LENGTH = GetVRegSize();
uint32_t dtypeSize = sizeof(T);
uint32_t vl = VECTOR_LENGTH / dtypeSize;
uint16_t loopNum = CeilDivision(count, vl);
uint32_t vlSize = vl;
__ubuf__ T* src1Addr = (__ubuf__ T*)src1.GetPhyAddr();
__ubuf__ uint8_t* dstAddr = (__ubuf__ uint8_t*)dst.GetPhyAddr();
AscendC::MicroAPI::RegTensor<T, MicroAPI::RegTraitNumOne> vregInput1;
AscendC::MicroAPI::RegTensor<uint32_t, MicroAPI::RegTraitNumOne> vregOutput;
AscendC::MicroAPI::MaskReg mask;
__VEC_SCOPE__
{
for (uint16_t loopIdx = 0; loopIdx < loopNum; loopIdx++) {
mask = AscendC::MicroAPI::UpdateMask<T, MicroAPI::RegTraitNumOne>(count);
AscendC::MicroAPI::DataCopy(vregInput1, (__ubuf__ T*)(src1Addr + loopIdx * vlSize));
AscendC::MicroAPI::ShiftRights<uint32_t, int16_t>(
vregOutput, (MicroAPI::RegTensor<uint32_t>&)vregInput1, STATE_BIT_SHF_VALUE, mask);
AscendC::MicroAPI::DataCopy<uint8_t, MicroAPI::StoreDist::DIST_PACK4_B32>(
dstAddr + loopIdx * vlSize, (MicroAPI::RegTensor<uint8_t>&)vregOutput, mask);
}
}
#endif
}
};
template <class T>
struct DoubleComputeCustom : public Vec::ElemwiseUnaryOP<uint8_t, T> {
__aicore__ inline DoubleComputeCustom(LocalTensor<uint8_t>& dst, LocalTensor<T>& src1, uint32_t count)
{
#ifdef __CCE_AICORE__
constexpr uint32_t VECTOR_LENGTH = GetVRegSize();
uint32_t dtypeSize = sizeof(T);
uint32_t vl = VECTOR_LENGTH / dtypeSize;
uint16_t loopNum = CeilDivision(count, vl);
uint32_t vlSize = vl;
__ubuf__ T* src1Addr = (__ubuf__ T*)src1.GetPhyAddr();
__ubuf__ uint8_t* dstAddr = (__ubuf__ uint8_t*)dst.GetPhyAddr();
AscendC::MicroAPI::RegTensor<uint64_t, MicroAPI::RegTraitNumOne> vregOutput;
AscendC::MicroAPI::RegTensor<T, MicroAPI::RegTraitNumOne> vregInput1;
AscendC::MicroAPI::MaskReg tmpMask;
AscendC::MicroAPI::MaskReg mask;
uint32_t countTmp = count;
MicroAPI::RegTensor<uint32_t, MicroAPI::RegTraitNumOne> tmpReg;
__VEC_SCOPE__
{
for (uint16_t loopIdx = 0; loopIdx < loopNum; loopIdx++) {
mask = AscendC::MicroAPI::UpdateMask<T, MicroAPI::RegTraitNumOne>(count);
AscendC::MicroAPI::DataCopy(vregInput1, (__ubuf__ T*)(src1Addr + loopIdx * vlSize));
AscendC::MicroAPI::ShiftRights<uint64_t, int16_t>(
vregOutput, (MicroAPI::RegTensor<uint64_t>&)vregInput1, DOUBLE_STATE_BIT_SHF_VALUE, mask);
MicroAPI::Pack<uint32_t, uint64_t, MicroAPI::HighLowPart::LOWEST>(tmpReg, vregOutput);
MicroAPI::Pack<MicroAPI::HighLowPart::LOWEST>(tmpMask, mask);
AscendC::MicroAPI::DataCopy<uint8_t, MicroAPI::StoreDist::DIST_PACK4_B32>(
dstAddr + loopIdx * vlSize, (MicroAPI::RegTensor<uint8_t>&)tmpReg, tmpMask);
}
}
#endif
}
};
}
template <class T>
struct SignbitFloatCompute {
using Input0 = Bind<Vec::CopyIn<T>, Placeholder::In0<T>>;
using OpCastIn0 = Bind<Vec::Cast<float, T, SignbitFloatOp::CAST_NONE_MODE>, Input0>;
using OpFloatCompute = Bind<SignbitFloatOp::FloatComputeCustom<float>, OpCastIn0>;
using OpCopyOut = Bind<Vec::CopyOut<uint8_t>, Placeholder::Out0<uint8_t>, OpFloatCompute>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <class T>
struct SignbitDoubleCompute {
using Input0 = Bind<Vec::CopyIn<T>, Placeholder::In0<T>>;
using OpDoubleCompute = Bind<SignbitFloatOp::DoubleComputeCustom<T>, Input0>;
using OpCopyOut = Bind<Vec::CopyOut<uint8_t>, Placeholder::Out0<uint8_t>, OpDoubleCompute>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
#endif