* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file stft.h
* \brief
*/
#ifndef STFT_H
#define STFT_H
#include "kernel_tiling/kernel_tiling.h"
#include "kernel_operator.h"
#include "lib/matrix/matmul/matmul.h"
#include "lib/matmul_intf.h"
#include "lib/pad/kernel_operator_pad_intf.h"
#include "kernel_operator_mm_intf.h"
#include "kernel_operator_fixpipe_intf.h"
namespace STFTND {
using namespace AscendC;
using namespace matmul;
constexpr int32_t ALIGNMENT_SIZE = 32;
constexpr int64_t CONTINUOUS_DATA = 256;
constexpr int64_t FLOAT_BYTES = 4;
constexpr int64_t FLOAT_MASK = 64;
constexpr uint32_t BLOCK_NUM = 8;
constexpr int32_t REAL_IMAG_PER_ROW = 4;
constexpr int32_t REAL_IMAG_COLS = 8;
constexpr int32_t REAL_IMAG = 2;
constexpr int PING_PONG_NUM = 2;
constexpr int BLOCK_BYTES = 32;
constexpr int BLOCK_ALIGN_NUM = 8;
constexpr int MODE_MOVE_LEFT = 4;
constexpr int FLAG_MOVE_LEFT = 8;
constexpr int TAIL_ROW_NUM = 2;
constexpr int TWO_BLOCKS = 2;
constexpr int FOUR_BLOCKS = 4;
constexpr int EIGHT_BLOCKS = 8;
constexpr int DIFFS_ONE_BLOCK = 4;
constexpr int C_V_DOUBLE = 2;
constexpr int WORKSPACE_SIZE_ALIGN = 512;
constexpr int64_t DOUBLE_BUFFER = 2;
constexpr int INDEX_K_3 = 3;
constexpr int64_t THREE_QUE = 3;
constexpr int BASE_N = 96;
template <typename T, int32_t bufferNum>
class StftND {
public:
__aicore__ inline StftND(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR window, GM_ADDR y, GM_ADDR workspace, STFTTilingData* tilingData)
{
inTilingData = tilingData;
size_t windowSplitWorkspaceSize =
(((inTilingData->blockNum * inTilingData->frameCount * inTilingData->nfft * sizeof(T) + 511) / 512) * 512) *
inTilingData->aivBatchLoop / sizeof(T);
if (g_coreType == AIV) {
inputGm.SetGlobalBuffer((__gm__ T*)x, inTilingData->batch * inTilingData->inputSize);
windowSplitWorkspaceGm.SetGlobalBuffer(
(__gm__ T*)workspace,
inTilingData->blockNum * inTilingData->aivBatchLoop * inTilingData->frameCount * inTilingData->nfft);
outputGm.SetGlobalBuffer(
(__gm__ T*)y, inTilingData->matmulM * inTilingData->frameCount * inTilingData->batch * DOUBLE_BUFFER);
gmReal.SetGlobalBuffer(
reinterpret_cast<__gm__ T*>(workspace) + windowSplitWorkspaceSize,
inTilingData->batch * inTilingData->matmulM * inTilingData->frameCount * DOUBLE_BUFFER);
gmImag.SetGlobalBuffer(
reinterpret_cast<__gm__ T*>(workspace) + windowSplitWorkspaceSize,
inTilingData->batch * inTilingData->matmulM * inTilingData->frameCount * DOUBLE_BUFFER);
pipe.InitBuffer(
inQueueInput, bufferNum,
(inTilingData->blkFrame * inTilingData->hop + (inTilingData->nfft - inTilingData->hop)) * sizeof(T));
pipe.InitBuffer(outQueueOutput, bufferNum, inTilingData->blkFrame * inTilingData->nfft * sizeof(T));
}
if (g_coreType == AIC) {
auto blockIdx = GetBlockIdx();
a1Global.SetGlobalBuffer(
reinterpret_cast<__gm__ T*>(window), inTilingData->matmulM * inTilingData->nfft * DOUBLE_BUFFER);
bGlobal.SetGlobalBuffer(
reinterpret_cast<__gm__ T*>(workspace),
inTilingData->blockNum * inTilingData->nfft * inTilingData->frameCount * inTilingData->aivBatchLoop);
matMulWorkspaceGm.SetGlobalBuffer(
reinterpret_cast<__gm__ T*>(workspace) + windowSplitWorkspaceSize,
inTilingData->matmulM * inTilingData->frameCount * inTilingData->batch * DOUBLE_BUFFER);
curCoreM_ = inTilingData->aicTotalLen;
curCoreN_ = inTilingData->mmTilingData.N;
curCoreK_ = inTilingData->mmTilingData.singleCoreK;
if (blockIdx % inTilingData->aicMatmulMCore >= inTilingData->aicMTailIdx) {
curCoreM_ = inTilingData->aicTailLen;
}
baseN_ = BASE_N;
baseMK_ = baseM_ * baseK_;
baseKN_ = baseK_ * baseN_;
baseMN_ = baseM_ * baseN_;
L1N_ = baseN_;
midTailM_ = curCoreM_ % L1M_;
midTailN_ = curCoreN_ % L1N_;
if (midTailM_ == 0) {
midTailM_ = L1M_;
}
if (midTailN_ == 0) {
midTailN_ = L1N_;
}
pipe.InitBuffer(inQueueL1, depthL1_, baseMK_ * sizeof(T) * inK_);
pipe.InitBuffer(inQueueA0, depthA0_, baseMK_ * sizeof(T));
pipe.InitBuffer(inQueueB0, depthB0_, baseKN_ * sizeof(T));
pipe.InitBuffer(outQueueCO, depthC0_, baseMN_ * sizeof(T));
}
}
__aicore__ inline void Process()
{
uint32_t totalMLen = inTilingData->aicTotalLen;
uint32_t frameCount = inTilingData->frameCount;
uint32_t aicTailLen = inTilingData->aicTailLen;
uint32_t aicBatchLoop = inTilingData->aicBatchLoop;
uint32_t matmulM = inTilingData->matmulM;
if (g_coreType == AIV) {
auto blockIdx = GetBlockIdx();
SplitFrameNormal(0, blockIdx);
pipe.InitBuffer(realInQueue, PING_PONG_NUM, inTilingData->sizePerRepeat);
pipe.InitBuffer(imagInQueue, PING_PONG_NUM, inTilingData->sizePerRepeat);
pipe.InitBuffer(complexOutQueue, PING_PONG_NUM, REAL_IMAG * inTilingData->sizePerRepeat);
pipe.InitBuffer(mask, REAL_IMAG * inTilingData->sizePerRepeat);
uint32_t gatherSizePerRepeat = inTilingData->sizePerRepeat;
uint32_t windowLoop = inTilingData->aivWindowLoop;
uint32_t outUbGap = inTilingData->aivGatherUbGap;
uint32_t aivTotalEvenMLen = inTilingData->aivTotalEvenRow;
uint32_t aivTailEvenMLen = inTilingData->aivTailEvenRow;
uint32_t aivBatchLoop = inTilingData->aivBatchLoop;
uint32_t frameCountAlign = inTilingData->frameCountAlign;
uint32_t frameCountDiff = frameCountAlign - frameCount;
uint32_t aivMTailIdx = inTilingData->aivMTailIdx;
uint32_t aivTailLoop = inTilingData->aivTailLoop;
uint32_t aivBatchTailIdx = inTilingData->aivBatchTailIdx;
int32_t numPerRepeat = gatherSizePerRepeat / sizeof(T);
int32_t maskCol = (numPerRepeat * REAL_IMAG) / BLOCK_ALIGN_NUM;
rowEven = inTilingData->aivTotalEvenRow;
rowOdd = inTilingData->aivTotalOddRow;
if (blockIdx % inTilingData->aivWindowLoop >= inTilingData->aivMTailIdx) {
rowEven = inTilingData->aivTailEvenRow;
rowOdd = inTilingData->aivTailOddRow;
}
if (blockIdx % DOUBLE_BUFFER == 0) {
N = rowEven * inTilingData->frameCountAlign;
} else {
N = rowOdd * inTilingData->frameCountAlign;
}
LocalTensor<int32_t> maskTemp = mask.template Get<int32_t>();
LocalTensor<uint32_t> maskBuf;
uint16_t flag_id_mte3 = 3;
AscendC::CrossCoreSetFlag<DOUBLE_BUFFER, PIPE_MTE3>(flag_id_mte3);
for (int32_t i = 1; i < aivBatchLoop; i++) {
if (i >= aivTailLoop && blockIdx >= aivBatchTailIdx) {
continue;
}
SplitFrameNormal(i, blockIdx);
}
CreateGatherMask(maskTemp, maskCol, REAL_IMAG_COLS, 0, REAL_IMAG * gatherSizePerRepeat);
int32_t offsetBase = 0;
if (unlikely(blockIdx % windowLoop < aivMTailIdx)) {
offsetBase = ((blockIdx % windowLoop) / C_V_DOUBLE) * totalMLen * frameCount +
(blockIdx % C_V_DOUBLE) * (aivTotalEvenMLen * REAL_IMAG) * frameCount;
} else {
offsetBase = (aivMTailIdx / C_V_DOUBLE) * totalMLen * frameCount +
(((blockIdx % windowLoop) - aivMTailIdx) % C_V_DOUBLE) * (aivTailEvenMLen * REAL_IMAG) *
frameCount +
((blockIdx % windowLoop) - aivMTailIdx) / C_V_DOUBLE * aicTailLen * frameCount;
}
int repeats = (N * sizeof(T) + gatherSizePerRepeat - 1) / gatherSizePerRepeat;
maskBuf = maskTemp.ReinterpretCast<uint32_t>();
for (int32_t i = 0; i < aivBatchLoop; i++) {
if (i >= aivTailLoop && blockIdx >= aivBatchTailIdx) {
continue;
}
uint64_t flag_id_fix = 1;
AscendC::WaitEvent(flag_id_fix);
int32_t curBatch = (blockIdx / windowLoop) * aicBatchLoop + i;
int32_t gmRealOffset = curBatch * matmulM * frameCount * DOUBLE_BUFFER;
int32_t gmImagOffset = gmRealOffset + frameCount;
int32_t outputOffset = curBatch * matmulM * frameCount * DOUBLE_BUFFER;
gmRealOffset = gmRealOffset + offsetBase;
gmImagOffset = gmImagOffset + offsetBase;
outputOffset = outputOffset + offsetBase;
auto gmRealLocal = gmReal[gmRealOffset];
auto gmImagLocal = gmImag[gmImagOffset];
auto outputGmLocal = outputGm[outputOffset];
int32_t len = numPerRepeat;
int32_t nBurst = len / frameCountAlign;
int32_t offset = nBurst * frameCount * DOUBLE_BUFFER;
for (int j = 0; j < repeats; j++) {
if (j == repeats - 1) {
len = N - numPerRepeat * j;
nBurst = len / frameCountAlign;
}
LocalTensor<T> realBuf = realInQueue.template AllocTensor<T>();
DataCopyExtParams copyGmToUbParams{
static_cast<uint16_t>(nBurst), static_cast<uint32_t>(frameCount * sizeof(T)),
static_cast<uint32_t>(frameCount * sizeof(T)), 0, 0};
DataCopyPadExtParams<T> padParams{false, 0, static_cast<uint8_t>(frameCountDiff), 0};
DataCopyPad(realBuf, gmRealLocal[j * offset], copyGmToUbParams, padParams);
realInQueue.EnQue(realBuf);
LocalTensor<T> imagBuf = imagInQueue.template AllocTensor<T>();
DataCopyPad(imagBuf, gmImagLocal[j * offset], copyGmToUbParams, padParams);
imagInQueue.EnQue(imagBuf);
realBuf = realInQueue.template DeQue<T>();
imagBuf = imagInQueue.template DeQue<T>();
LocalTensor<T> complexBuf = complexOutQueue.template AllocTensor<T>();
Gather(complexBuf, realBuf, maskBuf, 0, REAL_IMAG * len);
complexOutQueue.EnQue(complexBuf);
realInQueue.FreeTensor(realBuf);
imagInQueue.FreeTensor(imagBuf);
complexBuf = complexOutQueue.template DeQue<T>();
DataCopyExtParams copyUbToGmParams{
static_cast<uint16_t>(nBurst), static_cast<uint32_t>(frameCount * sizeof(T) * REAL_IMAG),
outUbGap, 0, 0};
DataCopyPad(outputGmLocal[j * offset], complexBuf, copyUbToGmParams);
complexOutQueue.FreeTensor(complexBuf);
}
}
}
if (g_coreType == AIC) {
int blockIdx = GetBlockIdx();
int32_t aicMatmulMCore = inTilingData->aicMatmulMCore;
uint32_t aicTailLoop = inTilingData->aicTailLoop;
uint32_t aicBatchTailIdx = inTilingData->aicBatchTailIdx;
int32_t curBatchBase = (blockIdx / aicMatmulMCore) * aicBatchLoop;
int32_t innerReminder = blockIdx % aicMatmulMCore;
int32_t aicMTailIdx = inTilingData->aicMTailIdx;
int32_t nfft = inTilingData->nfft;
globalM_ = inTilingData->mmTilingData.M;
globalN_ = inTilingData->mmTilingData.N;
globalK_ = inTilingData->mmTilingData.Ka;
int32_t a1GlobalOffset = 0;
int32_t outputOffsetBase = 0;
if (unlikely(blockIdx % aicMatmulMCore < aicMTailIdx)) {
a1GlobalOffset = innerReminder * totalMLen * nfft;
outputOffsetBase = innerReminder * totalMLen * frameCount;
} else {
a1GlobalOffset = aicMTailIdx * totalMLen * nfft + (innerReminder - aicMTailIdx) * aicTailLen * nfft;
outputOffsetBase =
aicMTailIdx * totalMLen * frameCount + (innerReminder - aicMTailIdx) * aicTailLen * frameCount;
}
uint64_t flag_id_mte3 = 3;
AscendC::WaitEvent(flag_id_mte3);
for (int i = 0; i < aicBatchLoop; i++) {
if (i >= aicTailLoop && blockIdx >= aicBatchTailIdx) {
continue;
}
int32_t curBatch = curBatchBase + i;
int32_t bGlobalOffset =
blockIdx * aicBatchLoop * frameCount * inTilingData->nfft + i * frameCount * inTilingData->nfft;
int32_t outputOffset = curBatch * matmulM * frameCount * REAL_IMAG + outputOffsetBase;
bGM = bGlobal[bGlobalOffset];
aGM = a1Global[a1GlobalOffset];
cGM = matMulWorkspaceGm[outputOffset];
IterateAllTB();
uint16_t flag_id_fix = 1;
AscendC::CrossCoreSetFlag<DOUBLE_BUFFER, PIPE_FIX>(flag_id_fix);
}
}
}
__aicore__ inline void CreateGatherMask(
LocalTensor<int32_t>& dst, const int32_t row, const int32_t col, const int32_t startValue,
const int32_t diffValue)
{
int32_t repeatNum = row / BLOCK_NUM;
int32_t tailNum = row % BLOCK_NUM;
int32_t realNum = BLOCK_BYTES / (sizeof(T) * REAL_IMAG);
for (int32_t i = 0; i < REAL_IMAG_COLS / REAL_IMAG; i++) {
dst.SetValue(i * REAL_IMAG, startValue + i * realNum);
dst.SetValue(i * REAL_IMAG + 1, startValue + diffValue + i * realNum);
}
auto eventID = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventID);
WaitFlag<HardEvent::S_V>(eventID);
if (repeatNum == 0) {
AddsComputeAlignment(dst, dst, row, col, realNum);
return;
}
AddsComputeAlignment(dst, dst, REAL_IMAG, col, realNum);
PipeBarrier<PIPE_V>();
AddsComputeAlignment(dst, dst, REAL_IMAG, col * TWO_BLOCKS, TWO_BLOCKS * realNum);
PipeBarrier<PIPE_V>();
AddsComputeAlignment(dst, dst, REAL_IMAG, col * FOUR_BLOCKS, FOUR_BLOCKS * realNum);
PipeBarrier<PIPE_V>();
AddsComputeAlignment(dst, dst, repeatNum, col * EIGHT_BLOCKS, EIGHT_BLOCKS * realNum);
PipeBarrier<PIPE_V>();
if (tailNum != 0) {
int32_t lastPrevOffset = repeatNum * EIGHT_BLOCKS * (BLOCK_BYTES / sizeof(T)) - BLOCK_BYTES / sizeof(T);
auto lastPrevDst = dst[lastPrevOffset];
AddsComputeAlignment(lastPrevDst, lastPrevDst, tailNum + 1, col, realNum);
}
}
__aicore__ inline void AddsComputeAlignment(
LocalTensor<int32_t>& dst, LocalTensor<int32_t>& src, const int32_t rowCount, const int32_t colCount,
const int32_t scalar)
{
int32_t repeatTimesOneRow = (colCount * FLOAT_BYTES + CONTINUOUS_DATA - 1) / CONTINUOUS_DATA;
int32_t countOneRow32B = (colCount * FLOAT_BYTES + ALIGNMENT_SIZE - 1) / ALIGNMENT_SIZE;
UnaryRepeatParams repeatParams;
for (int32_t i = 0; i < rowCount; i++) {
repeatParams.dstBlkStride = 1;
repeatParams.srcBlkStride = 1;
repeatParams.dstRepStride = CONTINUOUS_DATA / ALIGNMENT_SIZE;
repeatParams.srcRepStride = CONTINUOUS_DATA / ALIGNMENT_SIZE;
if ((repeatTimesOneRow - 1) > 0) {
Adds(
dst[i * ((countOneRow32B * ALIGNMENT_SIZE) / FLOAT_BYTES)], src[0],
static_cast<int32_t>(i * sizeof(T)), FLOAT_MASK, repeatTimesOneRow - 1, repeatParams);
}
int32_t mask =
FLOAT_MASK - (((repeatTimesOneRow * CONTINUOUS_DATA) - (colCount * FLOAT_BYTES)) / FLOAT_BYTES);
Adds(
dst[i * (countOneRow32B * ALIGNMENT_SIZE / FLOAT_BYTES) +
(repeatTimesOneRow - 1) * (CONTINUOUS_DATA / FLOAT_BYTES)],
src[(repeatTimesOneRow - 1) * (CONTINUOUS_DATA / FLOAT_BYTES)],
static_cast<int32_t>(i * sizeof(T) * scalar), mask, 1, repeatParams);
}
}
__aicore__ inline void SplitFrameNormal(int32_t batchLoopIndex, int32_t blockIdx)
{
int32_t windowLength = inTilingData->frameCount;
int32_t windowLoop = inTilingData->aivWindowLoop;
int32_t batchLoop = inTilingData->aivBatchLoop;
int32_t nfft = inTilingData->nfft;
int32_t blkFrame = inTilingData->blkFrame;
int32_t curBatch = (blockIdx / windowLoop) * batchLoop + batchLoopIndex;
int32_t batchOffset = blockIdx / DOUBLE_BUFFER;
int32_t loopBlkFrame = windowLength / blkFrame;
int32_t remainingFrame = windowLength % blkFrame;
int32_t halfBlkFrame = loopBlkFrame / DOUBLE_BUFFER;
int32_t inputGmOffset = curBatch * (inTilingData->inputSize + nfft);
int32_t windowSplitOffset =
batchOffset * batchLoop * nfft * windowLength + batchLoopIndex * nfft * windowLength;
if (GetBlockIdx() % DOUBLE_BUFFER == 0) {
for (int32_t i = 0; i < halfBlkFrame; i++) {
CopyIn(i, blkFrame, inputGmOffset);
Compute(i, blkFrame);
CopyOut(i, blkFrame, windowSplitOffset);
}
} else {
for (int32_t i = halfBlkFrame; i < loopBlkFrame; i++) {
CopyIn(i, blkFrame, inputGmOffset);
Compute(i, blkFrame);
CopyOut(i, blkFrame, windowSplitOffset);
}
if (remainingFrame != 0) {
CopyIn(loopBlkFrame, remainingFrame, inputGmOffset);
Compute(loopBlkFrame, remainingFrame);
CopyOut(loopBlkFrame, remainingFrame, windowSplitOffset);
}
}
}
__aicore__ inline int MatmulCeil(int num1, int num2)
{
ASSERT(num2 > 0);
return (num1 + num2 - 1) / num2;
}
private:
__aicore__ inline void CopyIn(int32_t progress, int32_t remainingFrame, int32_t inputBaseOffset)
{
int inputGmStart = progress * inTilingData->blkFrame * inTilingData->hop;
int copySize = remainingFrame * inTilingData->hop + (inTilingData->nfft - inTilingData->hop);
LocalTensor<T> inputLocal = inQueueInput.template AllocTensor<T>();
DataCopy(inputLocal, inputGm[inputGmStart + inputBaseOffset], copySize);
inQueueInput.EnQue(inputLocal);
}
__aicore__ inline void Compute(int32_t progress, int32_t remainingFrame)
{
LocalTensor<T> inputLocal = inQueueInput.template DeQue<T>();
LocalTensor<T> outputLocal = outQueueOutput.template AllocTensor<T>();
DataCopyParams dataCopyParams;
dataCopyParams.blockCount = 1;
dataCopyParams.blockLen = inTilingData->nfft / BLOCK_NUM;
dataCopyParams.srcStride = 0;
dataCopyParams.dstStride = 0;
for (int32_t j = 0; j < remainingFrame; j++) {
DataCopy(outputLocal[j * inTilingData->nfft], inputLocal[j * inTilingData->hop], dataCopyParams);
}
outQueueOutput.EnQue(outputLocal);
inQueueInput.FreeTensor(inputLocal);
}
__aicore__ inline void CopyOut(int32_t progress, int32_t remainingFrame, int32_t outputBaseOffset)
{
LocalTensor<T> outputLocal2 = outQueueOutput.template DeQue<T>();
DataCopy(
windowSplitWorkspaceGm[progress * inTilingData->blkFrame * inTilingData->nfft + outputBaseOffset],
outputLocal2, remainingFrame * inTilingData->nfft);
outQueueOutput.FreeTensor(outputLocal2);
}
__aicore__ inline void IterateAllTB()
{
int midM = MatmulCeil(curCoreM_, baseM_);
int midN = MatmulCeil(curCoreN_, baseN_);
int iterAM = 0;
int iterBN = 0;
int iterK = 0;
int useL1M = curCoreM_ < L1M_ ? curCoreM_ : L1M_;
int useL1N = curCoreN_ < L1N_ ? curCoreN_ : L1N_;
LocalTensor<T> tmpB;
LocalTensor<T> tmpA;
LocalTensor<T> b1Local;
LocalTensor<T> a1Local;
tmpA = inQueueL1.AllocTensor<T>();
CopyGMND2LMNZ(tmpA, aGM, iterAM, iterK, useL1M, L1K_, globalK_);
tmpB = inQueueL1.AllocTensor<T>();
CopyGMND2LMNZTranspose(tmpB, bGM, iterK, iterBN, L1K_, useL1N, globalK_);
inQueueL1.EnQue(tmpA);
inQueueL1.EnQue(tmpB);
for (int idxAMOut = 0; idxAMOut < midM; ++idxAMOut) {
if (idxAMOut == midM - 1) {
useL1M = midTailM_;
}
iterAM = idxAMOut * baseM_;
iterBN = 0;
a1Local = inQueueL1.DeQue<T>();
for (int idxBNOut = 0; idxBNOut < midN - 1; ++idxBNOut) {
if (unlikely(idxBNOut == midN - DOUBLE_BUFFER)) {
tmpB = inQueueL1.AllocTensor<T>();
CopyGMND2LMNZTranspose(tmpB, bGM, iterK, iterBN + baseN_, L1K_, midTailN_, globalK_);
inQueueL1.EnQue(tmpB);
} else {
tmpB = inQueueL1.AllocTensor<T>();
CopyGMND2LMNZTranspose(tmpB, bGM, iterK, iterBN + baseN_, L1K_, useL1N, globalK_);
inQueueL1.EnQue(tmpB);
}
b1Local = inQueueL1.DeQue<T>();
computeTail(a1Local, b1Local, iterAM, iterBN, useL1M, useL1N, L1K_, false);
inQueueL1.FreeTensor(b1Local);
iterBN += baseN_;
}
if (likely(idxAMOut < midM - 1)) {
int nextM = iterAM + baseM_;
int nextK = 0;
int nextN = 0;
int nextUseL1M = L1M_;
if (unlikely(idxAMOut == midM - DOUBLE_BUFFER)) {
nextUseL1M = midTailM_;
}
tmpA = inQueueL1.AllocTensor<T>();
CopyGMND2LMNZ(tmpA, aGM, nextM, nextK, nextUseL1M, L1K_, globalK_);
inQueueL1.EnQue(tmpA);
b1Local = inQueueL1.DeQue<T>();
computeTail(a1Local, b1Local, iterAM, iterBN, useL1M, midTailN_, L1K_, false);
inQueueL1.FreeTensor(b1Local);
inQueueL1.FreeTensor(a1Local);
tmpB = inQueueL1.AllocTensor<T>();
if (midN != 1) {
CopyGMND2LMNZTranspose(tmpB, bGM, nextK, nextN, L1K_, useL1N, globalK_);
}
inQueueL1.EnQue(tmpB);
} else {
b1Local = inQueueL1.DeQue<T>();
computeTail(a1Local, b1Local, iterAM, iterBN, useL1M, midTailN_, L1K_, false);
inQueueL1.FreeTensor(b1Local);
inQueueL1.FreeTensor(a1Local);
}
}
}
__aicore__ inline void computeTail(
LocalTensor<T>& a1Local, LocalTensor<T>& b1Local, int curAM, int curBN, int useL1M, int useL1N, int useL1K,
bool fixpipeBias)
{
int nCount = MatmulCeil(useL1N, baseN_);
int kCount = MatmulCeil(useL1K, baseK_);
int sizeInM = useL1M;
int sizeInN = useL1N;
int sizeInK = baseK_;
int tailSizeInK = useL1K - (kCount - 1) * baseK_;
LocalTensor<T> a0Local;
LocalTensor<T> b0Local;
LocalTensor<T> c0Local = outQueueCO.AllocTensor<T>();
LoadDataA0v5(useL1M, 0, sizeInM, sizeInK, a1Local);
LoadDataB0v5(useL1K, 0, sizeInK, sizeInN, b1Local);
a0Local = inQueueA0.DeQue<T>();
b0Local = inQueueB0.DeQue<T>();
Mmad(
c0Local, a0Local, b0Local,
{static_cast<uint16_t>(sizeInM), static_cast<uint16_t>(sizeInN), static_cast<uint16_t>(sizeInK), 0, false,
true});
AscendC::PipeBarrier<PIPE_M>();
inQueueA0.FreeTensor(a0Local);
inQueueB0.FreeTensor(b0Local);
LoadDataA0v5(useL1M, 1, sizeInM, sizeInK, a1Local);
LoadDataB0v5(useL1K, 1, sizeInK, sizeInN, b1Local);
a0Local = inQueueA0.DeQue<T>();
b0Local = inQueueB0.DeQue<T>();
Mmad(
c0Local, a0Local, b0Local,
{static_cast<uint16_t>(sizeInM), static_cast<uint16_t>(sizeInN), static_cast<uint16_t>(sizeInK), 0, false,
false});
AscendC::PipeBarrier<PIPE_M>();
inQueueA0.FreeTensor(a0Local);
inQueueB0.FreeTensor(b0Local);
LoadDataA0v5(useL1M, DOUBLE_BUFFER, sizeInM, sizeInK, a1Local);
LoadDataB0v5(useL1K, DOUBLE_BUFFER, sizeInK, sizeInN, b1Local);
a0Local = inQueueA0.DeQue<T>();
b0Local = inQueueB0.DeQue<T>();
Mmad(
c0Local, a0Local, b0Local,
{static_cast<uint16_t>(sizeInM), static_cast<uint16_t>(sizeInN), static_cast<uint16_t>(sizeInK), 0, false,
false});
AscendC::PipeBarrier<PIPE_M>();
inQueueA0.FreeTensor(a0Local);
inQueueB0.FreeTensor(b0Local);
LoadDataA0v5(useL1M, INDEX_K_3, sizeInM, sizeInK, a1Local);
LoadDataB0v5(useL1K, INDEX_K_3, sizeInK, sizeInN, b1Local);
a0Local = inQueueA0.DeQue<T>();
b0Local = inQueueB0.DeQue<T>();
Mmad(
c0Local, a0Local, b0Local,
{static_cast<uint16_t>(sizeInM), static_cast<uint16_t>(sizeInN), static_cast<uint16_t>(sizeInK), 0, false,
false});
AscendC::PipeBarrier<PIPE_M>();
inQueueA0.FreeTensor(a0Local);
inQueueB0.FreeTensor(b0Local);
LoadDataA0v5(useL1M, DIFFS_ONE_BLOCK, sizeInM, tailSizeInK, a1Local);
LoadDataB0v5(useL1K, DIFFS_ONE_BLOCK, tailSizeInK, sizeInN, b1Local);
a0Local = inQueueA0.DeQue<T>();
b0Local = inQueueB0.DeQue<T>();
Mmad(
c0Local, a0Local, b0Local,
{static_cast<uint16_t>(sizeInM), static_cast<uint16_t>(sizeInN), static_cast<uint16_t>(tailSizeInK), 0,
false, false});
AscendC::PipeBarrier<PIPE_M>();
inQueueA0.FreeTensor(a0Local);
inQueueB0.FreeTensor(b0Local);
outQueueCO.EnQue<T>(c0Local);
FixPipeC0Tail(curAM, 0, curBN, 0, sizeInM, sizeInN, fixpipeBias);
}
__aicore__ inline void LoadDataA0v5(int useL1M, int idxInK, int sizeInM, int sizeInK, const LocalTensor<T>& a1Local)
{
LocalTensor<T> a0Local = inQueueA0.AllocTensor<T>();
int useL1MCeil = MatmulCeil(useL1M, cubeCol_);
int sizeInMCeil = MatmulCeil(sizeInM, cubeCol_);
int sizeInKCeil = MatmulCeil(sizeInK, cubeRow_);
LoadData3DParamsV2<T> param;
param.l1H = 1;
param.l1W = useL1MCeil * cubeCol_;
param.kExtension = cubeRow_ * sizeInKCeil;
param.mExtension = cubeCol_ * sizeInMCeil;
param.strideW = 1;
param.strideH = 1;
param.filterW = 1;
param.filterH = 1;
param.dilationFilterW = 1;
param.dilationFilterH = 1;
param.channelSize = cubeRow_ * sizeInKCeil;
LoadData(a0Local, a1Local[cubeCol_ * useL1MCeil * baseK_ * idxInK], param);
inQueueA0.EnQue(a0Local);
}
__aicore__ inline void LoadDataB0v5(int useL1N, int idxInK, int sizeInK, int sizeInN, const LocalTensor<T>& b1Local)
{
LocalTensor<T> b0Local = inQueueB0.AllocTensor<T>();
uint8_t RepeatCeil = MatmulCeil(sizeInK, cubeRow_) * MatmulCeil(sizeInN, cubeCol_);
int sizeInNCeil = MatmulCeil(sizeInN, cubeCol_) * cubeCol_;
RepeatCeil = RepeatCeil == 1 ? DOUBLE_BUFFER : RepeatCeil;
LoadData2dParams param = {0, RepeatCeil, 1, 0, 0, false, inc};
LoadData(b0Local, b1Local[idxInK * baseK_ * sizeInNCeil], param);
inQueueB0.EnQue(b0Local);
}
__aicore__ inline void FixPipeC0Tail(int aM, int inM, int bN, int inN, int sizeInM, int sizeInN, bool setAtomicAdd)
{
LocalTensor<T> c0Local = outQueueCO.DeQue<T>();
if (setAtomicAdd) {
SetAtomicAdd<T>();
}
auto intriParams = AscendC::FixpipeParamsV220(sizeInN,
sizeInM,
MatmulCeil(sizeInM, cubeCol_) * cubeCol_,
globalN_,
false);
intriParams.quantPre = QuantMode_t::NoQuant;
AscendC::Fixpipe(cGM[(aM + inM * baseM_) * globalN_ + bN + inN * baseN_], c0Local, intriParams);
if (setAtomicAdd) {
SetAtomicNone();
}
outQueueCO.FreeTensor(c0Local);
}
__aicore__ inline void CopyGMND2LMNZ(
const LocalTensor<T>& dst, const GlobalTensor<T>& src, const uint16_t row, const uint16_t col,
const uint16_t height, const uint16_t width, const uint16_t gCol)
{
Nd2NzParams param = {1, height, width, 0, gCol, static_cast<uint16_t>(MatmulCeil(height, cubeCol_) * cubeCol_),
1, 0};
DataCopy(dst, src[row * gCol + col], param);
}
__aicore__ inline void CopyGMND2LMNZTranspose(
const LocalTensor<T>& dst, const GlobalTensor<T>& src, const uint16_t row, const uint16_t col,
const uint16_t height, const uint16_t width, const uint16_t gCol)
{
Nd2NzParams param = {1, width, height, 0, gCol, static_cast<uint16_t>(MatmulCeil(width, cubeCol_) * cubeCol_),
1, 0};
DataCopy(dst, src[col * gCol + row], param);
}
private:
TPipe pipe;
TQue<QuePosition::VECIN, bufferNum> inQueueInput;
TQue<QuePosition::VECOUT, bufferNum> outQueueOutput;
TQue<QuePosition::VECIN, PING_PONG_NUM> realInQueue;
TQue<QuePosition::VECIN, PING_PONG_NUM> imagInQueue;
TQue<QuePosition::VECOUT, PING_PONG_NUM> complexOutQueue;
GlobalTensor<T> inputGm;
GlobalTensor<T> outputGm;
GlobalTensor<T> windowSplitWorkspaceGm;
GlobalTensor<T> matMulWorkspaceGm;
GlobalTensor<T> gmReal;
GlobalTensor<T> gmImag;
GlobalTensor<T> aGM;
GlobalTensor<T> bGM;
GlobalTensor<T> cGM;
TBuf<> mask;
constexpr static int baseM_ = 96;
int baseN_;
constexpr static int baseK_ = 80;
constexpr static int cubeRow_ = 8;
constexpr static int cubeCol_ = 16;
int baseMK_;
int baseKN_;
int baseMN_;
constexpr static int depthL1_ = 3;
constexpr static int depthA0_ = 2;
constexpr static int depthB0_ = 2;
constexpr static int depthC0_ = 2;
constexpr static int inK_ = 5;
constexpr static int L1M_ = baseM_;
int L1N_;
constexpr static int L1K_ = inK_ * baseK_;
int globalM_;
int globalN_;
int globalK_;
int curCoreM_;
int curCoreN_;
int curCoreK_;
int midTailM_;
int midTailN_;
TQue<QuePosition::A1, THREE_QUE> inQueueL1;
TQue<QuePosition::A2, DOUBLE_BUFFER> inQueueA0;
TQue<QuePosition::B2, DOUBLE_BUFFER> inQueueB0;
TQue<QuePosition::CO1, DOUBLE_BUFFER> outQueueCO;
GlobalTensor<T> a1Global;
GlobalTensor<T> bGlobal;
int32_t N;
int32_t rowEven;
int32_t rowOdd;
STFTTilingData* inTilingData;
};
}
#endif