* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file lerp_f16_nddma_with_loops.h
* \brief
*/
#ifndef ASCENDC_LERP_F16_NDDMA_WITH_LOOPS_H_
#define ASCENDC_LERP_F16_NDDMA_WITH_LOOPS_H_
#include "kernel_operator.h"
#include "atvoss/util/broadcast_utils.h"
namespace Lerp {
using AscendC::GlobalTensor;
using AscendC::LocalTensor;
using AscendC::TBuf;
using AscendC::TPipe;
using AscendC::TQue;
using AscendC::MicroAPI::MaskReg;
using AscendC::MicroAPI::RegTensor;
class LerpF16NddmaWithLoops {
public:
__aicore__ inline LerpF16NddmaWithLoops(){};
__aicore__ inline void Init(GM_ADDR start, GM_ADDR end, GM_ADDR weight, GM_ADDR y, GM_ADDR workspace,
const LerpTilingData* tilingDataPtr, TPipe* pipePtr)
{
pipePtr_ = pipePtr;
tilingDataPtr_ = tilingDataPtr;
inputGmStart_.SetGlobalBuffer((__gm__ half*)start);
inputGmEnd_.SetGlobalBuffer((__gm__ half*)end);
inputGmWeight_.SetGlobalBuffer((__gm__ half*)weight);
outputGmY_.SetGlobalBuffer((__gm__ half*)y);
constexpr int64_t DOUBLE_BUFFER = 2;
int64_t BUFFER_SIZE_0 = tilingDataPtr_->elemNum * sizeof(half);
pipePtr_->InitBuffer(queIn0_, DOUBLE_BUFFER, BUFFER_SIZE_0);
pipePtr_->InitBuffer(queIn1_, DOUBLE_BUFFER, BUFFER_SIZE_0);
pipePtr_->InitBuffer(queIn2_, DOUBLE_BUFFER, BUFFER_SIZE_0);
pipePtr_->InitBuffer(queOut0_, DOUBLE_BUFFER, BUFFER_SIZE_0);
}
__aicore__ inline void Process()
{
int64_t ubLoopNum = AscendC::GetBlockIdx() == AscendC::GetBlockNum() - 1 ? tilingDataPtr_->blockTail
: tilingDataPtr_->blockFormer;
int64_t axesIndices[Ops::Base::BROADCAST_MAX_DIMS] = {0};
Ops::Base::BroadcastGetAxesIndices(axesIndices, tilingDataPtr_->blockFormer * AscendC::GetBlockIdx(),
tilingDataPtr_->outputDims, tilingDataPtr_->ubSplitAxis,
tilingDataPtr_->dimProductBeforeUbInner);
for (int64_t ubLoopIdx = 0; ubLoopIdx < ubLoopNum; ubLoopIdx += 1) {
if (ubLoopIdx != 0) {
Ops::Base::BroadcastUpdateAxesIndices(axesIndices, tilingDataPtr_->outputDims, tilingDataPtr_->ubSplitAxis,
tilingDataPtr_->ubOuter);
}
int64_t ubSplitSize = axesIndices[tilingDataPtr_->ubSplitAxis] == tilingDataPtr_->ubOuter - 1
? tilingDataPtr_->ubTail
: tilingDataPtr_->ubFormer;
CopyIn0(ubSplitSize, axesIndices, ubLoopIdx);
CopyIn1(ubSplitSize, axesIndices, ubLoopIdx);
CopyIn2(ubSplitSize, axesIndices, ubLoopIdx);
Compute3(ubSplitSize, axesIndices, ubLoopIdx);
CopyOut4(ubSplitSize, axesIndices, ubLoopIdx);
}
}
private:
__aicore__ inline void CopyIn0(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
int64_t ubLoopIdx)
{
bufferIn0_ = queIn0_.AllocTensor<half>();
if ((tilingDataPtr_->input2Strides[tilingDataPtr_->ubSplitAxis] != 0) ||
(ubLoopIdx <= 1 ||
(AscendC::GetBlockIdx() * tilingDataPtr_->blockFormer + ubLoopIdx) % tilingDataPtr_->ubOuter <= 1)) {
Ops::Base::BroadcastNddmaWithLoop(inputGmWeight_, bufferIn0_, tilingDataPtr_->outputDims,
tilingDataPtr_->outputStrides, tilingDataPtr_->input2Strides, axesIndices,
tilingDataPtr_->ubSplitAxis, tilingDataPtr_->shapeLen, ubSplitSize,
tilingDataPtr_->ubFormer);
}
queIn0_.EnQue<half>(bufferIn0_);
}
__aicore__ inline void CopyIn1(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
int64_t ubLoopIdx)
{
bufferIn1_ = queIn1_.AllocTensor<half>();
if ((tilingDataPtr_->input1Strides[tilingDataPtr_->ubSplitAxis] != 0) ||
(ubLoopIdx <= 1 ||
(AscendC::GetBlockIdx() * tilingDataPtr_->blockFormer + ubLoopIdx) % tilingDataPtr_->ubOuter <= 1)) {
Ops::Base::BroadcastNddmaWithLoop(inputGmEnd_, bufferIn1_, tilingDataPtr_->outputDims, tilingDataPtr_->outputStrides,
tilingDataPtr_->input1Strides, axesIndices, tilingDataPtr_->ubSplitAxis,
tilingDataPtr_->shapeLen, ubSplitSize, tilingDataPtr_->ubFormer);
}
queIn1_.EnQue<half>(bufferIn1_);
}
__aicore__ inline void CopyIn2(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
int64_t ubLoopIdx)
{
bufferIn2_ = queIn2_.AllocTensor<half>();
if ((tilingDataPtr_->input0Strides[tilingDataPtr_->ubSplitAxis] != 0) ||
(ubLoopIdx <= 1 ||
(AscendC::GetBlockIdx() * tilingDataPtr_->blockFormer + ubLoopIdx) % tilingDataPtr_->ubOuter <= 1)) {
Ops::Base::BroadcastNddmaWithLoop(inputGmStart_, bufferIn2_, tilingDataPtr_->outputDims, tilingDataPtr_->outputStrides,
tilingDataPtr_->input0Strides, axesIndices, tilingDataPtr_->ubSplitAxis,
tilingDataPtr_->shapeLen, ubSplitSize, tilingDataPtr_->ubFormer);
}
queIn2_.EnQue<half>(bufferIn2_);
}
__aicore__ inline void Compute3(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
int64_t ubLoopIdx)
{
bufferIn0_ = queIn0_.DeQue<half>();
bufferIn1_ = queIn1_.DeQue<half>();
bufferIn2_ = queIn2_.DeQue<half>();
bufferOut0_ = queOut0_.AllocTensor<half>();
__VEC_SCOPE__
{
RegTensor<half> vreg0;
RegTensor<float> vreg1;
RegTensor<half> vreg2;
RegTensor<float> vreg3;
RegTensor<half> vreg4;
RegTensor<float> vreg5;
RegTensor<float> vreg6;
RegTensor<float> vreg7;
RegTensor<float> vreg8;
RegTensor<float> vreg9;
RegTensor<half> vreg10;
MaskReg preg0;
uint32_t size = ubSplitSize * tilingDataPtr_->outputStrides[tilingDataPtr_->ubSplitAxis];
MaskReg preg2;
preg0 = AscendC::MicroAPI::CreateMask<float>();
AscendC::MicroAPI::Duplicate<float, AscendC::MicroAPI::MaskMergeMode::ZEROING, float>(
vreg7, static_cast<float>(1.0), preg0);
uint16_t vfLoopNum = (ubSplitSize * tilingDataPtr_->outputStrides[tilingDataPtr_->ubSplitAxis] +
(AscendC::VECTOR_REG_WIDTH / 4) - 1) /
(AscendC::VECTOR_REG_WIDTH / 4);
__local_mem__ half* bufferIn0Addr = (__local_mem__ half*)bufferIn0_.GetPhyAddr();
__local_mem__ half* bufferIn1Addr = (__local_mem__ half*)bufferIn1_.GetPhyAddr();
__local_mem__ half* bufferIn2Addr = (__local_mem__ half*)bufferIn2_.GetPhyAddr();
__local_mem__ half* bufferOut0Addr = (__local_mem__ half*)bufferOut0_.GetPhyAddr();
for (uint16_t i = 0; i < vfLoopNum; i++) {
preg0 = AscendC::MicroAPI::UpdateMask<float>(size);
AscendC::MicroAPI::DataCopy<half, AscendC::MicroAPI::LoadDist::DIST_UNPACK_B16>(
vreg0, bufferIn0Addr + i * (AscendC::VECTOR_REG_WIDTH / 4));
AscendC::MicroAPI::Cast<float, half, castTrait0>(vreg1, vreg0, preg0);
AscendC::MicroAPI::CompareScalar<float, AscendC::CMPMODE::LT>(preg2, vreg1, static_cast<float>(0.5),
preg0);
AscendC::MicroAPI::DataCopy<half, AscendC::MicroAPI::LoadDist::DIST_UNPACK_B16>(
vreg2, bufferIn1Addr + i * (AscendC::VECTOR_REG_WIDTH / 4));
AscendC::MicroAPI::Cast<float, half, castTrait0>(vreg3, vreg2, preg0);
AscendC::MicroAPI::DataCopy<half, AscendC::MicroAPI::LoadDist::DIST_UNPACK_B16>(
vreg4, bufferIn2Addr + i * (AscendC::VECTOR_REG_WIDTH / 4));
AscendC::MicroAPI::Cast<float, half, castTrait0>(vreg5, vreg4, preg0);
AscendC::MicroAPI::Sub<float, AscendC::MicroAPI::MaskMergeMode::ZEROING>(vreg6, vreg3, vreg5, preg0);
AscendC::MicroAPI::MulAddDst<float, AscendC::MicroAPI::MaskMergeMode::ZEROING>(vreg5, vreg1, vreg6,
preg0);
AscendC::MicroAPI::Sub<float, AscendC::MicroAPI::MaskMergeMode::ZEROING>(vreg8, vreg1, vreg7, preg0);
AscendC::MicroAPI::MulAddDst<float, AscendC::MicroAPI::MaskMergeMode::ZEROING>(vreg3, vreg8, vreg6,
preg0);
AscendC::MicroAPI::Select<float>(vreg9, vreg5, vreg3, preg2);
AscendC::MicroAPI::Cast<half, float, castTrait1>(vreg10, vreg9, preg0);
AscendC::MicroAPI::DataCopy<half, AscendC::MicroAPI::StoreDist::DIST_PACK_B32>(
bufferOut0Addr + i * (AscendC::VECTOR_REG_WIDTH / 4), vreg10, preg0);
}
}
queIn0_.FreeTensor(bufferIn0_);
queIn1_.FreeTensor(bufferIn1_);
queIn2_.FreeTensor(bufferIn2_);
queOut0_.EnQue<half>(bufferOut0_);
}
__aicore__ inline void CopyOut4(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
int64_t ubLoopIdx)
{
bufferOut0_ = queOut0_.DeQue<half>();
AscendC::DataCopyExtParams dataCopyExtParams;
dataCopyExtParams.blockCount = 1;
dataCopyExtParams.blockLen =
ubSplitSize * tilingDataPtr_->outputStrides[tilingDataPtr_->ubSplitAxis] * sizeof(half);
int64_t gmOffset = Ops::Base::BroadcastGetGmOffset(axesIndices, tilingDataPtr_->outputStrides, tilingDataPtr_->ubSplitAxis,
tilingDataPtr_->ubFormer);
AscendC::DataCopyPad(outputGmY_[gmOffset], bufferOut0_[0], dataCopyExtParams);
queOut0_.FreeTensor(bufferOut0_);
}
private:
TPipe* pipePtr_;
const LerpTilingData* tilingDataPtr_;
GlobalTensor<half> inputGmStart_;
GlobalTensor<half> inputGmEnd_;
GlobalTensor<half> inputGmWeight_;
GlobalTensor<half> outputGmY_;
TQue<AscendC::QuePosition::VECIN, 1> queIn0_;
TQue<AscendC::QuePosition::VECIN, 1> queIn1_;
TQue<AscendC::QuePosition::VECIN, 1> queIn2_;
TQue<AscendC::QuePosition::VECOUT, 1> queOut0_;
LocalTensor<half> bufferIn0_;
LocalTensor<half> bufferIn1_;
LocalTensor<half> bufferIn2_;
LocalTensor<half> bufferOut0_;
constexpr static AscendC::MicroAPI::CastTrait castTrait0 = {
AscendC::MicroAPI::RegLayout::ZERO, AscendC::MicroAPI::SatMode::UNKNOWN,
AscendC::MicroAPI::MaskMergeMode::ZEROING, AscendC::RoundMode::UNKNOWN};
constexpr static AscendC::MicroAPI::CastTrait castTrait1 = {
AscendC::MicroAPI::RegLayout::ZERO, AscendC::MicroAPI::SatMode::NO_SAT,
AscendC::MicroAPI::MaskMergeMode::ZEROING, AscendC::RoundMode::CAST_RINT};
};
}
#endif