* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file pow2_tiling.cpp
* \brief
*/
#include "log/log.h"
#include "util/math_util.h"
#include "tiling/platform/platform_ascendc.h"
#include "register/op_impl_registry.h"
#include <graph/utils/type_utils.h>
#include "../op_kernel/pow2_tiling_data.h"
#include "../op_kernel/pow2_tiling_key.h"
namespace optiling {
constexpr uint32_t SELECT_NEED = 256;
constexpr uint32_t BLOCK_SIZE = SELECT_NEED;
constexpr uint32_t DEFAULT_UB_NUM = 10;
constexpr uint32_t INT8_UB_NUM = 24;
constexpr uint32_t DIMS_LIMIT = 10;
constexpr uint32_t BUFFER_NUM = 2;
constexpr uint32_t WS_SYS_SIZE = 0;
struct Pow2CompileInfo {};
struct Pow2ShapeInfo {
uint64_t inputNum{0};
uint64_t inputBytes{0};
uint64_t tileBlockNum{0};
uint64_t tileDataNum{0};
uint64_t inputLengthAlign256{0};
uint64_t smallCoreDataNum{0};
uint64_t bigCoreDataNum{0};
uint64_t smallTailDataNum{0};
uint64_t bigTailDataNum{0};
uint64_t finalSmallTileNum{0};
uint64_t finalBigTileNum{0};
uint64_t tailBlockNum{0};
bool is_input0_scalar{1};
bool is_input1_scalar{1};
uint64_t yDim{0};
bool isSameX1{1};
bool isSameX2{1};
uint64_t strideX1[DIMS_LIMIT]{0};
uint64_t strideX2[DIMS_LIMIT]{0};
uint64_t strideY[DIMS_LIMIT]{0};
uint64_t X2TotalNum{0};
uint64_t X1TotalNum{0};
};
static ge::graphStatus TilingParseForPow2([[maybe_unused]] gert::TilingParseContext* context)
{
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, int64_t& coreNum)
{
OP_CHECK_IF(context == nullptr, OP_LOGE(context, "context is nullptr"), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
coreNum = ascendcPlatform.GetCoreNum();
OP_CHECK_IF(coreNum <= 0, OP_LOGE(context, "coreNum is 0"), return ge::GRAPH_FAILED);
OP_CHECK_IF(ubSize <= 0, OP_LOGE(context, "ubSize is 0"), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
{
OP_CHECK_IF(context == nullptr, OP_LOGE(context, "context is nullptr"), return ge::GRAPH_FAILED);
size_t usrSize = 0;
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint32_t sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
size_t* currentWorkspace = context->GetWorkspaceSizes(
1);
currentWorkspace[0] = usrSize + sysWorkspaceSize;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetShapeAttrsInfo(gert::TilingContext* context, uint64_t ubSize, Pow2ShapeInfo& info)
{
OP_CHECK_IF(
context == nullptr || context->GetOutputShape(0) == nullptr, OP_LOGE(context, "context is nullptr"),
return ge::GRAPH_FAILED);
info.inputNum = context->GetOutputShape(0)->GetStorageShape().GetShapeSize();
uint32_t typeLength = 0;
ge::TypeUtils::GetDataTypeLength(context->GetOutputDesc(0)->GetDataType(), typeLength);
uint64_t inputLength = info.inputNum * typeLength;
if (info.inputNum == 0) {
return ge::GRAPH_FAILED;
}
info.inputBytes = inputLength / info.inputNum;
auto dataType1 = context->GetInputDesc(0)->GetDataType();
auto dataType2 = context->GetInputDesc(1)->GetDataType();
auto dataType3 = context->GetOutputDesc(0)->GetDataType();
uint64_t ubDataNumber = DEFAULT_UB_NUM;
if (dataType1 == ge::DT_INT8 || dataType1 == ge::DT_UINT8 ||
dataType2 == ge::DT_INT8 || dataType2 == ge::DT_UINT8 ||
dataType3 == ge::DT_INT8 || dataType3 == ge::DT_UINT8) {
ubDataNumber = INT8_UB_NUM;
}
if ((dataType1 == ge::DT_INT8 && dataType2 == ge::DT_UINT8) ||
(dataType2 == ge::DT_INT8 && dataType1 == ge::DT_UINT8) ||
(dataType1 == ge::DT_INT8 && dataType2 == ge::DT_INT8) ||
(dataType1 == ge::DT_UINT8 && dataType2 == ge::DT_UINT8)) {
ubDataNumber = INT8_UB_NUM;
}
info.tileBlockNum = (ubSize / BUFFER_NUM / BLOCK_SIZE) / ubDataNumber;
OP_CHECK_IF(info.inputBytes == 0,
OP_LOGE(context, "GetShapeAttrsInfo error"), return ge::GRAPH_FAILED);
info.tileDataNum = (info.tileBlockNum * BLOCK_SIZE) / info.inputBytes;
info.inputLengthAlign256 = (((inputLength + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE);
bool is_input0_scalar = (context->GetInputShape(0)->GetStorageShape().GetShapeSize() == 1);
bool is_input1_scalar = (context->GetInputShape(1)->GetStorageShape().GetShapeSize() == 1);
info.is_input0_scalar = is_input0_scalar;
info.is_input1_scalar = is_input1_scalar;
uint64_t x1ShapeArr[DIMS_LIMIT] = {0};
uint64_t x2ShapeArr[DIMS_LIMIT] = {0};
uint64_t yShapeArr[DIMS_LIMIT] = {0};
const gert::Shape x1ShapeObj = context->GetInputShape(0)->GetStorageShape();
const gert::Shape x2ShapeObj = context->GetInputShape(1)->GetStorageShape();
uint64_t dimNum1 = x1ShapeObj.GetDimNum();
uint64_t dimNum2 = x2ShapeObj.GetDimNum();
uint64_t dimMax = (dimNum1 > dimNum2) ? dimNum1 : dimNum2;
uint64_t X1TotalNum = 1;
uint64_t X2TotalNum = 1;
for (uint64_t i = 0; i < dimNum1; ++i) {
x1ShapeArr[i] = static_cast<uint64_t>(x1ShapeObj.GetDim(i));
X1TotalNum *= x1ShapeArr[i];
}
for (uint64_t i = 0; i < dimNum2; ++i) {
x2ShapeArr[i] = static_cast<uint64_t>(x2ShapeObj.GetDim(i));
X2TotalNum *= x2ShapeArr[i];
}
info.X1TotalNum = X1TotalNum;
info.X2TotalNum = X2TotalNum;
for (uint32_t i = 0; i < dimMax; ++i) {
int idx1 = dimNum1 - 1 - i;
int idx2 = dimNum2 - 1 - i;
uint32_t s1 = (idx1 >= 0) ? x1ShapeObj.GetDim(idx1) : 1;
uint32_t s2 = (idx2 >= 0) ? x2ShapeObj.GetDim(idx2) : 1;
OP_CHECK_IF((s1 != s2 && s1 != 1 && s2 != 1),
OP_LOGE(context, "Broadcast Fail,Please check your input shape"), return ge::GRAPH_FAILED);
yShapeArr[dimMax - 1 - i] = (s1 > s2) ? s1 : s2;
}
uint32_t alignedX1[DIMS_LIMIT] = {0};
uint32_t alignedX2[DIMS_LIMIT] = {0};
uint32_t alignedY[DIMS_LIMIT] = {0};
for (uint32_t i = 0; i < dimMax; ++i) {
int idx1 = i - (dimMax - dimNum1);
alignedX1[i] = (idx1 >= 0) ? x1ShapeArr[idx1] : 1;
int idx2 = i - (dimMax - dimNum2);
alignedX2[i] = (idx2 >= 0) ? x2ShapeArr[idx2] : 1;
alignedY[i] = yShapeArr[i];
}
uint32_t strideY[DIMS_LIMIT] = {0};
strideY[dimMax - 1] = 1;
for (int i = dimMax - 2; i >= 0; --i) {
strideY[i] = strideY[i + 1] * alignedY[i + 1];
}
uint32_t strideX1[DIMS_LIMIT] = {0};
uint32_t strideX2[DIMS_LIMIT] = {0};
strideX1[dimMax - 1] = 1;
strideX2[dimMax - 1] = 1;
for (int i = dimMax - 2; i >= 0; --i) {
strideX1[i] = strideX1[i + 1] * alignedX1[i + 1];
strideX2[i] = strideX2[i + 1] * alignedX2[i + 1];
}
uint32_t effStrideX1[DIMS_LIMIT] = {0};
uint32_t effStrideX2[DIMS_LIMIT] = {0};
for (uint32_t i = 0; i < dimMax; ++i) {
effStrideX1[i] = (alignedX1[i] == 1) ? 0 : strideX1[i];
effStrideX2[i] = (alignedX2[i] == 1) ? 0 : strideX2[i];
}
info.yDim = dimMax;
for (uint32_t i = 0; i < DIMS_LIMIT; ++i) {
info.strideX1[i] = effStrideX1[i];
info.strideX2[i] = effStrideX2[i];
info.strideY[i] = strideY[i];
}
bool isSameX1 = true;
bool isSameX2 = true;
for (uint32_t i = 0; i < dimMax; ++i) {
if (alignedX1[i] != alignedY[i]) isSameX1 = false;
if (alignedX2[i] != alignedY[i]) isSameX2 = false;
}
info.isSameX1 = isSameX1;
info.isSameX2 = isSameX2;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CalculateCoreBlockNums(int64_t coreNum, Pow2ShapeInfo& info, gert::TilingContext* context)
{
OP_CHECK_IF(( 0 == coreNum || 0 == info.tileBlockNum || 0 == info.inputBytes), OP_LOGE(context, "GetPlatformInfo error"), return ge::GRAPH_FAILED);
uint64_t everyCoreInputBlockNum = info.inputLengthAlign256 / BLOCK_SIZE / coreNum;
info.tailBlockNum = (info.inputLengthAlign256 / BLOCK_SIZE) % coreNum;
info.smallCoreDataNum = everyCoreInputBlockNum * BLOCK_SIZE / info.inputBytes;
uint64_t smallTileNum = everyCoreInputBlockNum / info.tileBlockNum;
info.finalSmallTileNum = (everyCoreInputBlockNum % info.tileBlockNum) == 0 ? smallTileNum : smallTileNum + 1;
info.smallTailDataNum = info.smallCoreDataNum - (info.tileDataNum * smallTileNum);
info.smallTailDataNum = info.smallTailDataNum == 0 ? info.tileDataNum : info.smallTailDataNum;
everyCoreInputBlockNum += 1;
info.bigCoreDataNum = everyCoreInputBlockNum * BLOCK_SIZE / info.inputBytes;
uint64_t bigTileNum = everyCoreInputBlockNum / info.tileBlockNum;
info.finalBigTileNum = (everyCoreInputBlockNum % info.tileBlockNum) == 0 ? bigTileNum : bigTileNum + 1;
info.bigTailDataNum = info.bigCoreDataNum - info.tileDataNum * bigTileNum;
info.bigTailDataNum = info.bigTailDataNum == 0 ? info.tileDataNum : info.bigTailDataNum;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus Pow2TilingFunc(gert::TilingContext* context)
{
Pow2TilingData* tiling = context->GetTilingData<Pow2TilingData>();
OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
OP_CHECK_IF(
memset_s(tiling, sizeof(Pow2TilingData), 0, sizeof(Pow2TilingData)) != EOK,
OP_LOGE(context, "set tiling data error"), return ge::GRAPH_FAILED);
uint64_t ubSize;
int64_t coreNum;
ge::graphStatus ret = GetPlatformInfo(context, ubSize, coreNum);
OP_CHECK_IF(ret != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetPlatformInfo error"), return ge::GRAPH_FAILED);
Pow2ShapeInfo shapeInfo;
ret = GetShapeAttrsInfo(context, ubSize, shapeInfo);
OP_CHECK_IF(ret != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetShapeAttrsInfo error"), return ge::GRAPH_FAILED);
if (shapeInfo.tileDataNum >= shapeInfo.inputNum) {
coreNum = 1;
}
else {
coreNum = (static_cast<uint64_t>(coreNum) < shapeInfo.inputLengthAlign256 / BLOCK_SIZE) ? coreNum : shapeInfo.inputLengthAlign256 / BLOCK_SIZE;
}
ret = CalculateCoreBlockNums(coreNum, shapeInfo, context);
OP_CHECK_IF(ret != ge::GRAPH_SUCCESS, OP_LOGE(context, "CalculateCoreBlockNums error"), return ge::GRAPH_FAILED);
tiling->smallCoreDataNum = static_cast<uint32_t>(shapeInfo.smallCoreDataNum);
tiling->bigCoreDataNum = static_cast<uint32_t>(shapeInfo.bigCoreDataNum);
tiling->tileDataNum = static_cast<uint32_t>(shapeInfo.tileDataNum);
tiling->smallTailDataNum = static_cast<uint32_t>(shapeInfo.smallTailDataNum);
tiling->bigTailDataNum = static_cast<uint32_t>(shapeInfo.bigTailDataNum);
tiling->finalSmallTileNum = static_cast<uint32_t>(shapeInfo.finalSmallTileNum);
tiling->finalBigTileNum = static_cast<uint32_t>(shapeInfo.finalBigTileNum);
tiling->tailBlockNum = static_cast<uint32_t>(shapeInfo.tailBlockNum);
tiling->X1TotalNum = static_cast<uint32_t>(shapeInfo.X1TotalNum);
tiling->X2TotalNum = static_cast<uint32_t>(shapeInfo.X2TotalNum);
tiling->yDim = static_cast<uint32_t>(shapeInfo.yDim);
tiling->is_input0_scalar = static_cast<bool>(shapeInfo.is_input0_scalar);
tiling->is_input1_scalar = static_cast<bool>(shapeInfo.is_input1_scalar);
tiling->isSameX1 = static_cast<bool>(shapeInfo.isSameX1);
tiling->isSameX2 = static_cast<bool>(shapeInfo.isSameX2);
for (uint32_t i = 0; i < DIMS_LIMIT; i++) {
tiling->strideX1[i] = static_cast<uint32_t>(shapeInfo.strideX1[i]);
tiling->strideX2[i] = static_cast<uint32_t>(shapeInfo.strideX2[i]);
tiling->strideY[i] = static_cast<uint32_t>(shapeInfo.strideY[i]);
}
OP_CHECK_IF(GetWorkspaceSize(context) != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetWorkspaceSize error"), return ge::GRAPH_FAILED);
context->SetBlockDim(coreNum);
uint32_t tilingKey = 0;
tilingKey = GET_TPL_TILING_KEY(ELEMENTWISE_TPL_SCH_MODE_0);
context->SetTilingKey(tilingKey);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(Pow2).Tiling(Pow2TilingFunc).TilingParse<Pow2CompileInfo>(TilingParseForPow2);
}