* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <iostream>
#include <algorithm>
#include <mki/utils/platform/platform_info.h>
#include "aclnn/acl_meta.h"
#include "log/log.h"
#include "fftoperation/transpose.h"
#include "fftcore/dft_core.h"
#include "fftcore/fft_core_b.h"
#include "fftcore/fft_core_n.h"
#include "fftcore/fft_core_mix.h"
#include "fftcore/fft_core_stride.h"
#include "fftcore/dft_c2r_core.h"
#include "fftcore/fft_c2r_core.h"
#include "fftcore/dft_r2c_core.h"
#include "fftcore/fft_r2c_core.h"
#include "fftcore/fft_core_any.h"
#include "fftcore/dd_core.h"
#include "fftcore/dft_core_sep.h"
#include "fftcore/ddd_core_sep.h"
#include "fftcore/fft_core_b_sep.h"
#include "fftcore/fft_c2r_arch35_core.h"
#include "fftcore/fft_r2c_arch35_core.h"
#include "fftcore/fft_c2c_arch35_core.h"
#include "fftcore/fft_c2c2d_arch35_core.h"
#include "params/fft_c2c2d_arch35.h"
#include "fftplan/fft_plan_cache.h"
#include "fftcore/select_core.h"
#include "utils/include/utils/fft_common_func.h"
#include "utils/sip_lock.h"
#include "utils/assert.h"
#include "fft_api.h"
namespace AsdSip {
using namespace Mki;
using namespace AsdSip;
constexpr int K_FACTOR_2 = 2;
constexpr int K_LAST_DIM = -1;
constexpr int K_PENULTIMATE_DIM = -2;
constexpr int K_ANTEPENULTIMATE_DIM = -3;
constexpr int K_BLOCK_SIZE = 32;
constexpr int K_RADIX_2 = 2;
constexpr int K_RADIX_MIX = -2;
constexpr int K_RADIX_ANY = -1;
constexpr int K_N_FFT_32 = 32;
constexpr int K_N_FFT_128 = 128;
constexpr int K_N_FFT_256 = 256;
constexpr int K_N_FFT_1024 = 1024;
constexpr int K_N_FFT_2048 = 2048;
constexpr int K_N_FFT_4096 = 4096;
constexpr int K_N_FFT_8192 = 8192;
constexpr int K_N_FFT_32768 = 32768;
constexpr int K_N_FFT_65536 = 65536;
constexpr int K_VERTICAL_FACTOR_128 = 128;
constexpr int MAX_FFT_SIZE = 1 << 27;
std::vector<int64_t> RADIX_2 = {2};
std::vector<int64_t> RADIX_MIX = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47};
std::vector<int64_t> RADIX_ARCH35_C2C_MIX = {2, 3, 5, 7, 11, 13, 17, 19};
std::vector<int64_t> deDuplicates(const std::vector<int64_t> &duplicates)
{
std::vector<int64_t> uniques;
for (int64_t item : duplicates) {
if (uniques.empty() || uniques.back() != item) {
uniques.push_back(item);
}
}
return uniques;
}
bool Support(const std::vector<int64_t> &uniques, const std::vector<int64_t> &radixSet)
{
for (int64_t factor : uniques) {
if (std::find(radixSet.begin(), radixSet.end(), factor) == radixSet.end()) {
return false;
}
}
return true;
}
int ChooseRadix(asdFftType fftType, const std::vector<int64_t> &uniques)
{
if (fftType == asdFftType::ASCEND_FFT_C2C) {
if (Support(uniques, RADIX_2)) {
return K_RADIX_2;
}
}
if (Support(uniques, RADIX_MIX)) {
return K_RADIX_MIX;
}
return K_RADIX_ANY;
}
int ChooseRadix(AsdSip::asdFftType fftType, int64_t fftSize)
{
std::vector<int64_t> factors = orderedFactorize(fftSize);
std::vector<int64_t> uniques = deDuplicates(factors);
return ChooseRadix(fftType, uniques);
}
bool Fft2dSupportVertical(int64_t batchSize, int64_t fftSizeX, int64_t fftSizeY, asdFftType fftType)
{
if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
return false;
}
if (batchSize != 1) {
return false;
}
switch (fftType) {
case asdFftType::ASCEND_FFT_C2R:
if (GetC2RInputSize(fftSizeY) % K_VERTICAL_FACTOR_128 != 0) {
return false;
}
break;
case asdFftType::ASCEND_FFT_R2C:
if (GetR2COutputSize(fftSizeY) % K_VERTICAL_FACTOR_128 != 0) {
return false;
}
break;
case asdFftType::ASCEND_FFT_C2C:
if (fftSizeY % K_VERTICAL_FACTOR_128 != 0) {
return false;
}
break;
default:
break;
}
std::vector<int64_t> factors = orderedFactorize(fftSizeX);
std::vector<int64_t> uniques = deDuplicates(factors);
if (!Support(uniques, RADIX_2)) {
return false;
}
if (fftSizeX < K_N_FFT_256 || fftSizeX > K_N_FFT_65536) {
return false;
}
return true;
}
bool Fft2dSupportFusing(int64_t fftSizeX, int64_t fftSizeY, int radixX, int radixY,
AsdSip::asdFftType fftType)
{
if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
return false;
}
switch (fftType) {
case asdFftType::ASCEND_FFT_C2C:
if (radixX == K_RADIX_2 && radixY == K_RADIX_2 && fftSizeX >= K_N_FFT_32 && fftSizeY >= K_N_FFT_32 &&
fftSizeX <= K_N_FFT_128 && fftSizeY <= K_N_FFT_128) {
return true;
}
break;
default:
break;
}
return false;
}
void getC2RCore(std::optional<FFTCoreType> &coreTypeOpt, int radix, unsigned nDoing, bool forward)
{
if (forward) {
ASDSIP_LOG(DEBUG) << "C2RCore forward.";
} else {
ASDSIP_LOG(DEBUG) << "C2RCore backward.";
}
if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_910B) {
if (nDoing <= K_N_FFT_1024) {
coreTypeOpt = FFTCoreType::kDftC2R;
} else {
if (radix == K_RADIX_MIX) {
coreTypeOpt = FFTCoreType::kFftC2R;
}
}
} else if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
if (nDoing <= K_N_FFT_1024) {
coreTypeOpt = FFTCoreType::kDftC2R;
} else {
if (radix == K_RADIX_MIX) {
coreTypeOpt = FFTCoreType::kFftC2RArch35;
}
}
}
return;
}
void getR2CCore(std::optional<FFTCoreType> &coreTypeOpt, int radix, unsigned nDoing, bool forward)
{
if (forward) {
ASDSIP_LOG(DEBUG) << "R2CCore forward.";
} else {
ASDSIP_LOG(DEBUG) << "R2CCore backward.";
}
if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_910B) {
if (nDoing <= K_N_FFT_1024) {
coreTypeOpt = FFTCoreType::kDftR2C;
} else {
if (radix == K_RADIX_MIX) {
coreTypeOpt = FFTCoreType::kFftR2C;
}
}
} else if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
if (nDoing <= K_N_FFT_1024) {
coreTypeOpt = FFTCoreType::kDftR2C;
} else {
if (radix == K_RADIX_MIX) {
coreTypeOpt = FFTCoreType::kFftR2CArch35;
}
}
}
return;
}
void getC2CCore(std::optional<FFTCoreType> &coreTypeOpt, int radix, unsigned nDoing, bool forward)
{
if (forward) {
ASDSIP_LOG(DEBUG) << "C2CCore forward.";
} else {
ASDSIP_LOG(DEBUG) << "C2CCore backward.";
}
if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
bool arch35Supported = nDoing > 1 && (radix == K_RADIX_2 ||
(radix == K_RADIX_MIX && Support(deDuplicates(orderedFactorize(nDoing)), RADIX_ARCH35_C2C_MIX)));
if (arch35Supported) {
coreTypeOpt = FFTCoreType::kFftC2CArch35;
} else if (nDoing <= K_N_FFT_256) {
coreTypeOpt = FFTCoreType::kDft;
} else {
ASDSIP_LOG(ERROR) << "ASCEND_950 C2C arch35 unsupported factorization for nDoing=" << nDoing;
throw std::runtime_error("ASCEND_950 C2C arch35 unsupported factorization.");
}
return;
}
if (nDoing <= K_N_FFT_256) {
coreTypeOpt = FFTCoreType::kDft;
return;
}
switch (radix) {
case K_RADIX_2:
coreTypeOpt = nDoing < K_N_FFT_32768 ? FFTCoreType::kFftB : FFTCoreType::kFftN;
return;
case K_RADIX_MIX:
if (nDoing > K_N_FFT_256) {
coreTypeOpt = FFTCoreType::kFftMix;
}
return;
default:
return;
}
}
std::unique_ptr<FftOperation> InitFftOpPtr(std::optional<FFTCoreType> coreTypeOpt, unsigned nDone, unsigned nDoing,
unsigned nLeft, unsigned batch, asdFftType fftType, bool forward,
FFTPlan &plan)
{
std::unique_ptr<FftOperation> unique(nullptr);
if (coreTypeOpt) {
switch (coreTypeOpt.value()) {
case FFTCoreType::kFftStride:
unique.reset(new FFTCoreStride(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kDft:
unique.reset(new DFTCore(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kFftB:
unique.reset(new FFTCoreB(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kFftN:
unique.reset(new FFTCoreN(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kFftMix:
unique.reset(new FftCoreMix(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kDftC2R:
unique.reset(new DftC2RCore(nDoing, batch, fftType, forward));
break;
case FFTCoreType::kFftC2R:
unique.reset(new FftC2RCore(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kFftC2RArch35:
unique.reset(new FftC2RCoreArch35(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kDftR2C:
unique.reset(new DftR2CCore(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kFftR2C:
unique.reset(new FftR2CCore(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kFftR2CArch35:
unique.reset(new FftR2CCoreArch35(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kFftC2CArch35:
unique.reset(new FftC2CCoreArch35(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kAny:
unique.reset(new FFTCoreAny(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kDftSep:
unique.reset(new DFTCoreSep(nDone, nDoing, nLeft, batch, fftType, forward));
break;
case FFTCoreType::kFftBSep:
unique.reset(new FFTCoreBSep(nDone, nDoing, nLeft, batch, fftType, forward));
break;
default:
throw std::runtime_error("Invalid coreType.");
break;
}
if (unique == nullptr) {
ASDSIP_LOG(ERROR) << "initialize fftcore failed, ptr is nullptr.";
throw std::runtime_error("ptr is nullptr.");
}
if (!unique->init()) {
plan.markFailed();
ASDSIP_LOG(ERROR) << "initialize fftcore failed.";
throw std::runtime_error("initialize fftcore failed.");
}
}
return unique;
}
std::unique_ptr<FftOperation> InitFft2DOpPtr(std::optional<FFTCoreType> coreTypeOpt, int64_t fftSizeX, int64_t fftSizeY,
int64_t batchSize, AsdSip::asdFftType fftType, bool forward, FFTPlan &plan,
int32_t arch35Mode = static_cast<int32_t>(OpParam::FftC2C2DArch35Mode::RADIX2))
{
std::unique_ptr<FftOperation> unique(nullptr);
if (coreTypeOpt) {
switch (coreTypeOpt.value()) {
case FFTCoreType::kDd:
unique.reset(new DdCore(fftSizeX, fftSizeY, batchSize, fftType, forward));
break;
case FFTCoreType::kFftC2C2DArch35:
unique.reset(new FftC2C2DCoreArch35(fftSizeX, fftSizeY, batchSize, fftType, forward, arch35Mode));
break;
default:
break;
}
if (unique == nullptr) {
ASDSIP_LOG(ERROR) << "initialize fftcore failed, ptr is nullptr.";
throw std::runtime_error("ptr is nullptr.");
}
if (!unique->init()) {
plan.markFailed();
ASDSIP_LOG(ERROR) << "initialize fftcore failed.";
}
}
return unique;
}
std::unique_ptr<FftOperation> InitFft3DOpPtr(std::optional<FFTCoreType> coreTypeOpt, int64_t fftSizeX, int64_t fftSizeY, int64_t fftSizeZ,
int64_t batchSize, AsdSip::asdFftType fftType, bool forward, FFTPlan &plan)
{
std::unique_ptr<FftOperation> unique(nullptr);
if (coreTypeOpt) {
switch (coreTypeOpt.value()) {
case FFTCoreType::kDddSep:
unique.reset(new DddCoreSep(fftSizeX, fftSizeY, fftSizeZ, batchSize, fftType, forward));
break;
default:
break;
}
if (unique == nullptr) {
ASDSIP_LOG(ERROR) << "initialize fftcore failed, ptr is nullptr.";
throw std::runtime_error("ptr is nullptr.");
}
if (!unique->init()) {
plan.markFailed();
ASDSIP_LOG(ERROR) << "initialize fftcore failed.";
}
}
return unique;
}
std::unique_ptr<FftOperation> getCore(const std::vector<int64_t> &uniques, unsigned nDone, unsigned nDoing,
unsigned nLeft, unsigned stride, unsigned batch, asdFftType fftType, bool forward,
FFTPlan &plan)
{
std::optional<FFTCoreType> coreTypeOpt = std::nullopt;
if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_910B) {
if (fftType == asdFftType::ASCEND_FFT_C2C) {
if (stride > 1) {
return InitFftOpPtr(FFTCoreType::kFftStride, nDone, nDoing, stride, batch, fftType, forward, plan);
}
}
if (fftType == asdFftType::ASCEND_FFT_C2C_SEP) {
if ((nDoing == 256 || nDoing == 512) && forward) {
return InitFftOpPtr(FFTCoreType::kFftBSep, nDone, nDoing, stride, batch, fftType, forward, plan);
}
return InitFftOpPtr(FFTCoreType::kDftSep, nDone, nDoing, stride, batch, fftType, forward, plan);
}
if (fftType == asdFftType::ASCEND_FFT_C2C) {
SelectCore::InitializeConfigs();
coreTypeOpt = SelectCore::FindConfig(batch, nDoing);
if (coreTypeOpt) {
return InitFftOpPtr(coreTypeOpt, nDone, nDoing, nLeft, batch, fftType, forward, plan);
}
}
}
int radix = ChooseRadix(fftType, uniques);
switch (fftType) {
case asdFftType::ASCEND_FFT_C2R:
getC2RCore(coreTypeOpt, radix, nDoing, forward);
break;
case asdFftType::ASCEND_FFT_R2C:
getR2CCore(coreTypeOpt, radix, nDoing, forward);
break;
case asdFftType::ASCEND_FFT_C2C:
getC2CCore(coreTypeOpt, radix, nDoing, forward);
break;
default:
break;
}
if (!coreTypeOpt) {
coreTypeOpt = FFTCoreType::kAny;
}
return InitFftOpPtr(coreTypeOpt, nDone, nDoing, nLeft, batch, fftType, forward, plan);
}
std::unique_ptr<FftOperation> GetCore2D(int64_t fftSizeX, int64_t fftSizeY, int radixX, int radixY, int64_t batchSize,
AsdSip::asdFftType fftType, bool forward, FFTPlan &plan)
{
if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
if (fftType == asdFftType::ASCEND_FFT_C2C && radixX == K_RADIX_2 && radixY == K_RADIX_2) {
return InitFft2DOpPtr(FFTCoreType::kFftC2C2DArch35, fftSizeX, fftSizeY, batchSize, fftType, forward, plan,
static_cast<int32_t>(OpParam::FftC2C2DArch35Mode::RADIX2));
}
if (fftType == asdFftType::ASCEND_FFT_C2C &&
Support(deDuplicates(orderedFactorize(fftSizeX)), RADIX_ARCH35_C2C_MIX) &&
Support(deDuplicates(orderedFactorize(fftSizeY)), RADIX_ARCH35_C2C_MIX)) {
return InitFft2DOpPtr(FFTCoreType::kFftC2C2DArch35, fftSizeX, fftSizeY, batchSize, fftType, forward, plan,
static_cast<int32_t>(OpParam::FftC2C2DArch35Mode::MIXED_RADIX));
}
if (fftType == asdFftType::ASCEND_FFT_C2C) {
ASDSIP_LOG(ERROR) << "ASCEND_950 2D C2C arch35 unsupported factorization: fftSizeX=" << fftSizeX
<< ", fftSizeY=" << fftSizeY;
throw std::runtime_error("ASCEND_950 2D C2C arch35 unsupported factorization.");
}
return nullptr;
}
std::optional<FFTCoreType> coreTypeOpt = std::nullopt;
if (fftType == asdFftType::ASCEND_FFT_C2C) {
if (radixX == K_RADIX_2 && radixY == K_RADIX_2 && fftSizeX >= K_N_FFT_32 && fftSizeY >= K_N_FFT_32 &&
fftSizeX <= K_N_FFT_128 && fftSizeY <= K_N_FFT_128) {
coreTypeOpt = FFTCoreType::kDd;
}
}
return InitFft2DOpPtr(coreTypeOpt, fftSizeX, fftSizeY, batchSize, fftType, forward, plan);
}
std::unique_ptr<FftOperation> GetCore3D(int64_t fftSizeX, int64_t fftSizeY, int64_t fftSizeZ, int64_t batchSize, AsdSip::asdFftType fftType, bool forward, FFTPlan &plan)
{
if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
return nullptr;
}
std::optional<FFTCoreType> coreTypeOpt = std::nullopt;
if (fftType == asdFftType::ASCEND_FFT_C2C_SEP) {
coreTypeOpt = FFTCoreType::kDddSep;
}
return InitFft3DOpPtr(coreTypeOpt, fftSizeX, fftSizeY, fftSizeZ, batchSize, fftType, forward, plan);
}
bool commonMatchFunc(const FFTPlan &plan, const Tensor &inData)
{
int64_t last = static_cast<int64_t>(inData.desc.dims.size()) - 1;
int64_t i = last;
if (static_cast<int64_t>(plan.fftSizes.size()) > last + 1) {
return false;
}
for (auto rit = plan.fftSizes.rbegin(); rit != plan.fftSizes.rend(); rit++, i--) {
int64_t dim = inData.desc.dims[i];
if (dim != *rit) {
return false;
}
}
return true;
}
void addFFTTransposeStep(FFTPlan &plan, int axis0, int axis1, const SVector<int64_t>& dims)
{
plan.steps.push_back(PlanStep{});
PlanStep &step = plan.steps.back();
step.operation = std::unique_ptr<FftOperation>(std::make_unique<Transpose>(axis0, axis1, dims));
}
void addFFTSteps(FFTPlan &plan, int dim, int batchSize, asdFftType fftType)
{
if (plan.fftSizes.size() <= static_cast<size_t>(dim) || plan.fftStrides.size() <= static_cast<size_t>(dim)) {
throw std::runtime_error("Invalid dim.");
}
int fftSize = plan.fftSizes[dim];
int fftStride = plan.fftStrides[dim];
std::vector<int64_t> factors = orderedFactorize(fftSize);
std::vector<int64_t> uniques = deDuplicates(factors);
unsigned nLeft = 1;
plan.steps.push_back(PlanStep{});
PlanStep &step = plan.steps.back();
step.operation = getCore(uniques, 1, fftSize, nLeft, fftStride, batchSize, fftType, plan.isForward(), plan);
}
inline void addFFTSteps(FFTPlan &plan, int dim, int batchSize)
{
addFFTSteps(plan, dim, batchSize, plan.fftType);
}
void addFFT2DStep(FFTPlan &plan, int radixX, int radixY)
{
int64_t fftSizeX = plan.fftSizes[0];
int64_t fftSizeY = plan.fftSizes[1];
plan.steps.push_back(PlanStep{});
PlanStep &step = plan.steps.back();
step.operation =
GetCore2D(fftSizeX, fftSizeY, radixX, radixY, plan.batchSize, plan.fftType, plan.isForward(), plan);
}
void addFFT3DStep(FFTPlan &plan, int64_t fftSizeX, int64_t fftSizeY, int64_t fftSizeZ)
{
plan.steps.push_back(PlanStep{});
PlanStep &step = plan.steps.back();
step.operation =
GetCore3D(fftSizeX, fftSizeY, fftSizeZ, plan.batchSize, plan.fftType, plan.isForward(), plan);
}
void init2DSteps(FFTPlan &plan)
{
int64_t batchSize = plan.batchSize;
int64_t fftSizeX = plan.fftSizes[0];
int64_t fftSizeY = plan.fftSizes[1];
int radixX = ChooseRadix(plan.fftType, fftSizeX);
int radixY = ChooseRadix(plan.fftType, fftSizeY);
SVector<int64_t> dimsStepOne;
SVector<int64_t> dimsStepTwo;
switch (plan.fftType) {
case asdFftType::ASCEND_FFT_C2R:
if (Fft2dSupportVertical(batchSize, fftSizeX, fftSizeY, plan.fftType)) {
plan.fftStrides[0] = GetC2RInputSize(fftSizeY);
addFFTSteps(plan, 0, batchSize, asdFftType::ASCEND_FFT_C2C);
addFFTSteps(plan, 1, batchSize * fftSizeX, asdFftType::ASCEND_FFT_C2R);
break;
}
dimsStepOne.push_back(batchSize);
dimsStepOne.push_back(fftSizeX);
dimsStepOne.push_back(GetC2RInputSize(fftSizeY));
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepOne);
addFFTSteps(plan, 0, batchSize * GetC2RInputSize(fftSizeY), asdFftType::ASCEND_FFT_C2C);
dimsStepTwo.push_back(batchSize);
dimsStepTwo.push_back(GetC2RInputSize(fftSizeY));
dimsStepTwo.push_back(fftSizeX);
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepTwo);
addFFTSteps(plan, 1, batchSize * fftSizeX, asdFftType::ASCEND_FFT_C2R);
break;
case asdFftType::ASCEND_FFT_R2C:
if (Fft2dSupportVertical(batchSize, fftSizeX, fftSizeY, plan.fftType)) {
plan.fftStrides[0] = GetC2RInputSize(fftSizeY);
addFFTSteps(plan, 1, batchSize * fftSizeX, asdFftType::ASCEND_FFT_R2C);
addFFTSteps(plan, 0, batchSize, asdFftType::ASCEND_FFT_C2C);
break;
}
addFFTSteps(plan, 1, batchSize * fftSizeX, asdFftType::ASCEND_FFT_R2C);
dimsStepOne.push_back(batchSize);
dimsStepOne.push_back(fftSizeX);
dimsStepOne.push_back(GetR2COutputSize(fftSizeY));
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepOne);
addFFTSteps(plan, 0, batchSize * GetR2COutputSize(fftSizeY), asdFftType::ASCEND_FFT_C2C);
dimsStepTwo.push_back(batchSize);
dimsStepTwo.push_back(GetR2COutputSize(fftSizeY));
dimsStepTwo.push_back(fftSizeX);
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepTwo);
break;
case asdFftType::ASCEND_FFT_C2C:
if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
if ((radixX == K_RADIX_2 && radixY == K_RADIX_2) ||
(Support(deDuplicates(orderedFactorize(fftSizeX)), RADIX_ARCH35_C2C_MIX) &&
Support(deDuplicates(orderedFactorize(fftSizeY)), RADIX_ARCH35_C2C_MIX))) {
addFFT2DStep(plan, radixX, radixY);
break;
}
ASDSIP_LOG(ERROR) << "ASCEND_950 2D C2C arch35 unsupported factorization: fftSizeX=" << fftSizeX
<< ", fftSizeY=" << fftSizeY;
throw std::runtime_error("ASCEND_950 2D C2C arch35 unsupported factorization.");
}
if (Fft2dSupportFusing(fftSizeX, fftSizeY, radixX, radixY, plan.fftType)) {
addFFT2DStep(plan, radixX, radixY);
break;
}
if (Fft2dSupportVertical(batchSize, fftSizeX, fftSizeY, plan.fftType)) {
plan.fftStrides[0] = fftSizeY;
addFFTSteps(plan, 1, batchSize * fftSizeX, asdFftType::ASCEND_FFT_C2C);
addFFTSteps(plan, 0, batchSize, asdFftType::ASCEND_FFT_C2C);
break;
}
addFFTSteps(plan, 1, batchSize * fftSizeX, asdFftType::ASCEND_FFT_C2C);
dimsStepOne.push_back(batchSize);
dimsStepOne.push_back(fftSizeX);
dimsStepOne.push_back(fftSizeY);
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepOne);
addFFTSteps(plan, 0, batchSize * fftSizeY, asdFftType::ASCEND_FFT_C2C);
dimsStepTwo.push_back(batchSize);
dimsStepTwo.push_back(fftSizeY);
dimsStepTwo.push_back(fftSizeX);
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepTwo);
break;
default:
break;
}
plan.markInitialized();
}
void addFFTC2CFirstTwoDimsStep(FFTPlan &plan, int batchSize, int64_t fftSizeX, int64_t fftSizeY, int64_t fftSizeZ)
{
SVector<int64_t> dimsStepOne;
SVector<int64_t> dimsStepTwo;
SVector<int64_t> dimsStepThree;
SVector<int64_t> dimsStepFour;
bool strideTag = false;
if (Fft2dSupportVertical(batchSize, fftSizeX, fftSizeY * fftSizeZ, plan.fftType)) {
plan.fftStrides[0] = fftSizeY * fftSizeZ;
strideTag = true;
addFFTSteps(plan, 0, batchSize, asdFftType::ASCEND_FFT_C2C);
} else {
dimsStepOne = {batchSize, fftSizeX, fftSizeY, fftSizeZ};
addFFTTransposeStep(plan, K_LAST_DIM, K_ANTEPENULTIMATE_DIM, dimsStepOne);
addFFTSteps(plan, 0, batchSize * fftSizeY * fftSizeZ, asdFftType::ASCEND_FFT_C2C);
}
if ((fftSizeX == 1) && Fft2dSupportVertical(batchSize, fftSizeY, fftSizeZ, plan.fftType)) {
plan.fftStrides[1] = fftSizeZ;
addFFTSteps(plan, 1, batchSize * fftSizeX, asdFftType::ASCEND_FFT_C2C);
} else {
if (strideTag) {
dimsStepOne = {batchSize, fftSizeX, fftSizeY, fftSizeZ};
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepOne);
addFFTSteps(plan, 1, batchSize * fftSizeX * fftSizeZ, asdFftType::ASCEND_FFT_C2C);
dimsStepTwo = {batchSize, fftSizeX, fftSizeZ, fftSizeY};
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepTwo);
} else {
dimsStepTwo = {batchSize, fftSizeZ, fftSizeY, fftSizeX};
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepTwo);
addFFTSteps(plan, 1, batchSize * fftSizeX * fftSizeZ, asdFftType::ASCEND_FFT_C2C);
dimsStepThree = {batchSize, fftSizeZ, fftSizeX, fftSizeY};
addFFTTransposeStep(plan, K_LAST_DIM, K_PENULTIMATE_DIM, dimsStepThree);
dimsStepFour = {batchSize, fftSizeZ, fftSizeY, fftSizeX};
addFFTTransposeStep(plan, K_LAST_DIM, K_ANTEPENULTIMATE_DIM, dimsStepFour);
}
}
}
void init3DSteps(FFTPlan &plan)
{
int64_t batchSize = plan.batchSize;
int64_t fftSizeX = plan.fftSizes[0];
int64_t fftSizeY = plan.fftSizes[1];
int64_t fftSizeZ = plan.fftSizes[2];
int radixY = ChooseRadix(plan.fftType, fftSizeY);
int radixZ = ChooseRadix(plan.fftType, fftSizeZ);
SVector<int64_t> dimsStepOne;
SVector<int64_t> dimsStepTwo;
switch (plan.fftType) {
case asdFftType::ASCEND_FFT_C2C_SEP:
plan.fftStrides = {fftSizeY * fftSizeZ, fftSizeZ, 1};
addFFT3DStep(plan, fftSizeX, fftSizeY, fftSizeZ);
break;
case asdFftType::ASCEND_FFT_C2R:
addFFTC2CFirstTwoDimsStep(plan, batchSize, fftSizeX, fftSizeY, GetC2RInputSize(fftSizeZ));
addFFTSteps(plan, 2, batchSize * fftSizeX * fftSizeY, asdFftType::ASCEND_FFT_C2R);
break;
case asdFftType::ASCEND_FFT_R2C:
addFFTSteps(plan, 2, batchSize * fftSizeX * fftSizeY, asdFftType::ASCEND_FFT_R2C);
addFFTC2CFirstTwoDimsStep(plan, batchSize, fftSizeX, fftSizeY, GetR2COutputSize(fftSizeZ));
break;
case asdFftType::ASCEND_FFT_C2C:
if (Fft2dSupportFusing(fftSizeY, fftSizeZ, radixY, radixZ, plan.fftType)) {
plan.batchSize = batchSize * fftSizeX;
plan.fftSizes = {fftSizeY, fftSizeZ};
radixZ = ChooseRadix(plan.fftType, fftSizeZ);
addFFT2DStep(plan, radixY, radixZ);
plan.batchSize = batchSize;
plan.fftSizes = {fftSizeX, fftSizeY, fftSizeZ};
if (Fft2dSupportVertical(batchSize, fftSizeX, fftSizeY * fftSizeZ, plan.fftType)) {
plan.fftStrides[0] = fftSizeY * fftSizeZ;
addFFTSteps(plan, 0, batchSize, asdFftType::ASCEND_FFT_C2C);
} else {
dimsStepOne = {batchSize, fftSizeX, fftSizeY, fftSizeZ};
addFFTTransposeStep(plan, K_LAST_DIM, K_ANTEPENULTIMATE_DIM, dimsStepOne);
addFFTSteps(plan, 0, batchSize * fftSizeY * fftSizeZ, asdFftType::ASCEND_FFT_C2C);
dimsStepTwo = {batchSize, fftSizeZ, fftSizeY, fftSizeX};
addFFTTransposeStep(plan, K_LAST_DIM, K_ANTEPENULTIMATE_DIM, dimsStepTwo);
}
break;
}
addFFTC2CFirstTwoDimsStep(plan, batchSize, fftSizeX, fftSizeY, fftSizeZ);
addFFTSteps(plan, 2, batchSize * fftSizeX * fftSizeY, asdFftType::ASCEND_FFT_C2C);
break;
default:
break;
}
plan.markInitialized();
}
inline size_t computeTempCachesSize(const FFTPlan &plan)
{
size_t num = plan.steps.size() <= K_FACTOR_2 ? 1 : K_FACTOR_2;
return num * getAlignedSize(plan.eleNum() * GetTensorElementSize(Mki::TensorDType::TENSOR_DTYPE_COMPLEX64) +
K_BLOCK_SIZE);
}
inline size_t computeWorkspaceSize(const FFTPlan &plan)
{
size_t workspaceSize = 0;
for (int64_t i = 0; i < static_cast<int64_t>(plan.steps.size()); i++) {
workspaceSize = std::max(workspaceSize, plan.steps[i].operation->EstimateWorkspaceSize());
}
return workspaceSize;
}
inline bool shouldAllocTempCaches(const FFTPlan &plan)
{
return plan.steps.size() > 1;
}
inline bool shouldAllocWorkspace(const FFTPlan &plan)
{
return computeWorkspaceSize(plan) != 0;
}
void recycleInterCaches(FFTPlan &plan, workspace::Workspace &wkspace)
{
int64_t num = plan.steps.size() <= K_FACTOR_2 ? 1 : K_FACTOR_2;
for (int64_t i = 0; i < num; i++) {
wkspace.recycleLast();
}
}
AspbStatus asdFftCreate(asdFftHandle &handle)
{
std::lock_guard<std::mutex> lock(fft_mtx);
handle = FFTPlanCache::makePlan();
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftSetStream(asdFftHandle handle, void *stream)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (stream == nullptr) {
ASDSIP_LOG(ERROR) << "stream is nullptr.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
plan.stream = stream;
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftMakePlan1D(asdFftHandle handle, int64_t fftSize, asdFftType fftType, asdFftDirection direction,
int64_t batchSize, asdFft1dDimType dimType)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
if (fftSize <= 0 || fftSize > MAX_FFT_SIZE) {
ASDSIP_LOG(ERROR) << "Invalid fft_size.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (batchSize <= 0) {
ASDSIP_LOG(ERROR) << "Invalid batch_size.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
plan.fftType = fftType;
plan.direction = direction;
plan.fftSizes = {fftSize};
if (dimType == asdFft1dDimType::ASCEND_FFT_VERTICAL) {
plan.batchSize = 1;
plan.fftStrides = {batchSize};
} else {
plan.batchSize = batchSize;
plan.fftStrides = {1};
}
addFFTSteps(plan, 0, batchSize);
plan.markInitialized();
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus commonParamCheck(asdFftHandle handle, int64_t fftSizeX, int64_t fftSizeY, int64_t batchSize)
{
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (fftSizeX <= 0 || fftSizeX > MAX_FFT_SIZE) {
ASDSIP_LOG(ERROR) << "Invalid fftSizeX.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (fftSizeY <= 0 || fftSizeY > MAX_FFT_SIZE) {
ASDSIP_LOG(ERROR) << "Invalid fftSizeY.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (batchSize <= 0) {
ASDSIP_LOG(ERROR) << "Invalid batchSize.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus commonParamCheck3D(asdFftHandle handle, int64_t fftSizeX, int64_t fftSizeY, int64_t fftSizeZ, int64_t batchSize)
{
AspbStatus checkStatus = commonParamCheck(handle, fftSizeX, fftSizeY, batchSize);
if (checkStatus != AsdSip::ErrorType::ACL_SUCCESS) {
return checkStatus;
}
if (fftSizeZ <= 0 || fftSizeZ > MAX_FFT_SIZE) {
ASDSIP_LOG(ERROR) << "Invalid fftSizeZ.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftMakePlan1D(asdFftHandle handle, int64_t fftSizeX, int64_t fftSizeY, asdFftType fftType,
asdFftDirection direction, int64_t batchSize, int32_t dim)
{
std::lock_guard<std::mutex> lock(fft_mtx);
AspbStatus checkStatus = commonParamCheck(handle, fftSizeX, fftSizeY, batchSize);
if (checkStatus != AsdSip::ErrorType::ACL_SUCCESS) {
return checkStatus;
}
if (dim <= 0) {
ASDSIP_LOG(ERROR) << "Invalid dim.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
plan.fftType = fftType;
plan.direction = direction;
plan.batchSize = batchSize;
plan.fftSizes = {fftSizeX};
plan.fftStrides = {1};
if (dim == 1) {
plan.fftStrides = {fftSizeY};
}
if (dim < 0) {
dim = dim + static_cast<int64_t>(plan.fftSizes.size());
}
if (dim < 0) {
ASDSIP_LOG(ERROR) << "dim is invalid.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
addFFTSteps(plan, 0, batchSize);
plan.markInitialized();
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftDestroy(asdFftHandle handle)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
FFTPlanCache::destroyPlan(handle);
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftMakePlan2D(asdFftHandle handle, int64_t fftSizeX, int64_t fftSizeY, asdFftType fftType,
asdFftDirection direction, int32_t batchSize)
{
std::lock_guard<std::mutex> lock(fft_mtx);
AspbStatus checkStatus = commonParamCheck(handle, fftSizeX, fftSizeY, batchSize);
if (checkStatus != AsdSip::ErrorType::ACL_SUCCESS) {
return checkStatus;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
plan.fftType = fftType;
plan.direction = direction;
plan.batchSize = batchSize;
plan.fftSizes = {fftSizeX, fftSizeY};
plan.fftStrides = {1, 1};
init2DSteps(plan);
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftMakePlan3D(
asdFftHandle handle,
int64_t fftSizeX,
int64_t fftSizeY,
int64_t fftSizeZ,
asdFftType fftType,
asdFftDirection direction,
int32_t batchSize)
{
std::lock_guard<std::mutex> lock(fft_mtx);
AspbStatus checkStatus = commonParamCheck3D(handle, fftSizeX, fftSizeY, fftSizeZ, batchSize);
if (checkStatus != AsdSip::ErrorType::ACL_SUCCESS) {
return checkStatus;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
plan.fftType = fftType;
plan.direction = direction;
plan.batchSize = batchSize;
plan.fftSizes = {fftSizeX, fftSizeY, fftSizeZ};
plan.fftStrides = {1, 1, 1};
init3DSteps(plan);
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftGetWorkspaceSize(asdFftHandle handle, size_t &workspaceSize)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
workspaceSize = 0;
if (shouldAllocTempCaches(plan)) {
workspaceSize += computeTempCachesSize(plan);
}
workspaceSize += computeWorkspaceSize(plan);
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftSetWorkspace(asdFftHandle handle, void *workspace)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
plan.workspaceAddr = workspace;
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftSynchronize(asdFftHandle handle)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
MkiRtStreamSynchronize(plan.stream);
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftGetType(asdFftHandle handle, asdFftType &fftType)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
fftType = plan.fftType;
return AsdSip::ErrorType::ACL_SUCCESS;
}
std::vector<void *> allocInterCachesV2(FFTPlan &plan, workspace::Workspace &wkspace)
{
std::vector<void *> cache;
int64_t num = plan.steps.size() <= K_FACTOR_2 ? 1 : K_FACTOR_2;
for (int64_t i = 0; i < num; i++) {
size_t dataSize = getAlignedSize(
plan.eleNum() * GetTensorElementSize(Mki::TensorDType::TENSOR_DTYPE_COMPLEX64) + K_BLOCK_SIZE);
cache.push_back(wkspace.allocate(dataSize));
}
return cache;
}
AspbStatus asdFftExecV2(FFTPlan &plan, const aclTensor *input, const aclTensor *output)
{
if (!plan.isInitialized()) {
ASDSIP_LOG(ERROR) << "plan is not initilized.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (shouldAllocTempCaches(plan) && shouldAllocWorkspace(plan) && !plan.hasWorkspace()) {
ASDSIP_LOG(ERROR) << "workspace has not been allocated.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
workspace::Workspace wkspace(plan.workspaceAddr);
wkspace.Reset();
void* inputData = Mki::GetStorageAddr(input);
if (inputData == nullptr) {
ASDSIP_LOG(ERROR) << "input aclTensor data is nullptr.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
void* outputData = Mki::GetStorageAddr(output);
if (outputData == nullptr) {
ASDSIP_LOG(ERROR) << "output aclTensor data is nullptr.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
if (plan.steps.size() == 1) {
plan.steps[0].operation->Run(inputData, outputData, plan.stream, wkspace);
return AsdSip::ErrorType::ACL_SUCCESS;
}
std::vector<void *> tmpCache = allocInterCachesV2(plan, wkspace);
int ping = 0;
for (int64_t i = 0; i < static_cast<int64_t>(plan.steps.size()); i++) {
void *tmpIn = i == 0 ? inputData : tmpCache[1 - ping];
void *tmpOut = i == static_cast<int64_t>(plan.steps.size()) - 1 ? outputData : tmpCache[ping];
plan.steps[i].operation->Run(tmpIn, tmpOut, plan.stream, wkspace);
ping = 1 - ping;
}
recycleInterCaches(plan, wkspace);
return AsdSip::ErrorType::ACL_SUCCESS;
}
bool matchC2C_(const FFTPlan &plan, const aclTensor *input)
{
if (input == nullptr) {
return false;
}
aclDataType dataType = aclDataType::ACL_DT_UNDEFINED;
auto ret = aclGetDataType(input, &dataType);
if (ret != 0 || dataType != aclDataType::ACL_COMPLEX64) {
return false;
}
int64_t *dims = nullptr;
uint64_t tensorDimSize = 0;
ret = aclGetViewShape(input, &dims, &tensorDimSize);
if (ret != 0) {
delete[] dims;
dims = nullptr;
return false;
}
if (plan.fftStrides[0] != 1) {
delete[] dims;
dims = nullptr;
return true;
}
int64_t lastDim = static_cast<int64_t>(tensorDimSize) - 1;
int64_t i = lastDim;
for (auto iterPtr = plan.fftSizes.rbegin(); iterPtr != plan.fftSizes.rend(); iterPtr++, i--) {
int64_t dim = dims[i];
if (dim != *iterPtr) {
delete[] dims;
dims = nullptr;
return false;
}
}
delete[] dims;
return true;
}
bool matchC2R_(const FFTPlan &plan, const aclTensor *input)
{
if (input == nullptr) {
return false;
}
aclDataType dataType = aclDataType::ACL_DT_UNDEFINED;
auto ret = aclGetDataType(input, &dataType);
if (ret != 0 || dataType != aclDataType::ACL_COMPLEX64) {
return false;
}
int64_t *numOfDims = nullptr;
uint64_t dimSize = 0;
ret = aclGetViewShape(input, &numOfDims, &dimSize);
if (ret != 0) {
delete[] numOfDims;
numOfDims = nullptr;
return false;
}
int64_t lastDim = static_cast<int64_t>(dimSize) - 1;
int64_t dimIterIdx = lastDim;
for (auto rit = plan.fftSizes.rbegin(); rit != plan.fftSizes.rend(); rit++, dimIterIdx--) {
int64_t dim = numOfDims[dimIterIdx];
if (dimIterIdx == lastDim) {
if (dim != (*rit / K_FACTOR_2 + 1)) {
delete[] numOfDims;
numOfDims = nullptr;
return false;
}
} else {
if (*rit != dim) {
delete[] numOfDims;
numOfDims = nullptr;
return false;
}
}
}
delete[] numOfDims;
return true;
}
bool matchR2C_(const FFTPlan &plan, const aclTensor *input)
{
if (input == nullptr) {
return false;
}
aclDataType dataType = aclDataType::ACL_DT_UNDEFINED;
auto ret = aclGetDataType(input, &dataType);
if (ret != 0 || dataType != aclDataType::ACL_FLOAT) {
return false;
}
int64_t *dims = nullptr;
uint64_t dimsNum = 0;
ret = aclGetViewShape(input, &dims, &dimsNum);
if (ret != 0) {
delete[] dims;
dims = nullptr;
return false;
}
int64_t last = static_cast<int64_t>(dimsNum) - 1;
int64_t i = last;
for (auto rit = plan.fftSizes.rbegin(); rit != plan.fftSizes.rend(); rit++, i--) {
int64_t dim = dims[i];
if (dim != *rit) {
delete[] dims;
dims = nullptr;
return false;
}
}
delete[] dims;
dims = nullptr;
return true;
}
AsdSip::AspbStatus asdFftExecC2C(asdFftHandle handle, const aclTensor *input, const aclTensor *output)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
if (!matchC2C_(plan, input)) {
ASDSIP_LOG(ERROR) << "input does not match plan.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (!matchC2C_(plan, output)) {
ASDSIP_LOG(ERROR) << "output does not match plan.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
return asdFftExecV2(plan, input, output);
}
AspbStatus asdFftExecC2R(asdFftHandle handle, const aclTensor *input, const aclTensor *output)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
if (!matchC2R_(plan, input)) {
ASDSIP_LOG(ERROR) << "input does not match plan.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (!matchR2C_(plan, output)) {
ASDSIP_LOG(ERROR) << "output does not match plan.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
return asdFftExecV2(plan, input, output);
}
AspbStatus asdFftExecR2C(asdFftHandle handle, const aclTensor *input, const aclTensor *output)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
FFTPlan &plan = FFTPlanCache::getPlan(handle);
if (!matchR2C_(plan, input)) {
ASDSIP_LOG(ERROR) << "input does not match plan.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (!matchC2R_(plan, output)) {
ASDSIP_LOG(ERROR) << "output does not match plan.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
return asdFftExecV2(plan, input, output);
}
AspbStatus asdFftExecV2Separated(FFTPlan &plan, const aclTensor *inputReal, const aclTensor *inputImag,
const aclTensor *outputReal, const aclTensor *outputImag)
{
if (!plan.isInitialized()) {
ASDSIP_LOG(ERROR) << "plan is not initilized.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
if (shouldAllocTempCaches(plan) && shouldAllocWorkspace(plan) && !plan.hasWorkspace()) {
ASDSIP_LOG(ERROR) << "workspace has not been allocated.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
workspace::Workspace wkspace(plan.workspaceAddr);
wkspace.Reset();
void* inputRealData = Mki::GetStorageAddr(inputReal);
if (inputRealData == nullptr) {
ASDSIP_LOG(ERROR) << "input aclTensor data is nullptr.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
void* inputImagData = Mki::GetStorageAddr(inputImag);
if (inputImagData == nullptr) {
ASDSIP_LOG(ERROR) << "input aclTensor data is nullptr.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
void* outputRealData = Mki::GetStorageAddr(outputReal);
if (outputRealData == nullptr) {
ASDSIP_LOG(ERROR) << "output aclTensor data is nullptr.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
void* outputImagData = Mki::GetStorageAddr(outputImag);
if (outputImagData == nullptr) {
ASDSIP_LOG(ERROR) << "output aclTensor data is nullptr.";
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
if (plan.steps.size() == 1) {
plan.steps[0].operation->Run(inputRealData, inputImagData, outputRealData, outputImagData, plan.stream, wkspace);
return AsdSip::ErrorType::ACL_SUCCESS;
}
return AsdSip::ErrorType::ACL_SUCCESS;
}
AspbStatus asdFftExecC2CSeparated(asdFftHandle handle, const aclTensor *inputReal, const aclTensor *inputImag,
const aclTensor *outputReal, const aclTensor *outputImag)
{
std::lock_guard<std::mutex> lock(fft_mtx);
if (!FFTPlanCache::doesPlanExist(handle)) {
ASDSIP_LOG(ERROR) << "Invalid handle.";
return ErrorType::ACL_ERROR_INVALID_PARAM;
}
int64_t *viewDimsInReal = nullptr;
uint64_t viewDimsNumInReal = 0;
CHECK_STATUS_WITH_ACL_RETURN(aclGetViewShape(inputReal, &viewDimsInReal, &viewDimsNumInReal), "aclGetViewShape");
int64_t *viewDimsInImag = nullptr;
uint64_t viewDimsNumInImag = 0;
CHECK_STATUS_WITH_ACL_RETURN(aclGetViewShape(inputImag, &viewDimsInImag, &viewDimsNumInImag), "aclGetViewShape");
int64_t *viewDimsOutReal = nullptr;
uint64_t viewDimsNumOutReal = 0;
CHECK_STATUS_WITH_ACL_RETURN(aclGetViewShape(outputReal, &viewDimsOutReal, &viewDimsNumOutReal), "aclGetViewShape");
int64_t *viewDimsOutImag = nullptr;
uint64_t viewDimsNumOutImag = 0;
CHECK_STATUS_WITH_ACL_RETURN(aclGetViewShape(outputImag, &viewDimsOutImag, &viewDimsNumOutImag), "aclGetViewShape");
if ((viewDimsNumInReal != viewDimsNumInImag) || (viewDimsNumOutReal != viewDimsNumOutImag) ||
(viewDimsNumInReal != viewDimsNumOutReal)) {
ASDSIP_ELOG(ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH) << "invalid input/output format.";
delete[] viewDimsInReal;
delete[] viewDimsInImag;
delete[] viewDimsOutReal;
delete[] viewDimsOutImag;
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
bool validFlag = true;
for (uint64_t dim = 0; dim < viewDimsNumInReal; dim++) {
if ((viewDimsInReal[dim] != viewDimsInImag[dim]) || (viewDimsInReal[dim] != viewDimsOutReal[dim])
|| (viewDimsOutReal[dim] != viewDimsOutImag[dim])) {
validFlag = false;
break;
}
}
if (!validFlag) {
ASDSIP_ELOG(ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH) << "invalid input/output shape.";
delete[] viewDimsInReal;
delete[] viewDimsInImag;
delete[] viewDimsOutReal;
delete[] viewDimsOutImag;
return AsdSip::ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
delete[] viewDimsInReal;
delete[] viewDimsInImag;
delete[] viewDimsOutReal;
delete[] viewDimsOutImag;
FFTPlan &plan = FFTPlanCache::getPlan(handle);
return asdFftExecV2Separated(plan, inputReal, inputImag, outputReal, outputImag);
}
}