* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "log/log.h"
#include "util/math_util.h"
#include "tiling/platform/platform_ascendc.h"
#include "register/op_impl_registry.h"
#include <graph/utils/type_utils.h>
#include <cmath>
#include "../op_kernel/round_tiling_data.h"
#include "../op_kernel/round_tiling_key.h"
#include"util/platform_util.h"
namespace optiling {
#define UB_DATA_NUM_INT32 4U
#define UB_DATA_NUM_F16_BF16_NO_DECIMAL 8U
#define UB_DATA_NUM_F16_BF16_WITH_DECIMAL 10U
#define UB_DATA_NUM_FLOAT_NO_DECIMAL 5U
#define UB_DATA_NUM_FLOAT_WITH_DECIMAL 6U
const uint64_t BUFFER_NUM = 2;
struct RoundCompileInfo {};
static ge::graphStatus TilingParseRound([[maybe_unused]] gert::TilingParseContext* context)
{
OP_CHECK_NULL_WITH_CONTEXT(context,context);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
{
OP_CHECK_NULL_WITH_CONTEXT(context,context);
size_t usrSize = 0;
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint64_t sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
size_t* currentWorkspace = context->GetWorkspaceSizes(
1);
currentWorkspace[0] = usrSize + sysWorkspaceSize;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingFunc(gert::TilingContext* context)
{
RoundTilingData* tiling = context->GetTilingData<RoundTilingData>();
uint64_t ubSize = 0;
uint64_t bigCoreDataNum = 0;
uint64_t bigTileNum = 0;
uint64_t finalBigTileNum = 0;
uint64_t bigTailDataNum = 0;
int64_t decimals = 0;
uint64_t ubDataNumber=8;
auto blockSize=Ops::Base::GetUbBlockSize(context);
const gert::RuntimeAttrs* attrs = context->GetAttrs();
if (attrs != nullptr) {
const int64_t* decimals_ptr = attrs->GetInt(0);
if (decimals_ptr != nullptr) decimals = *decimals_ptr;
}
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
auto coreNum = ascendcPlatform.GetCoreNum();
auto inputShape = context->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context,inputShape);
uint64_t inputNum = inputShape->GetStorageShape().GetShapeSize();
uint32_t typeLength = 0;
auto inputDesc = context->GetInputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context,inputDesc);
ge::TypeUtils::GetDataTypeLength(inputDesc->GetDataType(), typeLength);
if (inputNum == 0)
{
return ge::GRAPH_FAILED;
}
uint64_t inputLength = inputNum * typeLength;
if (coreNum == 0 || blockSize == 0)
{
return ge::GRAPH_FAILED;
}
ge::DataType dataType =inputDesc->GetDataType();
if (dataType == ge::DT_INT32) {
ubDataNumber = UB_DATA_NUM_INT32;
} else if (dataType == ge::DT_FLOAT16 || dataType == ge::DT_BF16) {
if (decimals) {
ubDataNumber = UB_DATA_NUM_F16_BF16_WITH_DECIMAL;
} else {
ubDataNumber = UB_DATA_NUM_F16_BF16_NO_DECIMAL;
}
} else if (dataType == ge::DT_FLOAT) {
if (decimals) {
ubDataNumber = UB_DATA_NUM_FLOAT_WITH_DECIMAL;
} else {
ubDataNumber = UB_DATA_NUM_FLOAT_NO_DECIMAL;
}
} else {
return ge::GRAPH_FAILED;
}
总共能存(ubSize / blockSize)个block,
每个block做buffer优化,能存ubSize / blockSize /BUFFER_NUM /个块数据
总共要存储ubDataNumber个变量,每个变量存储ubSize / blockSize /BUFFER_NUM /ubDataNumber个数据块
*/
uint64_t tileBlockNum = ubSize / blockSize / ubDataNumber;
计算每个分块有多少个数值:
总共tileBlockNum个数据块,每个数据块占blockSize(32字节),总共能存储tileBlockNum * blockSize个字节
每个自然数据占typeLength个字节,总共单次可计算tileDataNum个数据
*/
uint64_t tileDataNum = (tileBlockNum * blockSize) / typeLength;
uint64_t inputLengthAlgin32 = (((inputLength + blockSize - 1) / blockSize) * blockSize);
if (inputNum <= tileDataNum)
{
coreNum = 1;
}
else
{
coreNum = (coreNum < inputLengthAlgin32 / blockSize) ? coreNum : inputLengthAlgin32 / blockSize;
coreNum = (coreNum >= 1) ? coreNum : 1;
}
uint64_t everyCoreInputBlockNum = inputLengthAlgin32 / blockSize / coreNum;
uint64_t tailBlockNum = (inputLengthAlgin32 / blockSize) % coreNum;
uint64_t smallCoreDataNum = everyCoreInputBlockNum * blockSize / typeLength;
uint64_t smallTileNum = everyCoreInputBlockNum / tileBlockNum;
uint64_t finalSmallTileNum = (everyCoreInputBlockNum % tileBlockNum) == 0 ? smallTileNum : smallTileNum + 1;
uint64_t smallTailDataNum = smallCoreDataNum - (tileDataNum * smallTileNum);
smallTailDataNum = smallTailDataNum == 0 ? tileDataNum : smallTailDataNum;
if (0 != tailBlockNum)
{
everyCoreInputBlockNum += 1;
bigCoreDataNum = everyCoreInputBlockNum * blockSize / typeLength;
bigTileNum = bigCoreDataNum / tileDataNum;
finalBigTileNum = (everyCoreInputBlockNum % tileBlockNum) == 0 ? bigTileNum : bigTileNum + 1;
bigTailDataNum = bigCoreDataNum - (tileDataNum * bigTileNum);
bigTailDataNum = bigTailDataNum == 0 ? tileDataNum : bigTailDataNum;
}
tiling->smallCoreDataNum=((uint64_t)smallCoreDataNum);
tiling->bigCoreDataNum=((uint64_t)bigCoreDataNum);
tiling->tileDataNum=((uint64_t)tileDataNum);
tiling->smallTailDataNum=((uint64_t)smallTailDataNum);
tiling->bigTailDataNum=((uint64_t)bigTailDataNum);
tiling->finalSmallTileNum=((uint64_t)finalSmallTileNum);
tiling->finalBigTileNum=((uint64_t)finalBigTileNum);
tiling->tailBlockNum=((uint64_t)tailBlockNum);
float decimals_float = static_cast<float>(std::pow(10, decimals));
context->SetTilingKey(0);
tiling->decimals=((float)decimals_float);
context->SetBlockDim(coreNum);
OP_CHECK_IF(
GetWorkspaceSize(context) != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetWorkspaceSize error"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(Round).Tiling(TilingFunc).TilingParse<RoundCompileInfo>(TilingParseRound);
}