* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file stft_tiling_base.cpp
* \brief
*/
#include "exe_graph/runtime/shape.h"
#include "op_host/math_tiling_templates_registry.h"
#include "log/log.h"
#include "platform/platform_info.h"
#include "stft_tiling_base.h"
using namespace AscendC;
using namespace ge;
namespace optiling {
constexpr size_t INPUT_MAX_DIM_NUM = 2;
ge::graphStatus STFTBaseTiling::GetPlatformInfo()
{
auto platformPtr = context_->GetPlatformInfo();
if (platformPtr == nullptr) {
auto compileInfoPtr = reinterpret_cast<const STFTCompileInfo*>(context_->GetCompileInfo());
OP_CHECK_IF(
compileInfoPtr == nullptr, OP_LOGE(context_->GetNodeName(), "compile info is null"),
return ge::GRAPH_FAILED);
coreNum = compileInfoPtr->coreNum;
aivCoreNum = compileInfoPtr->aivCoreNum;
aicCoreNum = compileInfoPtr->aicCoreNum;
sysWorkspaceSize = compileInfoPtr->sysWorkspaceSize;
ubSize = compileInfoPtr->ubSize;
l1Size = compileInfoPtr->l1Size;
l0ASize = compileInfoPtr->l0ASize;
l0BSize = compileInfoPtr->l0BSize;
l0CSize = compileInfoPtr->l0CSize;
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformPtr);
coreNum = ascendcPlatform.GetCoreNum();
aicCoreNum = ascendcPlatform.GetCoreNumAic();
aivCoreNum = ascendcPlatform.GetCoreNumAiv();
sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
uint64_t ubSizePlatform;
uint64_t l1SizePlatform;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatform);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1SizePlatform);
ubSize = static_cast<int64_t>(ubSizePlatform);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1SizePlatform);
l1Size = static_cast<int64_t>(l1SizePlatform);
uint64_t l0ASizePlatform;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, l0ASizePlatform);
l0ASize = static_cast<int64_t>(l0ASizePlatform);
uint64_t l0BSizePlatform;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, l0BSizePlatform);
l0BSize = static_cast<int64_t>(l0BSizePlatform);
uint64_t l0CSizePlatform;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, l0CSizePlatform);
l0CSize = static_cast<int64_t>(l0CSizePlatform);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus STFTBaseTiling::GetShapeAttrsInfo()
{
auto input = context_->GetInputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, input);
dtype = input->GetDataType();
if (dtype != ge::DataType::DT_FLOAT16 && dtype != ge::DataType::DT_FLOAT && dtype != ge::DataType::DT_COMPLEX64) {
OP_LOGE(context_->GetNodeName(), "STFT: invalid dtype.");
return ge::GRAPH_FAILED;
}
auto inputShape = context_->GetInputShape(0)->GetStorageShape();
if (inputShape.GetDimNum() > INPUT_MAX_DIM_NUM) {
OP_LOGE(context_->GetNodeName(), "STFT: input shape dim(=%zu) > 2, please check.", inputShape.GetDimNum());
return ge::GRAPH_FAILED;
}
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
const int64_t* hopPtr = attrs->GetAttrPointer<int64_t>(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, hopPtr);
hop = *hopPtr;
if (hop <= 0) {
OP_LOGE(context_->GetNodeName(), "STFT: invalid hop attr.");
return ge::GRAPH_FAILED;
}
const int64_t* winLengthPtr = attrs->GetAttrPointer<int64_t>(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, winLengthPtr);
winLength = *winLengthPtr;
if (winLength <= 0) {
OP_LOGE(context_->GetNodeName(), "STFT: invalid winLength attr.");
return ge::GRAPH_FAILED;
}
const bool* normalizedPtr = attrs->GetAttrPointer<bool>(2);
OP_CHECK_NULL_WITH_CONTEXT(context_, normalizedPtr);
normalized = *normalizedPtr;
const bool* onesidedPtr = attrs->GetAttrPointer<bool>(3);
OP_CHECK_NULL_WITH_CONTEXT(context_, onesidedPtr);
onesided = *onesidedPtr;
const bool* returnComplexPtr = attrs->GetAttrPointer<bool>(4);
OP_CHECK_NULL_WITH_CONTEXT(context_, returnComplexPtr);
returnComplex = *returnComplexPtr;
const int64_t* nfftPtr = attrs->GetAttrPointer<int64_t>(5);
OP_CHECK_NULL_WITH_CONTEXT(context_, nfftPtr);
nfft = *nfftPtr;
if (nfft <= 0) {
OP_LOGE(context_->GetNodeName(), "STFT: invalid nfft attr.");
return ge::GRAPH_FAILED;
}
if (inputShape.GetDimNum() == INPUT_MAX_DIM_NUM) {
batch = inputShape.GetDim(0);
inputSize = inputShape.GetDim(1) - nfft;
} else {
batch = 1;
inputSize = inputShape.GetDim(0) - nfft;
}
return ge::GRAPH_SUCCESS;
}
}