* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file tensor_equal.h
* \brief
*/
#ifndef TENSOR_EQUAL_H_
#define TENSOR_EQUAL_H_
#include "kernel_operator.h"
#include "op_kernel/platform_util.h"
namespace TensorEqual {
using namespace AscendC;
constexpr int64_t DOUBLE_BUFFER = 2;
constexpr int64_t EMPTY_SHAPE_TILINGKEY = 101;
constexpr int64_t DIFF_SHAPE_TILINGKEY = 111;
constexpr int64_t OUTPUT_SIZE = 256;
constexpr uint8_t NORMAL_OUTPUT = 1;
constexpr uint8_t DIFF_SHAPE_OUTPUT = 0;
template <typename T>
class TensorEqualKernel {
using InputType = std::conditional_t<std::is_integral_v<T>, uint8_t, T>;
public:
__aicore__ inline TensorEqualKernel(const TensorEqualTilingData& tilingData, TPipe& pipe) :
tilingData_(tilingData), pipe_(pipe) {};
__aicore__ inline void Init(GM_ADDR input_x, GM_ADDR input_y, GM_ADDR output_z, GM_ADDR workspace);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyIn(int64_t offset, int64_t dataLen);
__aicore__ inline void Compute(int64_t dataLen);
private:
TQue<QuePosition::VECIN, DOUBLE_BUFFER> inputXQueue_;
TQue<QuePosition::VECIN, DOUBLE_BUFFER> inputYQueue_;
TBuf<QuePosition::VECCALC> saveBuf_;
TBuf<QuePosition::VECCALC> resultBuf_;
GlobalTensor<InputType> inputXGm_;
GlobalTensor<InputType> inputYGm_;
GlobalTensor<uint8_t> outputZGm_;
int64_t blockIdx_ = 0;
int64_t blockOffset_ = 0;
int64_t totalCoreNum_;
int64_t ubFactor_;
int64_t bufferSize_;
TPipe& pipe_;
const TensorEqualTilingData& tilingData_;
};
template <typename T>
__aicore__ inline void TensorEqualKernel<T>::Init(GM_ADDR input_x, GM_ADDR input_y, GM_ADDR output_z, GM_ADDR workspace)
{
blockIdx_ = GetBlockIdx();
if (blockIdx_ >= tilingData_.usedCoreNum) {
return;
}
blockOffset_ = blockIdx_ * (sizeof(T) / sizeof(InputType)) *
((tilingData_.perCoreLoopTimes - 1) * tilingData_.ubFactor + tilingData_.perCoreTailFactor);
bufferSize_ = tilingData_.ubFactor * sizeof(T);
inputXGm_.SetGlobalBuffer((__gm__ InputType *)(input_x) + blockOffset_);
inputYGm_.SetGlobalBuffer((__gm__ InputType *)(input_y) + blockOffset_);
outputZGm_.SetGlobalBuffer((__gm__ uint8_t *)(output_z));
if (blockIdx_ == 0) {
uint32_t initOutput = 1;
uint8_t globalInitValue = tilingData_.tilingKey == DIFF_SHAPE_TILINGKEY ? DIFF_SHAPE_OUTPUT : NORMAL_OUTPUT;
InitGlobalMemory(outputZGm_, initOutput, globalInitValue);
auto mteWaitMTE3EventID = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
SetFlag<HardEvent::MTE3_MTE2>(mteWaitMTE3EventID);
WaitFlag<HardEvent::MTE3_MTE2>(mteWaitMTE3EventID);
}
pipe_.InitBuffer(inputXQueue_, DOUBLE_BUFFER, bufferSize_);
pipe_.InitBuffer(inputYQueue_, DOUBLE_BUFFER, bufferSize_);
pipe_.InitBuffer(resultBuf_, OUTPUT_SIZE);
pipe_.InitBuffer(saveBuf_, OUTPUT_SIZE);
SyncAll();
}
template <typename T>
__aicore__ inline void TensorEqualKernel<T>::CopyIn(int64_t offset, int64_t dataLen)
{
DataCopyExtParams inParams = { 1, static_cast<uint32_t>(dataLen * sizeof(T)), 0, 0, 0 };
DataCopyPadExtParams<InputType> padParams = { false, 0, 0, 0};
LocalTensor<InputType> xLocal = inputXQueue_.AllocTensor<InputType>();
LocalTensor<InputType> yLocal = inputYQueue_.AllocTensor<InputType>();
DataCopyPad(xLocal, inputXGm_[offset], inParams, padParams);
DataCopyPad(yLocal, inputYGm_[offset], inParams, padParams);
inputXQueue_.EnQue(xLocal);
inputYQueue_.EnQue(yLocal);
}
template <typename T>
__aicore__ inline void TensorEqualKernel<T>::Compute(int64_t dataLen)
{
LocalTensor<InputType> xLocal = inputXQueue_.DeQue<InputType>();
LocalTensor<InputType> yLocal = inputYQueue_.DeQue<InputType>();
__ubuf__ InputType *inputXAddr = (__ubuf__ InputType *)xLocal.GetPhyAddr();
__ubuf__ InputType *inputYAddr = (__ubuf__ InputType *)yLocal.GetPhyAddr();
uint16_t strideVReg = Ops::Base::GetVRegSize();
uint32_t dataLenVf = dataLen * sizeof(T);
uint16_t repeatTimes = (dataLenVf + strideVReg - 1) / strideVReg;
LocalTensor<uint32_t> resultLocal = resultBuf_.Get<uint32_t>();
__ubuf__ uint32_t *resultAddr = (__ubuf__ uint32_t *)resultLocal.GetPhyAddr();
__VEC_SCOPE__
{
AscendC::MicroAPI::RegTensor<InputType> xReg, yReg;
AscendC::MicroAPI::RegTensor<uint16_t> tmpU16Reg, tmpU16RegAdd;
AscendC::MicroAPI::RegTensor<uint32_t> tmpU32Reg;
AscendC::MicroAPI::AddrReg offSetReg;
AscendC::MicroAPI::MaskReg bakMaskRegHigh, bakMaskRegLow, maskReg, cmpMaskReg;
AscendC::MicroAPI::MaskReg allMaskReg = AscendC::MicroAPI::CreateMask<uint8_t, AscendC::MicroAPI::MaskPattern::ALL>();
AscendC::MicroAPI::MaskReg bakMaskReg = AscendC::MicroAPI::CreateMask<uint8_t, AscendC::MicroAPI::MaskPattern::ALLF>();
for (uint16_t i = 0; i < repeatTimes; i++) {
offSetReg = AscendC::MicroAPI::CreateAddrReg<uint8_t>(i, strideVReg);
AscendC::MicroAPI::DataCopy(xReg, inputXAddr, offSetReg);
AscendC::MicroAPI::DataCopy(yReg, inputYAddr, offSetReg);
maskReg = AscendC::MicroAPI::UpdateMask<uint8_t>(dataLenVf);
AscendC::MicroAPI::Compare<InputType, CMPMODE::NE>(cmpMaskReg, xReg, yReg, maskReg);
AscendC::MicroAPI::MaskOr(bakMaskReg, bakMaskReg, cmpMaskReg, allMaskReg);
}
if constexpr (std::is_same<InputType, uint8_t>::value) {
AscendC::MicroAPI::Duplicate(tmpU16Reg, 1);
AscendC::MicroAPI::MaskUnPack<AscendC::MicroAPI::HighLowPart::HIGHEST>(bakMaskRegHigh, bakMaskReg);
AscendC::MicroAPI::MaskUnPack<AscendC::MicroAPI::HighLowPart::LOWEST>(bakMaskRegLow, bakMaskReg);
AscendC::MicroAPI::ReduceMax(tmpU16RegAdd, tmpU16Reg, bakMaskRegHigh);
AscendC::MicroAPI::ReduceMax(tmpU16Reg, tmpU16Reg, bakMaskRegLow);
AscendC::MicroAPI::Add(tmpU16Reg, tmpU16Reg, tmpU16RegAdd, allMaskReg);
AscendC::MicroAPI::UnPack<uint32_t, uint16_t, AscendC::MicroAPI::HighLowPart::LOWEST>(tmpU32Reg, tmpU16Reg);
AscendC::MicroAPI::DataCopy(resultAddr, tmpU32Reg, allMaskReg);
} else if constexpr (std::is_same<InputType, half>::value) {
AscendC::MicroAPI::Duplicate(tmpU16Reg, 1);
AscendC::MicroAPI::ReduceMax(tmpU16Reg, tmpU16Reg, bakMaskReg);
AscendC::MicroAPI::UnPack<uint32_t, uint16_t, AscendC::MicroAPI::HighLowPart::LOWEST>(tmpU32Reg, tmpU16Reg);
AscendC::MicroAPI::DataCopy(resultAddr, tmpU32Reg, allMaskReg);
} else {
AscendC::MicroAPI::Duplicate(tmpU32Reg, 1);
AscendC::MicroAPI::ReduceMax(tmpU32Reg, tmpU32Reg, bakMaskReg);
AscendC::MicroAPI::DataCopy(resultAddr, tmpU32Reg, allMaskReg);
}
}
inputXQueue_.FreeTensor(xLocal);
inputYQueue_.FreeTensor(yLocal);
}
template <typename T>
__aicore__ inline void TensorEqualKernel<T>::Process()
{
if (blockIdx_ >= tilingData_.usedCoreNum || tilingData_.tilingKey == DIFF_SHAPE_TILINGKEY || tilingData_.tilingKey == EMPTY_SHAPE_TILINGKEY) {
return;
}
LocalTensor<uint32_t> saveLocal = saveBuf_.Get<uint32_t>();
AscendC::Duplicate<uint32_t>(saveLocal, 0, 1);
int64_t loopSize_ = blockIdx_ == tilingData_.usedCoreNum - 1 ? tilingData_.tailCoreLoopTimes : tilingData_.perCoreLoopTimes;
int64_t tailFactor_ = blockIdx_ == tilingData_.usedCoreNum - 1 ? tilingData_.tailCoreTailFactor : tilingData_.perCoreTailFactor;
int64_t offset = 0;
for (int64_t idx = 0; idx < loopSize_ - 1; idx++) {
offset = idx * tilingData_.ubFactor;
CopyIn(offset, tilingData_.ubFactor);
Compute(tilingData_.ubFactor);
LocalTensor<uint32_t> resultLocal = resultBuf_.Get<uint32_t>();
saveLocal = saveLocal + resultLocal;
}
offset = (loopSize_ - 1) * tilingData_.ubFactor;
CopyIn(offset, tailFactor_);
Compute(tailFactor_);
LocalTensor<uint32_t> resultLocal = resultBuf_.Get<uint32_t>();
saveLocal = saveLocal + resultLocal;
auto sWaitVEventID = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(sWaitVEventID);
WaitFlag<HardEvent::V_S>(sWaitVEventID);
if (saveLocal(0) != 0 && outputZGm_.GetValue(0) == NORMAL_OUTPUT) {
outputZGm_.SetValue(0, 0);
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE, AscendC::DcciDst::CACHELINE_ALL>(outputZGm_);
return;
}
}
}
#endif