* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file matrix_diag_scatter_lower.h
* \brief matrix_diag_scatter
*/
#ifndef MATRIX_DIAG_SCATTER_LOWER_H
#define MATRIX_DIAG_SCATTER_LOWER_H
#include "kernel_operator.h"
#include "matrix_diag_tiling_data_struct.h"
#include "op_kernel/platform_util.h"
namespace MatrixDiagAsc {
using namespace AscendC;
template <typename T, typename U, typename Y>
class MatrixDiagScatterLower
{
public:
__aicore__ inline MatrixDiagScatterLower(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, const MatrixDiagScatterTilingData* tilingData, TPipe* pipe);
__aicore__ inline void Process();
private:
__aicore__ inline void ProcessPerCore(uint64_t ubLoopNum);
__aicore__ inline void ProcessPerCoreInt8(uint64_t ubLoopNum);
__aicore__ inline void CopyIn(uint64_t i, int64_t totalNum);
__aicore__ inline void CopyOut(uint64_t i, int64_t totalNum);
__aicore__ inline void genIndex();
__aicore__ inline void VFProcess(uint32_t maskNum0, uint32_t processNum, uint32_t totalNum, U addrOffset,
uint16_t loopNum, __local_mem__ U *idxPtr, __local_mem__ T *srcPtr, __local_mem__ T *dstPtr);
__aicore__ inline void VFProcessInt8(uint32_t maskNumInt8, uint32_t processNum, uint32_t totalNum,
U addrOffset, uint16_t loopNum, __local_mem__ U *idxPtr, __local_mem__ T *srcPtr, __local_mem__ T *dstPtr);
__aicore__ inline void MainProcess(uint64_t loop, uint32_t maskNum, uint32_t processNum, uint32_t totalNum, U addrOffset,
uint16_t loopNum, __local_mem__ U *idxPtr);
private:
constexpr static int32_t BUFFER_NUM = 2;
int64_t blockIdx_;
TQue<QuePosition::VECIN, 1> xQueue_;
TQue<QuePosition::VECOUT, 1> yQueue_;
TQue<QuePosition::VECCALC, 1> indexBuf_;
GlobalTensor<T> xGM_;
GlobalTensor<T> yGM_;
const MatrixDiagScatterTilingData* tiling_;
int32_t vlLen_ = Ops::Base::GetVRegSize() / sizeof(T);
int64_t xGmOffset_ = 0;
int64_t yGmOffset_ = 0;
int64_t blockNum_ = 0;
int64_t blockFactor_ = 0;
int64_t nSize_ = 0;
int64_t batchUbFactor_ = 0;
int64_t batchNum_ = 0;
int64_t blockMainCount_ = 0;
int64_t blockTailFactor_ = 0;
int64_t realCoreNum_ = 0;
int64_t batchUbTailFactor_ = 0;
DataCopyExtParams copyInParams_{1, 0, 0, 0, 0};
DataCopyExtParams copyOutParams_{1, 0, 0, 0, 0};
DataCopyPadExtParams<T> padParams_{false, 0, 0, 0};
};
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::Init(GM_ADDR x, GM_ADDR y, const MatrixDiagScatterTilingData* tilingData, TPipe* pipe)
{
blockIdx_ = GetBlockIdx();
tiling_ = tilingData;
nSize_ = tiling_->nSize;
batchUbFactor_ = tiling_->batchUbFactor;
blockMainCount_ = tiling_->blockMainCount;
blockTailFactor_ = tiling_->blockTailFactor;
realCoreNum_ = tiling_->realCoreNum;
batchUbTailFactor_ = tiling_->batchUbTailFactor;
blockFactor_ = tiling_->blockFactor;
xGM_.SetGlobalBuffer((__gm__ T*)x);
yGM_.SetGlobalBuffer((__gm__ T*)y);
pipe->InitBuffer(xQueue_, BUFFER_NUM, batchUbFactor_ * nSize_ * sizeof(T));
pipe->InitBuffer(yQueue_, BUFFER_NUM, batchUbFactor_ * nSize_ * nSize_ * sizeof(T));
pipe->InitBuffer(indexBuf_, 1, Ops::Base::GetVRegSize());
}
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::genIndex()
{
uint32_t num = vlLen_;
LocalTensor<Y> indexLocal = indexBuf_.AllocTensor<Y>();
__local_mem__ U* dst = (__local_mem__ U*)indexLocal.GetPhyAddr();
__VEC_SCOPE__{
MicroAPI::RegTensor<Y> indexReg;
MicroAPI::RegTensor<U> niReg;
MicroAPI::RegTensor<U> tmp;
MicroAPI::RegTensor<U> tmp1;
MicroAPI::RegTensor<U> tmp2;
MicroAPI::RegTensor<U> subReg;
MicroAPI::RegTensor<U> addReg;
MicroAPI::RegTensor<U> tmp3;
MicroAPI::MaskReg mask;
mask = MicroAPI::UpdateMask<U>(num);
MicroAPI::Arange(indexReg, 0);
MicroAPI::Duplicate(niReg,U(nSize_));
MicroAPI::Div(tmp,(MicroAPI::RegTensor<U> &)indexReg,niReg,mask);
MicroAPI::Muls(tmp1,tmp,(U)(nSize_*nSize_),mask);
MicroAPI::Mul(subReg,tmp,niReg,mask);
MicroAPI::Sub(tmp2,(MicroAPI::RegTensor<U> &)indexReg,subReg,mask);
MicroAPI::Add(addReg,tmp1,tmp2,mask);
MicroAPI::Muls(tmp3,tmp2,U(nSize_),mask);
MicroAPI::Add(addReg,addReg,tmp3,mask);
MicroAPI::DataCopy(dst,addReg,mask);
}
indexBuf_.EnQue(indexLocal);
}
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::CopyIn(uint64_t i, int64_t totalNum)
{
copyInParams_.blockCount = 1;
copyInParams_.blockLen = static_cast<uint32_t>(totalNum * sizeof(T));
copyInParams_.srcStride = 0;
copyInParams_.dstStride = 0;
LocalTensor<T> xLocal = xQueue_.AllocTensor<T>();
int64_t curUbInFactor = batchUbFactor_ * nSize_;
int64_t blockOffsetX_ = blockIdx_ < blockMainCount_ ?
blockIdx_* blockFactor_ * curUbInFactor:
blockMainCount_ * blockFactor_ * curUbInFactor + (blockIdx_ - blockMainCount_) * blockTailFactor_ * curUbInFactor;
xGmOffset_ = blockOffsetX_ + i * curUbInFactor;
DataCopyPad(xLocal, xGM_[xGmOffset_], copyInParams_, padParams_);
xQueue_.EnQue(xLocal);
}
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::CopyOut(uint64_t i, int64_t totalNum)
{
copyOutParams_.blockCount = 1;
copyOutParams_.blockLen = static_cast<uint32_t>(totalNum * nSize_ * sizeof(T));
copyOutParams_.srcStride = 0;
copyOutParams_.dstStride = 0;
LocalTensor<T> yLocal = yQueue_.DeQue<T>();
int64_t curUbOutFactor = batchUbFactor_ * nSize_ * nSize_;
int64_t blockOffsetY_ = blockIdx_ < blockMainCount_ ?
blockIdx_* blockFactor_ * curUbOutFactor:
blockMainCount_ * blockFactor_ * curUbOutFactor + (blockIdx_ - blockMainCount_) * blockTailFactor_ * curUbOutFactor;
yGmOffset_ = blockOffsetY_ + i * curUbOutFactor;
DataCopyPad(yGM_[yGmOffset_], yLocal, copyOutParams_);
yQueue_.FreeTensor(yLocal);
}
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::Process()
{
uint64_t ubLoopNum;
if (blockIdx_ >= realCoreNum_) {
return;
}
if(blockIdx_ < blockMainCount_){
ubLoopNum = blockFactor_;
}
else {
ubLoopNum = blockTailFactor_;
}
genIndex();
if constexpr (sizeof(T) == sizeof(int8_t)) {
ProcessPerCoreInt8(ubLoopNum);
}
else {
ProcessPerCore(ubLoopNum);
}
}
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::MainProcess(uint64_t loop, uint32_t maskNum, uint32_t processNum,
uint32_t totalNum, U addrOffset, uint16_t loopNum, __local_mem__ U *idxPtr)
{
CopyIn(loop,totalNum);
LocalTensor<T> xLocal = xQueue_.DeQue<T>();
LocalTensor<T> yLocal = yQueue_.AllocTensor<T>();
Duplicate<T>(yLocal,(T)0,totalNum*nSize_);
auto *srcPtr = (__local_mem__ T *)xLocal.GetPhyAddr();
auto *dstPtr = (__local_mem__ T *)yLocal.GetPhyAddr();
if constexpr (sizeof(T) == sizeof(int8_t)) {
VFProcessInt8(maskNum, processNum, totalNum, addrOffset, loopNum, idxPtr, srcPtr, dstPtr);
}
else {
VFProcess(maskNum, processNum, totalNum, addrOffset, loopNum, idxPtr, srcPtr, dstPtr);
}
xQueue_.FreeTensor(xLocal);
yQueue_.EnQue(yLocal);
CopyOut(loop,totalNum);
}
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::ProcessPerCore(uint64_t ubLoopNum)
{
batchNum_ = vlLen_ / nSize_ ;
uint32_t totalNum = batchUbFactor_ * nSize_;
uint32_t processNum = totalNum<=vlLen_ ? totalNum : batchNum_ * nSize_;
uint16_t loopNum = totalNum / processNum;
LocalTensor<Y> indexLocal = indexBuf_.DeQue<Y>();
auto *idxPtr = (__local_mem__ U *)indexLocal.GetPhyAddr();
U addrOffset = processNum * nSize_;
uint32_t maskNum0 = processNum;
for(uint64_t i = 0; i < ubLoopNum; i++){
if (blockIdx_ == realCoreNum_ - 1 && i == ubLoopNum-1) {
totalNum = batchUbTailFactor_ * nSize_;
loopNum = totalNum / processNum;
}
MainProcess(i,maskNum0, processNum, totalNum, addrOffset, loopNum, idxPtr);
}
indexBuf_.FreeTensor(indexLocal);
}
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::VFProcess(uint32_t maskNum0, uint32_t processNum, uint32_t totalNum, U addrOffset,
uint16_t loopNum, __local_mem__ U *idxPtr, __local_mem__ T *srcPtr, __local_mem__ T *dstPtr)
{
uint32_t tailNum = totalNum - processNum * loopNum;
uint32_t maskTailNum = tailNum;
__VEC_SCOPE__
{
MicroAPI::RegTensor<T> srcReg;
MicroAPI::RegTensor<T> zeroReg;
MicroAPI::RegTensor<U> indexReg;
MicroAPI::RegTensor<T> lowerReg;
MicroAPI::RegTensor<T> higherReg;
MicroAPI::UnalignReg u0;
MicroAPI::MaskReg mask0;
MicroAPI::MaskReg mask1;
mask0 = MicroAPI::UpdateMask<T>(maskNum0);
mask1 = MicroAPI::UpdateMask<T>(maskTailNum);
MicroAPI::DataCopy(indexReg, idxPtr);
for (uint16_t j = 0; j < loopNum; j++) {
MicroAPI::DataCopyUnAlignPre(u0, srcPtr);
MicroAPI::DataCopyUnAlign<T, MicroAPI::PostLiteral::POST_MODE_UPDATE>(srcReg, u0, srcPtr, processNum );
MicroAPI::DataCopyScatter(dstPtr, srcReg, indexReg, mask0);
MicroAPI::Adds(indexReg, indexReg, addrOffset, mask0);
}
if (tailNum!=0) {
MicroAPI::DataCopyUnAlignPre(u0, srcPtr);
MicroAPI::DataCopyUnAlign<T, MicroAPI::PostLiteral::POST_MODE_UPDATE>(srcReg, u0, srcPtr, processNum );
MicroAPI::DataCopyScatter(dstPtr, srcReg, indexReg, mask1);
}
}
}
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::ProcessPerCoreInt8(uint64_t ubLoopNum)
{
uint32_t totalNum = batchUbFactor_ * nSize_;
int32_t halfVlLen = vlLen_/2;
batchNum_ = halfVlLen / nSize_ ;
uint32_t processNum = totalNum <= halfVlLen ? totalNum : batchNum_ * nSize_;
uint16_t loopNum = totalNum / processNum;
LocalTensor<Y> indexLocal = indexBuf_.DeQue<Y>();
auto *idxPtr = (__local_mem__ U *)indexLocal.GetPhyAddr();
uint32_t maskNumInt8 = processNum*2;
U addrOffset = processNum * nSize_;
uint32_t maskNum0 = processNum;
for(uint64_t i = 0; i < ubLoopNum; i++){
if (blockIdx_ == realCoreNum_ - 1 && i == ubLoopNum-1) {
totalNum = batchUbTailFactor_ * nSize_;
loopNum = totalNum / processNum;
}
MainProcess(i, maskNumInt8, processNum, totalNum, addrOffset, loopNum, idxPtr);
}
indexBuf_.FreeTensor(indexLocal);
}
template <typename T, typename U, typename Y>
__aicore__ inline void MatrixDiagScatterLower<T, U, Y>::VFProcessInt8(uint32_t maskNumInt8, uint32_t processNum, uint32_t totalNum,
U addrOffset, uint16_t loopNum, __local_mem__ U *idxPtr, __local_mem__ T *srcPtr, __local_mem__ T *dstPtr)
{
uint32_t tailNum = totalNum - processNum * loopNum;
uint32_t maskTailNum = tailNum*2;
__VEC_SCOPE__
{
MicroAPI::RegTensor<T> srcRegInt8;
MicroAPI::RegTensor<T> zeroRegInt8;
MicroAPI::RegTensor<U> indexRegInt8;
MicroAPI::RegTensor<T> lowerRegInt8;
MicroAPI::RegTensor<T> higherRegInt8;
MicroAPI::UnalignReg u0Int8;
MicroAPI::MaskReg mask0Int8;
MicroAPI::MaskReg mask1Int8;
mask0Int8 = MicroAPI::UpdateMask<T>(maskNumInt8);
mask1Int8 = MicroAPI::UpdateMask<T>(maskTailNum);
MicroAPI::DataCopy(indexRegInt8, idxPtr);
for (uint16_t j = 0; j < loopNum; j++) {
MicroAPI::DataCopyUnAlignPre(u0Int8, srcPtr);
MicroAPI::DataCopyUnAlign<T, MicroAPI::PostLiteral::POST_MODE_UPDATE>(srcRegInt8, u0Int8, srcPtr, processNum );
MicroAPI::Duplicate(zeroRegInt8, (T)0 );
MicroAPI::Interleave(lowerRegInt8, higherRegInt8, srcRegInt8, zeroRegInt8);
MicroAPI::DataCopyScatter(dstPtr, lowerRegInt8, indexRegInt8, mask0Int8);
MicroAPI::Adds(indexRegInt8, indexRegInt8, addrOffset, mask0Int8);
}
if (tailNum!=0) {
MicroAPI::DataCopyUnAlignPre(u0Int8, srcPtr);
MicroAPI::DataCopyUnAlign<T, MicroAPI::PostLiteral::POST_MODE_UPDATE>(srcRegInt8, u0Int8, srcPtr, processNum );
MicroAPI::Duplicate(zeroRegInt8, (U)0 );
MicroAPI::Interleave(lowerRegInt8, higherRegInt8, srcRegInt8, zeroRegInt8);
MicroAPI::DataCopyScatter(dstPtr, lowerRegInt8, indexRegInt8, mask1Int8);
}
}
}
}
#endif