* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file Uupsample_nearest_tiling.cpp
* \brief
*/
#include "register/op_impl_registry.h"
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
#include "tiling/platform/platform_ascendc.h"
#include "upsample_nearest_tiling.h"
namespace optiling {
constexpr int8_t NHWC_N_INDEX = 0;
constexpr int8_t NHWC_H_INDEX = 1;
constexpr int8_t NHWC_W_INDEX = 2;
constexpr int8_t NHWC_C_INDEX = 3;
constexpr int8_t NCHW_N_INDEX = 0;
constexpr int8_t NCHW_C_INDEX = 1;
constexpr int8_t NCHW_H_INDEX = 2;
constexpr int8_t NCHW_W_INDEX = 3;
constexpr int8_t OUT_H_INDEX = 0;
constexpr int8_t OUT_W_INDEX = 1;
constexpr int8_t OUT_L_INDEX = 0;
constexpr int8_t NLC_N_INDEX = 0;
constexpr int8_t NLC_L_INDEX = 1;
constexpr int8_t NLC_C_INDEX = 2;
constexpr uint32_t OUTPUT_SIZE_ATTR = 0;
constexpr uint32_t SCALE_H_ATTR = 1;
constexpr uint32_t SCALE_W_ATTR = 2;
constexpr uint32_t EXACT_ATTR = 3;
constexpr uint32_t SCALE_L_ATTR = 1;
constexpr uint32_t DATE_TYPE_FLOAT16 = 1;
constexpr uint32_t DATE_TYPE_FLOAT = 2;
constexpr uint32_t DATE_TYPE_HALF = 3;
constexpr uint64_t WORK_SPACE_SIZE = 32 * 1024 * 1024;
constexpr uint32_t BYTE_LEN_4 = 4;
constexpr uint32_t BYTE_LEN_2 = 2;
constexpr uint32_t NHWC_DIM_SIZE = 4;
constexpr uint32_t NLC_DIM_SIZE = 3;
constexpr uint32_t ADDR_ALIGN_SIZE = 512;
constexpr uint32_t COMMON_TILING_KEY = 1000;
constexpr uint32_t SMALL_CW_TILING_KEY = 1001;
constexpr uint32_t SMALL_C_TILING_KEY = 1002;
constexpr uint32_t SMALL_NCH_TILING_KEY = 1003;
constexpr uint32_t MAX_SMALL_SHPAE = 8192;
constexpr uint32_t MAX_SMALL_SACALE = 2;
constexpr uint8_t SCHEDULE_MODE = 1;
class UpsampleNearestTiling {
public:
explicit UpsampleNearestTiling(gert::TilingContext* context) : tilingContext(context) {};
ge::graphStatus RunBigKernelTiling();
private:
ge::graphStatus ParseInputAttrs();
inline float ComputeScaleValue(int64_t inputSize, int64_t outputSize, const float scale) const;
void GetWorkSpace() const;
uint32_t GetDataTypeVal() const;
int64_t GetBestAvergingCols(uint32_t coreNumPlatform);
uint32_t GetNeedCoreNum(uint32_t coreNumPlatform);
void FillTilingData();
void GetTilingKey();
template <typename T1, typename T2>
inline T1 CeilA2B(T1 a, T2 b) const;
template <typename T1>
inline int64_t Ceil(T1 x);
private:
UpsampleNearestTilingData tilingData;
gert::TilingContext* tilingContext = nullptr;
ge::DataType dataType = ge::DT_UNDEFINED;
ge::Format inputFormat = ge::Format::FORMAT_NHWC;
uint8_t dim = 0;
float realScaleH = 0.0f;
float realScaleW = 0.0f;
int64_t tailColStartList[MAX_CORE_CONT] = {0};
int64_t tailColEndList[MAX_CORE_CONT] = {0};
int64_t tailRowStartList[MAX_CORE_CONT] = {0};
int64_t tailRowEndList[MAX_CORE_CONT] = {0};
int64_t outputShapes[4] = {0};
int64_t inputShapes[4] = {0};
bool exactMode = true;
uint32_t tilingKey = 1000;
uint32_t needCoreNum = 1;
};
ge::graphStatus UpsampleNearestTiling::RunBigKernelTiling()
{
if (ParseInputAttrs() == ge::GRAPH_FAILED) {
return ge::GRAPH_FAILED;
}
GetTilingKey();
auto compileInfo = reinterpret_cast<const UpsampleNearestCompileInfo*>(tilingContext->GetCompileInfo());
uint32_t totalCoreNum = compileInfo->totalCoreNum;
needCoreNum = GetNeedCoreNum(totalCoreNum);
GetWorkSpace();
FillTilingData();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus UpsampleNearestTiling::ParseInputAttrs()
{
const gert::RuntimeAttrs* attrs = tilingContext->GetAttrs();
if (attrs == nullptr) {
return ge::GRAPH_FAILED;
}
auto srcShape = tilingContext->GetInputShape(0);
dim = srcShape->GetStorageShape().GetDimNum();
auto inputShape = srcShape->GetOriginShape();
const gert::ContinuousVector* outputSizeAttr = attrs->GetAttrPointer<gert::ContinuousVector>(OUTPUT_SIZE_ATTR);
const int64_t* outputSizeArray = reinterpret_cast<const int64_t*>(outputSizeAttr->GetData());
exactMode = *(attrs->GetAttrPointer<bool>(EXACT_ATTR));
for (int8_t i = 0; i < dim; i++) {
inputShapes[i] = inputShape.GetDim(i);
outputShapes[i] = inputShape.GetDim(i);
}
inputFormat = static_cast<ge::Format>(GetPrimaryFormat(tilingContext->GetInputDesc(0)->GetStorageFormat()));
if (dim == NHWC_DIM_SIZE) {
const float scaleH = *(attrs->GetAttrPointer<float>(SCALE_H_ATTR));
const float scaleW = *(attrs->GetAttrPointer<float>(SCALE_W_ATTR));
if (inputFormat == ge::Format::FORMAT_NCHW) {
outputShapes[NCHW_H_INDEX] = outputSizeArray[OUT_H_INDEX];
outputShapes[NCHW_W_INDEX] = outputSizeArray[OUT_W_INDEX];
realScaleH = ComputeScaleValue(inputShapes[NCHW_H_INDEX], outputShapes[NCHW_H_INDEX], scaleH);
realScaleW = ComputeScaleValue(inputShapes[NCHW_W_INDEX], outputShapes[NCHW_W_INDEX], scaleW);
} else {
outputShapes[NHWC_H_INDEX] = outputSizeArray[OUT_H_INDEX];
outputShapes[NHWC_W_INDEX] = outputSizeArray[OUT_W_INDEX];
realScaleH = ComputeScaleValue(inputShapes[NHWC_H_INDEX], outputShapes[NHWC_H_INDEX], scaleH);
realScaleW = ComputeScaleValue(inputShapes[NHWC_W_INDEX], outputShapes[NHWC_W_INDEX], scaleW);
}
} else if (dim == NLC_DIM_SIZE) {
inputShapes[NHWC_H_INDEX] = 1;
inputShapes[NHWC_W_INDEX] = inputShape.GetDim(NLC_L_INDEX);
inputShapes[NHWC_C_INDEX] = inputShape.GetDim(NLC_C_INDEX);
outputShapes[NHWC_H_INDEX] = 1;
outputShapes[NHWC_W_INDEX] = outputSizeArray[OUT_L_INDEX];
outputShapes[NHWC_C_INDEX] = inputShapes[NLC_C_INDEX];
const float scaleL = *(attrs->GetAttrPointer<float>(SCALE_H_ATTR));
realScaleH = 1.0;
realScaleW = ComputeScaleValue(inputShapes[NHWC_W_INDEX], outputShapes[NHWC_W_INDEX], scaleL);
} else {
return ge::GRAPH_FAILED;
}
auto srcDtype = tilingContext->GetInputDesc(0)->GetDataType();
if (dataType == ge::DT_UNDEFINED) {
dataType = srcDtype;
} else if (srcDtype != dataType) {
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
void UpsampleNearestTiling::GetWorkSpace() const
{
size_t* workspaces = tilingContext->GetWorkspaceSizes(1);
workspaces[0] = WORK_SPACE_SIZE;
}
void UpsampleNearestTiling::GetTilingKey()
{
if (inputFormat == ge::Format::FORMAT_NCHW) {
tilingKey = SMALL_NCH_TILING_KEY;
} else {
int64_t inputC = inputShapes[NHWC_C_INDEX];
int64_t outputH = inputShapes[NHWC_H_INDEX];
int64_t outputW = inputShapes[NHWC_W_INDEX];
int64_t inputW = outputShapes[NHWC_W_INDEX];
if (inputC * inputW < MAX_SMALL_SHPAE && inputC * outputW < MAX_SMALL_SHPAE && outputH > 1) {
tilingKey = SMALL_CW_TILING_KEY;
} else if (inputC < MAX_SMALL_SHPAE && outputW < MAX_SMALL_SHPAE && realScaleW < MAX_SMALL_SACALE) {
tilingKey = SMALL_C_TILING_KEY;
} else {
tilingKey = COMMON_TILING_KEY;
}
auto ascendc_platform = platform_ascendc::PlatformAscendC(tilingContext->GetPlatformInfo());
platform_ascendc::SocVersion nearestSocVersion = ascendc_platform.GetSocVersion();
if (nearestSocVersion == platform_ascendc::SocVersion::ASCEND310P) {
tilingContext->SetScheduleMode(SCHEDULE_MODE);
}
}
}
inline float UpsampleNearestTiling::ComputeScaleValue(int64_t inputSize, int64_t outputSize, const float scale) const
{
if ((dim == NHWC_DIM_SIZE) && (inputSize == outputSize)) {
return static_cast<float>(1);
}
if (scale > 0) {
return scale;
} else {
return static_cast<float>(inputSize) / static_cast<float>(outputSize);
}
}
template <typename T1, typename T2>
inline auto UpsampleNearestTiling::CeilA2B(T1 a, T2 b) const -> T1
{
if (b != 0) {
return (a + b - 1) / b;
} else {
return a;
}
}
uint32_t UpsampleNearestTiling::GetDataTypeVal() const
{
switch (dataType) {
case ge::DT_FLOAT:
return BYTE_LEN_4;
case ge::DT_FLOAT16:
return BYTE_LEN_2;
case ge::DT_BF16:
return BYTE_LEN_2;
default:
return 0;
}
}
int64_t UpsampleNearestTiling::GetBestAvergingCols(uint32_t coreNumPlatform)
{
int64_t outputH = outputShapes[NHWC_H_INDEX];
int64_t outputW = outputShapes[NHWC_W_INDEX];
if (inputFormat == ge::Format::FORMAT_NCHW) {
outputH = outputShapes[NCHW_H_INDEX];
outputW = outputShapes[NCHW_W_INDEX];
}
uint32_t dataTypeSize = GetDataTypeVal();
int64_t minAvergingCols = dataTypeSize > 0 ? 32 / dataTypeSize : outputW;
for (int64_t i = 1; i <= coreNumPlatform; i++) {
if ((coreNumPlatform % i == 0) && (outputW % i == 0)) {
int64_t j = coreNumPlatform / i;
if (outputH % j == 0) {
minAvergingCols = coreNumPlatform / i;
break;
}
}
}
if (tilingKey == SMALL_CW_TILING_KEY) {
minAvergingCols = outputW;
}
return minAvergingCols;
}
uint32_t UpsampleNearestTiling::GetNeedCoreNum(uint32_t coreNumPlatform)
{
int64_t outputH = outputShapes[NHWC_H_INDEX];
int64_t outputW = outputShapes[NHWC_W_INDEX];
if (inputFormat == ge::Format::FORMAT_NCHW) {
outputH = outputShapes[NCHW_H_INDEX];
outputW = outputShapes[NCHW_W_INDEX];
}
int64_t realCoreNum = 0;
int64_t slideSizeW = 0;
int64_t slideSizeH = 0;
int64_t colGroupNum = 0;
int64_t groupRowCoreNum = 0;
int64_t groupColCoreNum = 0;
int64_t minAvergingCols = GetBestAvergingCols(coreNumPlatform);
if (outputH < coreNumPlatform) {
int64_t tailAvergingCols = std::max(CeilA2B(outputW, coreNumPlatform), minAvergingCols);
colGroupNum = std::min(static_cast<int64_t>(coreNumPlatform), CeilA2B(outputW, tailAvergingCols));
slideSizeW = tailAvergingCols;
} else {
colGroupNum = 1;
slideSizeW = outputW;
}
groupColCoreNum = colGroupNum > 0 ? coreNumPlatform / colGroupNum : 0;
int64_t row = groupColCoreNum > 0 ? outputH / groupColCoreNum : 0;
int64_t tailAvergingRows = std::max(row, static_cast<int64_t>(1));
groupRowCoreNum = std::min(groupColCoreNum, CeilA2B(outputH, tailAvergingRows));
realCoreNum = colGroupNum * groupRowCoreNum;
int64_t tailRowRemainder = outputH - groupRowCoreNum * tailAvergingRows;
slideSizeH = tailAvergingRows;
int64_t realNeedCoreNum = 0;
int64_t tailRowOffset = 0;
int64_t tempTailRowRemainder = tailRowRemainder;
for (int64_t coreIndex = 0; coreIndex < realCoreNum; coreIndex++) {
int64_t groupColIndex = groupRowCoreNum > 0 ? coreIndex / groupRowCoreNum : 0;
tailColStartList[coreIndex] = groupColIndex * slideSizeW;
tailColEndList[coreIndex] = std::min((groupColIndex + 1) * slideSizeW, outputW);
int64_t groupRowIndex = groupRowCoreNum > 0 ? coreIndex % groupRowCoreNum : 0;
if (groupRowIndex == 0) {
tempTailRowRemainder = tailRowRemainder;
tailRowOffset = 0;
}
tailRowStartList[coreIndex] = groupRowIndex * slideSizeH + tailRowOffset;
if (tempTailRowRemainder > 0) {
tempTailRowRemainder -= 1;
tailRowOffset += 1;
}
tailRowEndList[coreIndex] = std::min((groupRowIndex + 1) * slideSizeH + tailRowOffset, outputH);
realNeedCoreNum++;
}
realNeedCoreNum = realNeedCoreNum < 1 ? 1 : realNeedCoreNum;
return realNeedCoreNum;
}
void UpsampleNearestTiling::FillTilingData()
{
tilingData.set_dataType(GetDataTypeVal());
tilingData.set_scaleH(realScaleH);
tilingData.set_scaleW(realScaleW);
tilingData.set_exactMode(exactMode);
tilingData.set_needCoreNum(needCoreNum);
tilingData.set_inputShapes(inputShapes);
tilingData.set_outputShapes(outputShapes);
tilingData.set_tailColStartList(tailColStartList);
tilingData.set_tailColEndList(tailColEndList);
tilingData.set_tailRowStartList(tailRowStartList);
tilingData.set_tailRowEndList(tailRowEndList);
tilingContext->SetBlockDim(needCoreNum);
tilingContext->SetTilingKey(tilingKey);
tilingData.SaveToBuffer(tilingContext->GetRawTilingData()->GetData(),
tilingContext->GetRawTilingData()->GetCapacity());
tilingContext->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
}
static ge::graphStatus tiling4UpsampleNearestTiling(gert::TilingContext* context)
{
UpsampleNearestTiling tilingObject(context);
return tilingObject.RunBigKernelTiling();
}
static ge::graphStatus tilingPrepareTiling(gert::TilingParseContext* context)
{
auto compileInfo = context->GetCompiledInfo<UpsampleNearestCompileInfo>();
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNumAiv();
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(UpsampleNearest)
.Tiling(tiling4UpsampleNearestTiling)
.TilingParse<UpsampleNearestCompileInfo>(tilingPrepareTiling);
}