* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file pows_tiling.cpp
* \brief
*/
#include "pows_tiling.h"
#include <graph/utils/type_utils.h>
#include "register/op_impl_registry.h"
#include "util/math_util.h"
#include "log/log.h"
#include "platform/platform_ascendc.h"
namespace optiling {
static const uint32_t INPUT_IDX = 0;
static const uint32_t DIM_0 = 0;
static const uint32_t DIM_1 = 1;
static const uint32_t DIM_2 = 2;
static const uint32_t DIM_3 = 3;
static const uint32_t ATTR_DIM_INDEX = 0;
static const uint32_t ATTR_APPROXIMATE_INDEX = 1;
static const uint32_t ATTR_ACTIVATE_LEFT_INDEX = 2;
static const uint32_t FP16_DTYPE_BYTES = 2;
static const uint32_t FP32_DTYPE_BYTES = 4;
static const uint32_t FP16_COEXISTING_NUM = 7;
static const uint32_t FP32_COEXISTING_NUM = 3;
static const uint32_t WORK_SPACE_SIZE = 32;
static const uint32_t SPLIT_FACTOR = 2;
static const uint32_t SPLIT_ERROR_STATUS = 10000;
static const int64_t APPROXIMATE_USING_TANH = 1;
static const int64_t APPROXIMATE_USING_ERF = 0;
static const int64_t BYTES_ONE_BLOCK = 32;
static const int64_t MULTI_CORE_SHAPE_SIZE_LIMIT = 4096;
static const int64_t BUFFER_SIZE_ALIGN_LENGTH = 256;
inline static ge::graphStatus SetTilingDataForPows(gert::TilingContext* context, PowsTilingData& tilingData)
{
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
return ge::GRAPH_SUCCESS;
}
static void GetTilingDataNonCut(PowsTilingData& tilingData, TilingParam& tilingParam)
{
tilingData.set_numPerCore(tilingParam.x);
tilingData.set_realCoreNum(1);
tilingData.set_mainCoreLoopNum(1);
tilingData.set_mainCoreTailLength(0);
tilingData.set_dataLength(tilingParam.x);
tilingData.set_tailCoreLoopNum(0);
tilingData.set_tailCoreTailLength(0);
tilingData.set_bufSize(tilingParam.bufSize / BUFFER_SIZE_ALIGN_LENGTH * BUFFER_SIZE_ALIGN_LENGTH);
tilingData.set_blockSize(tilingParam.blockSize);
}
static void GetTilingDataBigCase(PowsTilingData& tilingData, TilingParam& tilingParam)
{
int64_t bufSizeAlign = tilingParam.bufSize / BUFFER_SIZE_ALIGN_LENGTH * BUFFER_SIZE_ALIGN_LENGTH;
int64_t blockFactor = Ops::Base::CeilDiv(tilingParam.x, tilingParam.coreNum);
int64_t blockFactorAlign =
(blockFactor + tilingParam.blockSize - 1) / tilingParam.blockSize * tilingParam.blockSize;
int64_t realCoreNum = Ops::Base::CeilDiv(tilingParam.x, blockFactorAlign);
tilingData.set_numPerCore(blockFactorAlign);
tilingData.set_realCoreNum(realCoreNum);
int64_t ubFactorAlign = bufSizeAlign / tilingParam.blockSize * tilingParam.blockSize;
ubFactorAlign = ubFactorAlign > blockFactorAlign ? blockFactorAlign : ubFactorAlign;
tilingData.set_mainCoreLoopNum(blockFactorAlign / ubFactorAlign);
tilingData.set_mainCoreTailLength(blockFactorAlign % ubFactorAlign);
tilingData.set_dataLength(ubFactorAlign);
if (tilingParam.x % blockFactorAlign != 0) {
int64_t tailCoreTotalNum = tilingParam.x - blockFactorAlign * (realCoreNum - 1);
tilingData.set_tailCoreLoopNum(tailCoreTotalNum / ubFactorAlign);
tilingData.set_tailCoreTailLength(tailCoreTotalNum % ubFactorAlign);
} else {
tilingData.set_tailCoreLoopNum(0);
tilingData.set_tailCoreTailLength(0);
}
tilingData.set_bufSize(bufSizeAlign);
tilingData.set_blockSize(tilingParam.blockSize);
}
static void GetTilingData(PowsTilingData& tilingData, TilingParam& tilingParam)
{
if (tilingParam.x < MULTI_CORE_SHAPE_SIZE_LIMIT) {
GetTilingDataNonCut(tilingData, tilingParam);
} else {
GetTilingDataBigCase(tilingData, tilingParam);
}
}
static void GetFp16TilingData(PowsTilingData& tilingData, TilingParam& tilingParam)
{
tilingParam.bufSize = tilingParam.ubSize / (FP16_DTYPE_BYTES * FP16_COEXISTING_NUM);
tilingParam.blockSize = BYTES_ONE_BLOCK / FP16_DTYPE_BYTES;
GetTilingData(tilingData, tilingParam);
tilingData.set_tilingKey(static_cast<int64_t>(PowsTilingKey::TILINGKEY_101));
}
static void GetBf16TilingData(PowsTilingData& tilingData, TilingParam& tilingParam)
{
tilingParam.bufSize = tilingParam.ubSize / (FP16_DTYPE_BYTES * FP16_COEXISTING_NUM);
tilingParam.blockSize = BYTES_ONE_BLOCK / FP16_DTYPE_BYTES;
GetTilingData(tilingData, tilingParam);
int64_t tilingkey = static_cast<int64_t>(PowsTilingKey::TILINGKEY_201);
tilingData.set_tilingKey(tilingkey);
}
static void GetFp32TilingData(PowsTilingData& tilingData, TilingParam& tilingParam)
{
tilingParam.bufSize = tilingParam.ubSize / (FP32_DTYPE_BYTES * FP32_COEXISTING_NUM);
tilingParam.blockSize = BYTES_ONE_BLOCK / FP32_DTYPE_BYTES;
GetTilingData(tilingData, tilingParam);
int64_t tilingkey = static_cast<int64_t>(PowsTilingKey::TILINGKEY_301);
tilingData.set_tilingKey(tilingkey);
}
static ge::graphStatus CheckInputParams(const gert::TilingContext* context)
{
auto input = context->GetInputTensor(INPUT_IDX);
OP_CHECK_NULL_WITH_CONTEXT(context, input);
auto dtype = context->GetInputDesc(INPUT_IDX)->GetDataType();
int32_t typeSize = ge::GetSizeByDataType(dtype);
OP_CHECK_IF(
dtype != ge::DT_FLOAT16 && dtype != ge::DT_BF16 && dtype != ge::DT_FLOAT,
OP_LOGE_FOR_INVALID_DTYPE(context->GetNodeName(), "x1",
ge::TypeUtils::DataTypeToSerialString(dtype).c_str(), "fp16, fp32 or bf16"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
(typeSize <= 0), OP_LOGE(context, "typeSize is invalid %d, please check.", typeSize), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingPrepare4Pows(gert::TilingParseContext* context)
{
OP_LOGD(context, "TilingPrepare4Pows enter.");
auto compileInfo = context->GetCompiledInfo<PowsCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF(
(compileInfo->totalCoreNum <= 0), OP_LOGE(context, "TilingPrepare4Pows fail to get core num."),
return ge::GRAPH_FAILED);
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
compileInfo->ubSizePlatForm = static_cast<int64_t>(ubSizePlatForm);
OP_CHECK_IF(
(compileInfo->ubSizePlatForm <= 0), OP_LOGE(context, "TilingPrepare4Pows fail to get ub size."),
return ge::GRAPH_FAILED);
OP_LOGD(context, "TilingPrepare4Pows exit.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetTillingParam(const gert::TilingContext* context, TilingParam& tilingParam)
{
auto inputShape = context->GetInputTensor(INPUT_IDX)->GetStorageShape();
int64_t x{1};
for (size_t i = 0; i < inputShape.GetDimNum(); i++) {
x *= inputShape.GetDim(i);
}
auto compileInfo = reinterpret_cast<const PowsCompileInfo*>(context->GetCompileInfo());
tilingParam.x = x;
tilingParam.coreNum = compileInfo->totalCoreNum;
tilingParam.ubSize = compileInfo->ubSizePlatForm;
OP_LOGI(
context, "tilingParm is x: %ld, coreNum: %ld, ubSize: %ld", tilingParam.x, tilingParam.coreNum,
tilingParam.ubSize);
return ge::GRAPH_SUCCESS;
}
static void GetTillingData(ge::DataType dtype, TilingParam& tilingParam, PowsTilingData& tilingData)
{
if (dtype == ge::DT_FLOAT16) {
GetFp16TilingData(tilingData, tilingParam);
} else if (dtype == ge::DT_BF16) {
GetBf16TilingData(tilingData, tilingParam);
} else {
GetFp32TilingData(tilingData, tilingParam);
}
}
static ge::graphStatus Tiling4Pows(gert::TilingContext* context)
{
OP_LOGD(context, "Tiling4Pows enter.");
OP_CHECK_IF(
CheckInputParams(context) != ge::GRAPH_SUCCESS, OP_LOGE(context, "InputParams not valid."),
return ge::GRAPH_FAILED);
TilingParam tilingParam;
OP_CHECK_IF(
GetTillingParam(context, tilingParam) != ge::GRAPH_SUCCESS, OP_LOGE(context, "Get Tiling Param Failed."),
return ge::GRAPH_FAILED);
auto dtype = context->GetInputDesc(INPUT_IDX)->GetDataType();
PowsTilingData tilingData;
GetTillingData(dtype, tilingParam, tilingData);
OP_CHECK_IF(
SetTilingDataForPows(context, tilingData) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "PowsSetTilingData set tiling data fail."), return ge::GRAPH_FAILED);
context->SetBlockDim(tilingData.get_realCoreNum());
context->SetTilingKey(tilingData.get_tilingKey());
size_t* workspaces = context->GetWorkspaceSizes(1);
workspaces[0] = WORK_SPACE_SIZE + tilingParam.coreNum * BYTES_ONE_BLOCK;
OP_LOGI(
context,
"tilingData is bufSize: %ld, tilingKey: %ld, numPerCore: %ld, realCoreNum: %ld, \
mainCoreLoopNum: %ld, mainCoreTailLength: %ld, tailCoreloopNum: %ld, tailCoreTailLength: %ld, \
dataLength: %ld, blockSize: %ld",
tilingData.get_bufSize(), tilingData.get_tilingKey(), tilingData.get_numPerCore(), tilingData.get_realCoreNum(),
tilingData.get_mainCoreLoopNum(), tilingData.get_mainCoreTailLength(), tilingData.get_tailCoreLoopNum(),
tilingData.get_tailCoreTailLength(), tilingData.get_dataLength(), tilingData.get_blockSize());
OP_LOGD(context, "Tiling4Pows exit.");
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(Pows).Tiling(Tiling4Pows).TilingParse<PowsCompileInfo>(TilingPrepare4Pows);
}