* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file stft_generalized.h
* \brief
*/
#ifndef STFT_GENERALIZED_H
#define STFT_GENERALIZED_H
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
namespace STFTND {
using namespace AscendC;
using namespace matmul;
template <typename T, int32_t bufferNum, const MatmulConfig& MM_CFG>
class STFTGeneralized {
public:
typedef MatmulType<AscendC::TPosition::GM, CubeFormat::ND, T> aType;
typedef MatmulType<AscendC::TPosition::GM, CubeFormat::ND, T, true> bType;
typedef MatmulType<AscendC::TPosition::GM, CubeFormat::ND, T> cType;
typedef MatmulType<AscendC::TPosition::GM, CubeFormat::ND, T> biasType;
Matmul<aType, bType, cType, biasType, MM_CFG> mm0;
Matmul<aType, bType, cType, biasType, MM_CFG> mm1;
Matmul<aType, bType, cType, biasType, MM_CFG> mm2;
Matmul<aType, bType, cType, biasType, MM_CFG> mm3;
Matmul<aType, bType, cType, biasType, MM_CFG>* mm;
__aicore__ inline STFTGeneralized(){};
__aicore__ inline void Init(
GM_ADDR x, GM_ADDR plan, GM_ADDR window, GM_ADDR y, GM_ADDR workspace, STFTGeneralizedTilingData* tilingData,
TPipe* pipeIn)
{
pipe = pipeIn;
tiling = tilingData;
if (tiling->usedCoreNum <= GetBlockIdx()) {
return;
}
uint64_t inputGmSize = (uint64_t)tiling->batch * (((uint64_t)(tiling->inputSize + tiling->nfft) * sizeof(T) + BLOCK_SIZE - 1) /
BLOCK_SIZE * BLOCK_SIZE / sizeof(T));
inputGm.SetGlobalBuffer((__gm__ T*)x, inputGmSize);
uint64_t splitWindowGmSize = (uint64_t)tiling->batch * tiling->matmulN * tiling->nfftAlign;
splitWindowGm.SetGlobalBuffer((__gm__ T*)workspace, splitWindowGmSize);
uint64_t splitWindowWorkspaceSize = (splitWindowGmSize * sizeof(T) + WORKSPACE_ALIGN_SIZE - 1) /
WORKSPACE_ALIGN_SIZE * WORKSPACE_ALIGN_SIZE / sizeof(T);
uint64_t realGmSize = (uint64_t)tiling->batch * tiling->matmulM * tiling->matmulN;
uint64_t imagGmSize = realGmSize;
uint64_t outputGmSize = IMAG_AND_REAL * realGmSize;
matmulOutputGm.SetGlobalBuffer((__gm__ T*)workspace + splitWindowWorkspaceSize, outputGmSize);
outputGm.SetGlobalBuffer((__gm__ T*)y, outputGmSize);
uint64_t realWindowSize = (uint64_t)tiling->matmulM * tiling->nfftAlign;
uint64_t imagWindowSize = realWindowSize;
if (window == nullptr) {
planGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(plan), realWindowSize * IMAG_AND_REAL);
} else {
uint64_t matmulWorkspaceSize =
(((uint64_t)tiling->batch * tiling->matmulN * tiling->matmulM * sizeof(T) * IMAG_AND_REAL +
WORKSPACE_ALIGN_SIZE - 1) /
WORKSPACE_ALIGN_SIZE) *
WORKSPACE_ALIGN_SIZE / sizeof(T);
uint64_t planOffset = splitWindowWorkspaceSize + matmulWorkspaceSize;
planGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(workspace) + planOffset, realWindowSize * IMAG_AND_REAL);
}
bufferSize = tiling->copyUBSize / IMAG_AND_REAL / bufferNum / BLOCK_SIZE * BLOCK_SIZE;
if (tiling->nfft == tiling->nfftAlign && (tiling->hopLength * sizeof(T)) % BLOCK_SIZE == 0 &&
tiling->nfftAlign * sizeof(T) <= bufferSize && tiling->hopLength * sizeof(T) <= bufferSize) {
pipe->InitBuffer(outCopy, bufferNum, bufferSize);
pipe->InitBuffer(inCopy, bufferNum, bufferSize);
isAligned = true;
} else {
if (tiling->splitWindowSkipNum > 0) {
if (tiling->splitWindowSkipNum * tiling->hopLength >= tiling->nfftAlign) {
uint32_t oneBufferSize = (tiling->copyUBSize / 2) / BLOCK_SIZE * BLOCK_SIZE;
pipe->InitBuffer(inCopy, bufferNum, oneBufferSize);
pipe->InitBuffer(outCopy, bufferNum, oneBufferSize);
} else {
uint32_t skipHopLength = tiling->splitWindowSkipNum * tiling->hopLength;
uint32_t onceNInUB =
(tiling->copyUBSize / sizeof(T) - tiling->nfftAlign) / (skipHopLength + tiling->nfftAlign);
uint32_t inputSize = (onceNInUB * skipHopLength + tiling->nfftAlign) * sizeof(T);
uint32_t outputSize = onceNInUB * tiling->nfftAlign * sizeof(T);
pipe->InitBuffer(inCopy, bufferNum, inputSize);
pipe->InitBuffer(outCopy, bufferNum, outputSize);
}
} else {
bufferSize *= DOUBLE_SIZE;
pipe->InitBuffer(inCopy, bufferNum, bufferSize);
isAligned = false;
}
}
maskCount = tiling->maskUBSize / BLOCK_SIZE * BLOCK_SIZE / sizeof(int32_t);
pipe->InitBuffer(gatherIn, GATHER_BUFFER_NUM, maskCount * sizeof(T));
pipe->InitBuffer(gatherOut, GATHER_BUFFER_NUM, maskCount * sizeof(T));
pipe->InitBuffer(maskUB, tiling->maskUBSize);
pipe->InitBuffer(tempUB, tiling->maskUBSize);
if (tiling->normalized) {
SetRootNfft();
}
}
__aicore__ inline void Process()
{
SyncAll();
auto blockIdx = GetBlockIdx();
if (tiling->usedCoreNum <= blockIdx) {
return;
}
uint32_t nIdx = blockIdx % tiling->matmulNCoreNum;
uint32_t nFactor = tiling->matmulNCoreFactor;
uint32_t nOffset = nIdx * nFactor;
uint32_t mIdx = (blockIdx / tiling->matmulNCoreNum) % tiling->matmulMCoreNum;
uint32_t mFactor = tiling->matmulMCoreFactor;
uint32_t mOffset = mIdx * mFactor;
uint32_t bIdx = (blockIdx / tiling->matmulNCoreNum / tiling->matmulMCoreNum) % tiling->batchCoreNum;
uint32_t bFactor = tiling->batchCoreFactor;
uint32_t bOffset = bIdx * bFactor;
bool isTailM = false;
bool isTailN = false;
if (nIdx >= tiling->matmulNCoreNum - tiling->matmulNTailCoreNum) {
nOffset = (tiling->matmulNCoreNum - tiling->matmulNTailCoreNum) * nFactor;
nFactor = tiling->matmulN % nFactor;
isTailN = true;
}
if (mIdx >= tiling->matmulMCoreNum - tiling->matmulMTailCoreNum) {
mOffset = (tiling->matmulMCoreNum - tiling->matmulMTailCoreNum) * mFactor +
(mIdx + tiling->matmulMTailCoreNum - tiling->matmulMCoreNum) * (mFactor - 1);
mFactor = mFactor - 1;
isTailM = true;
}
if (bIdx >= tiling->batchCoreNum - tiling->batchTailCoreNum) {
bOffset = (tiling->batchCoreNum - tiling->batchTailCoreNum) * bFactor +
(bIdx + tiling->batchTailCoreNum - tiling->batchCoreNum) * (bFactor - 1);
bFactor = bFactor - 1;
}
if (!isTailM) {
mm = !isTailN ? &mm0 : &mm1;
} else {
mm = !isTailN ? &mm2 : &mm3;
}
for (int i = 0; i < bFactor; i++) {
int64_t inputOffset = (int64_t)(bOffset + i) * (tiling->inputSize + tiling->nfft) +
(int64_t)nOffset * tiling->hopLength;
int64_t splitWindowOffset = ((int64_t)(bOffset + i) * tiling->matmulN + nOffset) * tiling->nfftAlign;
int64_t outputOffset =
(((int64_t)(bOffset + i) * tiling->matmulM + mOffset) * tiling->matmulN + nOffset) * IMAG_AND_REAL;
int64_t realOffset = (int64_t)(bOffset + i) * tiling->matmulM * tiling->matmulN +
(int64_t)mOffset * tiling->matmulN +
(int64_t)nIdx * mFactor * tiling->matmulNCoreFactor;
int64_t imagOffset = realOffset;
int64_t a1Offset = (int64_t)mOffset * tiling->nfftAlign;
int64_t a2Offset = a1Offset;
int64_t planOffset = IMAG_AND_REAL * mOffset * tiling->nfftAlign;
int64_t matmulOutputOffset = IMAG_AND_REAL * realOffset;
SplitWindows(inputOffset, splitWindowOffset, nFactor);
if (i == 0) {
GenerateGatherMask(nFactor);
}
StftMatmul(splitWindowOffset, planOffset, matmulOutputOffset);
GatherRealAndImag(matmulOutputOffset, outputOffset, mFactor, nFactor);
}
}
private:
__aicore__ inline void GenerateGatherMask(uint32_t nFactor)
{
uint32_t alignNum = BLOCK_SIZE / sizeof(T);
uint32_t nAlign = (nFactor + alignNum - 1) / alignNum * alignNum;
if (nAlign * IMAG_AND_REAL < maskCount) {
GenerateGatherMaskSmallN(nFactor);
} else {
GenerateGatherMaskLargeN();
}
}
__aicore__ inline void GenerateGatherMaskLargeN()
{
LocalTensor<int32_t> maskTemp = maskUB.template Get<int32_t>(maskCount);
uint32_t alignNum = REPEAT_SIZE / sizeof(T);
uint32_t nAlignInUB = (maskCount / IMAG_AND_REAL) / alignNum * alignNum;
uint32_t maskTotalNum = nAlignInUB * IMAG_AND_REAL;
ArithProgression<int32_t>(maskTemp, (int32_t)0, (int32_t)2, maskTotalNum);
AscendC::PipeBarrier<PIPE_V>();
uint64_t maskBit1[2] = {0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA};
if (sizeof(T) == DWORD_SIZE) {
maskBit1[1] = 0;
}
UnaryRepeatParams repeatParams;
repeatParams.dstBlkStride = 1;
repeatParams.srcBlkStride = 1;
repeatParams.dstRepStride = BLOCK_NUM;
repeatParams.srcRepStride = BLOCK_NUM;
Muls(maskTemp, maskTemp, 0, maskBit1, maskTotalNum * sizeof(int32_t) / REPEAT_SIZE, repeatParams);
LocalTensor<int32_t> temp = tempUB.template Get<int32_t>(maskCount);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
int32_t START = -2;
int32_t STEP = 2;
ArithProgression<int32_t>(temp, START, STEP, maskTotalNum);
AscendC::PipeBarrier<PIPE_V>();
int32_t offset = nAlignInUB * sizeof(T);
Adds(temp, temp, offset, maskTotalNum);
AscendC::PipeBarrier<PIPE_V>();
uint64_t maskBit2[2] = {0x5555555555555555, 0x5555555555555555};
if (sizeof(T) == DWORD_SIZE) {
maskBit2[1] = 0;
}
Muls(temp, temp, 0, maskBit2, maskTotalNum * sizeof(int32_t) / REPEAT_SIZE, repeatParams);
AscendC::PipeBarrier<PIPE_V>();
Add(maskTemp, maskTemp, temp, maskTotalNum);
AscendC::PipeBarrier<PIPE_V>();
mask = maskTemp.ReinterpretCast<uint32_t>();
}
__aicore__ inline void GenerateGatherMaskSmallN(uint32_t nFactor)
{
LocalTensor<int32_t> maskTemp = maskUB.template Get<int32_t>(maskCount);
uint32_t alignNum = BLOCK_SIZE / sizeof(T);
uint32_t nAlign = (nFactor + alignNum - 1) / alignNum * alignNum;
uint32_t onceMInUB = maskCount / nAlign / IMAG_AND_REAL;
int32_t offset = nAlign * sizeof(T);
uint32_t blocksForRealAndImg = nAlign * IMAG_AND_REAL / alignNum;
for (int32_t i = 0; i < alignNum / IMAG_AND_REAL; i++) {
maskTemp.SetValue(i * IMAG_AND_REAL, i * sizeof(T));
maskTemp.SetValue(i * IMAG_AND_REAL + 1, i * sizeof(T) + offset);
}
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
uint32_t i = 1;
for (; i <= (blocksForRealAndImg / BINARY_ADD_COEF); i *= IMAG_AND_REAL) {
uint32_t maskAddrOffset = i * alignNum;
int32_t maskValOffset = i * (alignNum / IMAG_AND_REAL) * sizeof(T);
Adds(maskTemp[maskAddrOffset], maskTemp, maskValOffset, maskAddrOffset);
AscendC::PipeBarrier<PIPE_V>();
}
if (i > (blocksForRealAndImg / BINARY_ADD_COEF) && i < blocksForRealAndImg) {
int32_t lastMaskAddrOffset = i * alignNum;
uint32_t lastMaskCalcNum = (nAlign * IMAG_AND_REAL) - lastMaskAddrOffset;
int32_t lastMaskValOffset = i * (alignNum / IMAG_AND_REAL) * sizeof(T);
Adds(maskTemp[lastMaskAddrOffset], maskTemp, lastMaskValOffset, lastMaskCalcNum);
AscendC::PipeBarrier<PIPE_V>();
}
i = 1;
for (; i <= (onceMInUB / BINARY_ADD_COEF); i *= IMAG_AND_REAL) {
uint32_t maskAddrOffset = i * nAlign * IMAG_AND_REAL;
int32_t maskValOffset = maskAddrOffset * sizeof(T);
Adds(maskTemp[maskAddrOffset], maskTemp, maskValOffset, maskAddrOffset);
AscendC::PipeBarrier<PIPE_V>();
}
if (i > (onceMInUB / BINARY_ADD_COEF) && i < onceMInUB) {
uint32_t lastMaskAddrOffset = i * nAlign * IMAG_AND_REAL;
int32_t lastMaskValOffset = lastMaskAddrOffset * sizeof(T);
uint32_t lastMaskCalcNum = (onceMInUB - i) * nAlign * IMAG_AND_REAL;
Adds(maskTemp[lastMaskAddrOffset], maskTemp, lastMaskValOffset, lastMaskCalcNum);
AscendC::PipeBarrier<PIPE_V>();
}
mask = maskTemp.ReinterpretCast<uint32_t>();
}
__aicore__ inline void GatherRealAndImag(
int64_t matmulOffset, int64_t outputOffset, uint32_t mFactor, uint32_t nFactor)
{
uint32_t alignNum = BLOCK_SIZE / sizeof(T);
uint32_t nAlign = (nFactor + alignNum - 1) / alignNum * alignNum;
if (nAlign * IMAG_AND_REAL < maskCount) {
GatherRealAndImagSmallN(matmulOffset, outputOffset, mFactor, nFactor);
} else {
GatherRealAndImagLargeN(matmulOffset, outputOffset, mFactor, nFactor);
}
}
__aicore__ inline void SetRootNfft()
{
if (sizeof(T) == DWORD_SIZE) {
rootNfft = tiling->rootNfft;
} else {
LocalTensor<float> src = tempUB.template Get<float>(BLOCK_SIZE / sizeof(T));
LocalTensor<T> dst = tempUB.template Get<T>(BLOCK_SIZE / sizeof(T));
src.SetValue(0, tiling->rootNfft);
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Cast(dst, src, RoundMode::CAST_ROUND, 1);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
rootNfft = dst.GetValue(0);
}
}
__aicore__ inline void GatherRealAndImagSmallN(
int64_t matmulOffset, int64_t outputOffset, uint32_t mFactor, uint32_t nFactor)
{
uint32_t alignNum = BLOCK_SIZE / sizeof(T);
uint32_t nAlign = (nFactor + alignNum - 1) / alignNum * alignNum;
uint32_t onceM = maskCount / nAlign / IMAG_AND_REAL;
uint32_t mQuotient = mFactor / onceM;
uint32_t mRemainder = mFactor % onceM;
int64_t copyInGMOffsetCoef = onceM * nFactor * IMAG_AND_REAL;
int64_t copyOutGMOffsetCoef = onceM * tiling->matmulN * IMAG_AND_REAL;
DataCopyPadExtParams<T> dataCopyPadExtParams;
dataCopyPadExtParams.isPad = false;
dataCopyPadExtParams.leftPadding = 0;
dataCopyPadExtParams.rightPadding = 0;
dataCopyPadExtParams.paddingValue = 0;
DataCopyExtParams copyInParams;
copyInParams.blockCount = onceM * IMAG_AND_REAL;
copyInParams.blockLen = nFactor * sizeof(T);
copyInParams.srcStride = 0;
copyInParams.dstStride = 0;
DataCopyExtParams copyOutParams;
copyOutParams.blockCount = onceM;
copyOutParams.blockLen = nFactor * IMAG_AND_REAL * sizeof(T);
copyOutParams.srcStride = (IMAG_AND_REAL * (nAlign - nFactor)) / alignNum;
copyOutParams.dstStride = IMAG_AND_REAL * (tiling->matmulN - nFactor) * sizeof(T);
for (uint32_t i = 0; i < mQuotient; i++) {
LocalTensor<T> inputLocal = gatherIn.template AllocTensor<T>();
DataCopyPad(
inputLocal, matmulOutputGm[matmulOffset + i * copyInGMOffsetCoef], copyInParams, dataCopyPadExtParams);
gatherIn.EnQue(inputLocal);
inputLocal = gatherIn.template DeQue<T>();
LocalTensor<T> outputLocal = gatherOut.template AllocTensor<T>();
Gather(outputLocal, inputLocal, mask, 0, onceM * nAlign * IMAG_AND_REAL);
if (tiling->normalized) {
AscendC::PipeBarrier<PIPE_V>();
Muls(outputLocal, outputLocal, rootNfft, onceM * nAlign * IMAG_AND_REAL);
}
gatherOut.EnQue(outputLocal);
gatherIn.FreeTensor(inputLocal);
outputLocal = gatherOut.template DeQue<T>();
DataCopyPad(outputGm[outputOffset + i * copyOutGMOffsetCoef], outputLocal, copyOutParams);
gatherOut.FreeTensor(outputLocal);
}
if (mRemainder > 0) {
DataCopyExtParams copyInParamsRemainder;
copyInParamsRemainder.blockCount = mRemainder * IMAG_AND_REAL;
copyInParamsRemainder.blockLen = nFactor * sizeof(T);
copyInParamsRemainder.srcStride = 0;
copyInParamsRemainder.dstStride = 0;
DataCopyExtParams copyOutParamsRemainder;
copyOutParamsRemainder.blockCount = mRemainder;
copyOutParamsRemainder.blockLen = nFactor * IMAG_AND_REAL * sizeof(T);
copyOutParamsRemainder.srcStride = (IMAG_AND_REAL * (nAlign - nFactor)) / alignNum;
copyOutParamsRemainder.dstStride = IMAG_AND_REAL * (tiling->matmulN - nFactor) * sizeof(T);
LocalTensor<T> inputLocal = gatherIn.template AllocTensor<T>();
DataCopyPad(
inputLocal, matmulOutputGm[matmulOffset + mQuotient * copyInGMOffsetCoef], copyInParamsRemainder,
dataCopyPadExtParams);
gatherIn.EnQue(inputLocal);
inputLocal = gatherIn.template DeQue<T>();
LocalTensor<T> outputLocal = gatherOut.template AllocTensor<T>();
Gather(outputLocal, inputLocal, mask, 0, mRemainder * nAlign * IMAG_AND_REAL);
if (tiling->normalized) {
AscendC::PipeBarrier<PIPE_V>();
Muls(outputLocal, outputLocal, rootNfft, mRemainder * nAlign * IMAG_AND_REAL);
}
gatherOut.EnQue(outputLocal);
gatherIn.FreeTensor(inputLocal);
outputLocal = gatherOut.template DeQue<T>();
DataCopyPad(outputGm[outputOffset + mQuotient * copyOutGMOffsetCoef], outputLocal, copyOutParamsRemainder);
gatherOut.FreeTensor(outputLocal);
}
}
__aicore__ inline void GatherRealAndImagLargeN(
int64_t matmulOffset, int64_t outputOffset, uint32_t mFactor, uint32_t nFactor)
{
uint32_t alignNum = REPEAT_SIZE / sizeof(T);
uint32_t nAlignInUB = (maskCount / IMAG_AND_REAL) / alignNum * alignNum;
uint32_t nQuotient = nFactor / nAlignInUB;
uint32_t nRemainder = nFactor % nAlignInUB;
DataCopyPadExtParams<T> dataCopyPadExtParams;
dataCopyPadExtParams.isPad = false;
dataCopyPadExtParams.leftPadding = 0;
dataCopyPadExtParams.rightPadding = 0;
dataCopyPadExtParams.paddingValue = 0;
DataCopyExtParams copyInParams;
copyInParams.blockCount = IMAG_AND_REAL;
copyInParams.blockLen = nAlignInUB * sizeof(T);
copyInParams.srcStride = (nFactor - nAlignInUB) * sizeof(T);
copyInParams.dstStride = 0;
DataCopyExtParams copyOutParams;
copyOutParams.blockCount = 1;
copyOutParams.blockLen = nAlignInUB * IMAG_AND_REAL * sizeof(T);
copyOutParams.srcStride = 0;
copyOutParams.dstStride = 0;
DataCopyExtParams copyInParamsRemainder;
copyInParamsRemainder.blockCount = IMAG_AND_REAL;
copyInParamsRemainder.blockLen = nRemainder * sizeof(T);
copyInParamsRemainder.srcStride = (nFactor - nRemainder) * sizeof(T);
copyInParamsRemainder.dstStride = ((nAlignInUB - nRemainder) * sizeof(T)) / BLOCK_SIZE;
DataCopyExtParams copyOutParamsRemainder;
copyOutParamsRemainder.blockCount = 1;
copyOutParamsRemainder.blockLen = nRemainder * IMAG_AND_REAL * sizeof(T);
copyOutParamsRemainder.srcStride = 0;
copyOutParamsRemainder.dstStride = 0;
for (int64_t i = 0; i < mFactor; i++) {
int64_t curMatmulOffset = matmulOffset + i * nFactor * IMAG_AND_REAL;
int64_t curOutputOffset = outputOffset + i * tiling->matmulN * IMAG_AND_REAL;
for (int64_t j = 0; j < nQuotient; j++) {
LocalTensor<T> inputLocal = gatherIn.template AllocTensor<T>();
DataCopyPad(
inputLocal, matmulOutputGm[curMatmulOffset + j * nAlignInUB], copyInParams, dataCopyPadExtParams);
gatherIn.EnQue(inputLocal);
inputLocal = gatherIn.template DeQue<T>();
LocalTensor<T> outputLocal = gatherOut.template AllocTensor<T>();
Gather(outputLocal, inputLocal, mask, 0, nAlignInUB * IMAG_AND_REAL);
if (tiling->normalized) {
AscendC::PipeBarrier<PIPE_V>();
Muls(outputLocal, outputLocal, rootNfft, nAlignInUB * IMAG_AND_REAL);
}
gatherOut.EnQue(outputLocal);
gatherIn.FreeTensor(inputLocal);
outputLocal = gatherOut.template DeQue<T>();
DataCopyPad(outputGm[curOutputOffset + j * nAlignInUB * IMAG_AND_REAL], outputLocal, copyOutParams);
gatherOut.FreeTensor(outputLocal);
}
if (nRemainder > 0) {
LocalTensor<T> inputLocal = gatherIn.template AllocTensor<T>();
DataCopyPad(
inputLocal, matmulOutputGm[curMatmulOffset + nQuotient * nAlignInUB], copyInParamsRemainder,
dataCopyPadExtParams);
gatherIn.EnQue(inputLocal);
inputLocal = gatherIn.template DeQue<T>();
LocalTensor<T> outputLocal = gatherOut.template AllocTensor<T>();
Gather(outputLocal, inputLocal, mask, 0, nAlignInUB * IMAG_AND_REAL);
if (tiling->normalized) {
AscendC::PipeBarrier<PIPE_V>();
Muls(outputLocal, outputLocal, rootNfft, nRemainder * IMAG_AND_REAL);
}
gatherOut.EnQue(outputLocal);
gatherIn.FreeTensor(inputLocal);
outputLocal = gatherOut.template DeQue<T>();
DataCopyPad(
outputGm[curOutputOffset + nQuotient * nAlignInUB * IMAG_AND_REAL], outputLocal,
copyOutParamsRemainder);
gatherOut.FreeTensor(outputLocal);
}
}
}
__aicore__ inline void StftMatmul(int64_t splitWindowOffset, int64_t planOffset, int64_t matmulOutputOffset)
{
mm->SetTensorA(planGm[planOffset]);
mm->SetTensorB(splitWindowGm[splitWindowOffset], true);
mm->IterateAll(matmulOutputGm[matmulOutputOffset]);
mm->End();
}
__aicore__ inline void SplitWindowsUseVecSkip(int64_t inputOffset, int64_t outputOffset, int32_t nFactor)
{
uint32_t splitWindowSkipNum = tiling->splitWindowSkipNum;
uint32_t windowsForOnceSkip = (nFactor + splitWindowSkipNum - 1) / splitWindowSkipNum;
uint32_t tailSkipIdx = splitWindowSkipNum - ((windowsForOnceSkip * splitWindowSkipNum) - nFactor);
uint32_t skipHopLength = splitWindowSkipNum * tiling->hopLength;
uint32_t onceNInUB = (tiling->copyUBSize / sizeof(T) - tiling->nfftAlign) / (skipHopLength + tiling->nfftAlign);
uint32_t burstLen = tiling->nfftAlign * sizeof(T);
uint32_t copyOutDstGap = (splitWindowSkipNum - 1) * tiling->nfftAlign * sizeof(T);
uint32_t onceSkipInInputGM = onceNInUB * skipHopLength;
uint32_t copyLength = (onceNInUB - 1) * skipHopLength + tiling->nfft;
int64_t onceSkipInOutputGm = onceNInUB * splitWindowSkipNum * tiling->nfftAlign;
DataCopyPadExtParams<T> dataCopyPadExtParams;
dataCopyPadExtParams.isPad = false;
DataCopyExtParams copyOutParam;
copyOutParam.blockCount = onceNInUB;
copyOutParam.blockLen = burstLen;
copyOutParam.srcStride = 0;
copyOutParam.dstStride = copyOutDstGap;
for (int64_t i = 0; i < splitWindowSkipNum; i++) {
uint32_t curSplitN = i >= tailSkipIdx ? windowsForOnceSkip - 1 : windowsForOnceSkip;
uint32_t splitQuotient = curSplitN / onceNInUB;
uint32_t splitRemainder = curSplitN % onceNInUB;
for (int64_t j = 0; j < splitQuotient; j++) {
LocalTensor<T> inputLocal = inCopy.template AllocTensor<T>();
DataCopyExtParams copyInParam;
copyInParam.blockLen = copyLength * sizeof(T);
copyInParam.blockCount = 1;
copyInParam.srcStride = 0;
copyInParam.dstStride = 0;
DataCopyPad(
inputLocal, inputGm[inputOffset + i * tiling->hopLength + j * onceSkipInInputGM], copyInParam,
dataCopyPadExtParams);
inCopy.EnQue(inputLocal);
inputLocal = inCopy.template DeQue<T>();
LocalTensor<T> outputLocal = outCopy.template AllocTensor<T>();
for (uint32_t k = 0; k < onceNInUB; k++) {
DataCopy(outputLocal[k * tiling->nfftAlign], inputLocal[k * skipHopLength], tiling->nfftAlign);
}
outCopy.EnQue(outputLocal);
inCopy.FreeTensor(inputLocal);
outputLocal = outCopy.template DeQue<T>();
DataCopyPad(
splitWindowGm[outputOffset + i * tiling->nfftAlign + j * onceSkipInOutputGm], outputLocal,
copyOutParam);
outCopy.FreeTensor(outputLocal);
}
if (splitRemainder > 0) {
DataCopyExtParams copyOutParamRemainder;
copyOutParamRemainder.blockCount = splitRemainder;
copyOutParamRemainder.blockLen = burstLen;
copyOutParamRemainder.srcStride = 0;
copyOutParamRemainder.dstStride = copyOutDstGap;
uint32_t copyLength = (splitRemainder - 1) * skipHopLength + tiling->nfft;
LocalTensor<T> inputLocal = inCopy.template AllocTensor<T>();
DataCopyExtParams copyInParam;
copyInParam.blockLen = copyLength * sizeof(T);
copyInParam.blockCount = 1;
copyInParam.srcStride = 0;
copyInParam.dstStride = 0;
DataCopyPad(
inputLocal, inputGm[inputOffset + i * tiling->hopLength + splitQuotient * onceSkipInInputGM],
copyInParam, dataCopyPadExtParams);
inCopy.EnQue(inputLocal);
inputLocal = inCopy.template DeQue<T>();
LocalTensor<T> outputLocal = outCopy.template AllocTensor<T>();
for (uint32_t k = 0; k < splitRemainder; k++) {
DataCopy(outputLocal[k * tiling->nfftAlign], inputLocal[k * skipHopLength], tiling->nfftAlign);
}
outCopy.EnQue(outputLocal);
inCopy.FreeTensor(inputLocal);
outputLocal = outCopy.template DeQue<T>();
DataCopyPad(
splitWindowGm[outputOffset + i * tiling->nfftAlign + splitQuotient * onceSkipInOutputGm],
outputLocal, copyOutParamRemainder);
outCopy.FreeTensor(outputLocal);
}
}
}
__aicore__ inline void SplitWindowsUseMTESkip(int64_t inputOffset, int64_t outputOffset, int32_t nFactor)
{
uint32_t splitWindowSkipNum = tiling->splitWindowSkipNum;
uint32_t windowsForOnceSkip = (nFactor + splitWindowSkipNum - 1) / splitWindowSkipNum;
uint32_t tailSkipIdx = splitWindowSkipNum - ((windowsForOnceSkip * splitWindowSkipNum) - nFactor);
uint32_t skipHopLength = splitWindowSkipNum * tiling->hopLength;
uint32_t oneBufferSize = (tiling->copyUBSize / 2) / BLOCK_SIZE * BLOCK_SIZE;
uint32_t onceNInUB = oneBufferSize / (tiling->nfftAlign * sizeof(T));
uint32_t burstLenIn = tiling->nfft * sizeof(T);
uint32_t burstLenOut = tiling->nfftAlign * sizeof(T);
uint32_t copyInSrcGap = (skipHopLength - tiling->nfft) * sizeof(T);
uint32_t padLength = ((tiling->nfftAlign - tiling->nfft) * sizeof(T) % BLOCK_SIZE) / sizeof(T);
uint32_t copyInDstGap = (tiling->nfftAlign - tiling->nfft) * sizeof(T) / BLOCK_SIZE;
uint32_t copyOutDstGap = (splitWindowSkipNum - 1) * tiling->nfftAlign * sizeof(T);
uint32_t onceSkipInInputGM = onceNInUB * skipHopLength;
int64_t onceSkipInOutputGm = onceNInUB * splitWindowSkipNum * tiling->nfftAlign;
DataCopyPadExtParams<T> dataCopyPadExtParams;
dataCopyPadExtParams.isPad = padLength > 0 ? true : false;
dataCopyPadExtParams.leftPadding = 0;
dataCopyPadExtParams.rightPadding = padLength;
dataCopyPadExtParams.paddingValue = 0;
DataCopyExtParams copyInParams;
copyInParams.blockCount = onceNInUB;
copyInParams.blockLen = burstLenIn;
copyInParams.srcStride = copyInSrcGap;
copyInParams.dstStride = copyInDstGap;
DataCopyExtParams copyOutParams;
copyOutParams.blockCount = onceNInUB;
copyOutParams.blockLen = burstLenOut;
copyOutParams.srcStride = 0;
copyOutParams.dstStride = copyOutDstGap;
for (int64_t i = 0; i < splitWindowSkipNum; i++) {
uint32_t curSplitN = i >= tailSkipIdx ? windowsForOnceSkip - 1 : windowsForOnceSkip;
uint32_t splitQuotient = curSplitN / onceNInUB;
uint32_t splitRemainder = curSplitN % onceNInUB;
for (int64_t j = 0; j < splitQuotient; j++) {
LocalTensor<T> inputLocal = inCopy.template AllocTensor<T>();
DataCopyPad(
inputLocal, inputGm[inputOffset + i * tiling->hopLength + j * onceSkipInInputGM], copyInParams,
dataCopyPadExtParams);
inCopy.EnQue(inputLocal);
inputLocal = inCopy.template DeQue<T>();
LocalTensor<T> outputLocal = outCopy.template AllocTensor<T>();
DataCopy(outputLocal, inputLocal, onceNInUB * tiling->nfftAlign);
outCopy.EnQue(outputLocal);
inCopy.FreeTensor(inputLocal);
outputLocal = outCopy.template DeQue<T>();
DataCopyPad(
splitWindowGm[outputOffset + i * tiling->nfftAlign + j * onceSkipInOutputGm], outputLocal,
copyOutParams);
outCopy.FreeTensor(outputLocal);
}
if (splitRemainder > 0) {
DataCopyExtParams copyInParamsRemainder;
copyInParamsRemainder.blockCount = splitRemainder;
copyInParamsRemainder.blockLen = burstLenIn;
copyInParamsRemainder.srcStride = copyInSrcGap;
copyInParamsRemainder.dstStride = copyInDstGap;
DataCopyExtParams copyOutParamsRemainder;
copyOutParamsRemainder.blockCount = splitRemainder;
copyOutParamsRemainder.blockLen = burstLenOut;
copyOutParamsRemainder.srcStride = 0;
copyOutParamsRemainder.dstStride = copyOutDstGap;
LocalTensor<T> inputLocal = inCopy.template AllocTensor<T>();
DataCopyPad(
inputLocal, inputGm[inputOffset + i * tiling->hopLength + splitQuotient * onceSkipInInputGM],
copyInParamsRemainder, dataCopyPadExtParams);
inCopy.EnQue(inputLocal);
inputLocal = inCopy.template DeQue<T>();
LocalTensor<T> outputLocal = outCopy.template AllocTensor<T>();
DataCopy(outputLocal, inputLocal, splitRemainder * tiling->nfftAlign);
outCopy.EnQue(outputLocal);
inCopy.FreeTensor(inputLocal);
outputLocal = outCopy.template DeQue<T>();
DataCopyPad(
splitWindowGm[outputOffset + i * tiling->nfftAlign + splitQuotient * onceSkipInOutputGm],
outputLocal, copyOutParamsRemainder);
outCopy.FreeTensor(outputLocal);
}
}
}
__aicore__ inline void SplitWindowsNormal(int64_t inputOffset, int64_t outputOffset, int32_t nFactor)
{
int32_t loopCount = tiling->nfft * sizeof(T) / bufferSize;
int32_t tailCount = (tiling->nfft * sizeof(T) % bufferSize) / sizeof(T);
if (!isAligned) {
for (int32_t i = 0; i < nFactor; i++) {
for (int32_t j = 0; j < loopCount; j++) {
LocalTensor<T> inputLocal = inCopy.template AllocTensor<T>();
DataCopy(
inputLocal, inputGm[inputOffset + i * tiling->hopLength + j * bufferSize / sizeof(T)],
bufferSize / sizeof(T));
event_t eventIdMTE2ToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE3));
SetFlag<HardEvent::MTE2_MTE3>(eventIdMTE2ToMTE3);
WaitFlag<HardEvent::MTE2_MTE3>(eventIdMTE2ToMTE3);
DataCopy(
splitWindowGm[outputOffset + i * tiling->nfftAlign + j * bufferSize / sizeof(T)], inputLocal,
bufferSize / sizeof(T));
event_t eventIdMTE3ToMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
SetFlag<HardEvent::MTE3_MTE2>(eventIdMTE3ToMTE2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIdMTE3ToMTE2);
inCopy.FreeTensor(inputLocal);
}
if (tailCount > 0) {
DataCopyPadExtParams<T> dataCopyPadExtParams;
dataCopyPadExtParams.isPad = false;
DataCopyExtParams copyParams;
copyParams.blockCount = 1;
copyParams.blockLen = tailCount * sizeof(T);
copyParams.srcStride = 0;
copyParams.dstStride = 0;
LocalTensor<T> inputLocal = inCopy.template AllocTensor<T>();
DataCopyPad(
inputLocal, inputGm[inputOffset + i * tiling->hopLength + loopCount * bufferSize / sizeof(T)],
copyParams, dataCopyPadExtParams);
event_t eventIdMTE2ToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE3));
SetFlag<HardEvent::MTE2_MTE3>(eventIdMTE2ToMTE3);
WaitFlag<HardEvent::MTE2_MTE3>(eventIdMTE2ToMTE3);
DataCopyPad(
splitWindowGm[outputOffset + i * tiling->nfftAlign + loopCount * bufferSize / sizeof(T)],
inputLocal, copyParams);
event_t eventIdMTE3ToMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
SetFlag<HardEvent::MTE3_MTE2>(eventIdMTE3ToMTE2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIdMTE3ToMTE2);
inCopy.FreeTensor(inputLocal);
}
}
} else {
int32_t total = 0;
while (total < nFactor) {
LocalTensor<T> inputLocal = inCopy.template AllocTensor<T>();
int64_t inputLeft =
((int64_t)tiling->batch * (tiling->inputSize + tiling->nfft) - (int64_t)total * tiling->hopLength - inputOffset) *
sizeof(T);
int32_t copyLength = inputLeft > bufferSize ? bufferSize / sizeof(T) : inputLeft / sizeof(T);
DataCopyPadExtParams<T> dataCopyPadExtParams;
dataCopyPadExtParams.isPad = false;
DataCopyExtParams copyParams;
copyParams.blockCount = 1;
copyParams.blockLen = copyLength * sizeof(T);
copyParams.srcStride = 0;
copyParams.dstStride = 0;
DataCopyPad(
inputLocal, inputGm[inputOffset + total * tiling->hopLength], copyParams, dataCopyPadExtParams);
inCopy.EnQue(inputLocal);
inputLocal = inCopy.template DeQue<T>();
LocalTensor<T> outputLocal = outCopy.template AllocTensor<T>();
int32_t bufferLeft = copyLength;
int32_t outputBufferLeft = bufferSize / sizeof(T);
int32_t i = 0;
while (bufferLeft >= (int32_t)tiling->nfftAlign && outputBufferLeft >= (int32_t)tiling->nfftAlign &&
total + i < nFactor) {
DataCopy(outputLocal[i * tiling->nfftAlign], inputLocal[i * tiling->hopLength], tiling->nfftAlign);
i++;
bufferLeft -= tiling->hopLength;
outputBufferLeft -= tiling->nfftAlign;
}
outCopy.EnQue(outputLocal);
inCopy.FreeTensor(inputLocal);
outputLocal = outCopy.template DeQue<T>();
DataCopy(splitWindowGm[outputOffset + total * tiling->nfftAlign], outputLocal, i * tiling->nfftAlign);
outCopy.FreeTensor(outputLocal);
total += i;
}
}
}
__aicore__ inline void SplitWindows(int64_t inputOffset, int64_t outputOffset, int32_t nFactor)
{
if (tiling->splitWindowSkipNum > 0) {
if (tiling->splitWindowSkipNum * tiling->hopLength >= tiling->nfftAlign) {
return SplitWindowsUseMTESkip(inputOffset, outputOffset, nFactor);
}
return SplitWindowsUseVecSkip(inputOffset, outputOffset, nFactor);
}
return SplitWindowsNormal(inputOffset, outputOffset, nFactor);
}
uint32_t BLOCK_NUM = 8;
uint32_t BLOCK_SIZE = 32;
uint32_t WORKSPACE_ALIGN_SIZE = 512;
uint32_t REPEAT_SIZE = 256;
uint32_t IMAG_AND_REAL = 2;
uint32_t DWORD_SIZE = 4;
uint32_t GATHER_BUFFER_NUM = 2;
uint32_t BINARY_ADD_COEF = 2;
int32_t DOUBLE_SIZE = 2;
int32_t maskCount;
int32_t bufferSize;
bool isAligned;
T rootNfft;
STFTGeneralizedTilingData* tiling;
TPipe* pipe;
TQue<QuePosition::VECIN, bufferNum> inCopy;
TQue<QuePosition::VECOUT, bufferNum> outCopy;
TQue<QuePosition::VECIN, bufferNum> gatherIn;
TQue<QuePosition::VECOUT, bufferNum> gatherOut;
GlobalTensor<T> inputGm;
GlobalTensor<T> outputGm;
GlobalTensor<T> splitWindowGm;
GlobalTensor<T> matmulOutputGm;
GlobalTensor<T> planGm;
TBuf<> maskUB;
TBuf<> tempUB;
LocalTensor<uint32_t> mask;
};
}
#endif