* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* NOTE: Portions of this code were AI-generated and have been
* technically reviewed for functional accuracy and security
*/
* \file ndtri_tiling.cpp
* \brief Ndtri Tiling 实现(arch35 / Ascend950)
*/
#include "register/op_def_registry.h"
#include "op_common/log/log.h"
#include "op_common/op_host/util/math_util.h"
#include "op_common/op_host/util/platform_util.h"
#include "../op_kernel/ndtri_tiling_data.h"
#include "../op_kernel/ndtri_tiling_key.h"
namespace optiling {
constexpr uint32_t WS_USER_SIZE = 0U;
static constexpr size_t IDX_SELF = 0;
constexpr int64_t TYPE_SIZE_FP32 = 4;
constexpr int64_t TYPE_SIZE_FP16_BF16 = 2;
constexpr int64_t RESERVED_UB = 48 * 1024;
constexpr int64_t BYTE_PER_ELEM = 64;
constexpr int64_t TILE_ALIGN = 256;
static const gert::Shape K_VEC_1_SHAPE = {1};
static inline const gert::Shape EnsureNotScalar(const gert::Shape& in_shape)
{
if (in_shape.GetDimNum() == 0) {
return K_VEC_1_SHAPE;
}
return in_shape;
}
static ge::graphStatus GetPlatformInfo(
gert::TilingContext* context, uint64_t& ubSize, int64_t& coreNum,
uint32_t& sysWorkspaceSize)
{
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
coreNum = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF(coreNum == 0, OP_LOGE(context, "coreNum is 0"),
return ge::GRAPH_FAILED);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
OP_CHECK_IF(ubSize == 0, OP_LOGE(context, "ubSize is 0"),
return ge::GRAPH_FAILED);
sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckDtype(gert::TilingContext* context, ge::DataType& dtype)
{
auto selfDesc = context->GetInputDesc(IDX_SELF);
OP_CHECK_NULL_WITH_CONTEXT(context, selfDesc);
dtype = selfDesc->GetDataType();
const std::set<ge::DataType> supported = {
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16};
OP_CHECK_IF(supported.count(dtype) == 0,
OP_LOGE(context, "Ndtri: unsupported dtype %d",
static_cast<int>(dtype)),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetTotalNum(gert::TilingContext* context, int64_t& totalNum)
{
auto selfShapePtr = context->GetInputShape(IDX_SELF);
OP_CHECK_NULL_WITH_CONTEXT(context, selfShapePtr);
auto selfShape = EnsureNotScalar(selfShapePtr->GetStorageShape());
totalNum = selfShape.GetShapeSize();
OP_CHECK_IF(totalNum <= 0,
OP_LOGE(context, "Ndtri: totalNum must > 0, got %ld", totalNum),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus DoTiling(
gert::TilingContext* context, ge::DataType dtype, int64_t totalNum,
uint64_t ubSize, int64_t coreNum,
NdtriTilingData* tiling, int64_t& usedCoreNum, int64_t& alignElem)
{
int64_t ubBlockSize = Ops::Base::GetUbBlockSize(context);
int64_t typeSize = (dtype == ge::DT_FLOAT) ? TYPE_SIZE_FP32 : TYPE_SIZE_FP16_BF16;
OP_CHECK_IF(typeSize <= 0, OP_LOGE(context, "typeSize<=0"),
return ge::GRAPH_FAILED);
alignElem = ubBlockSize / typeSize;
OP_CHECK_IF(alignElem <= 0, OP_LOGE(context, "alignElem<=0"),
return ge::GRAPH_FAILED);
if (totalNum < alignElem) {
tiling->blockFactor = totalNum;
usedCoreNum = 1;
} else {
int64_t perCoreRaw = Ops::Base::CeilDiv(totalNum, coreNum);
tiling->blockFactor = Ops::Base::CeilAlign(perCoreRaw, alignElem);
usedCoreNum = Ops::Base::CeilDiv(totalNum, tiling->blockFactor);
}
OP_CHECK_IF(usedCoreNum == 0, OP_LOGE(context, "usedCoreNum is 0"),
return ge::GRAPH_FAILED);
int64_t availableUb = static_cast<int64_t>(ubSize) - RESERVED_UB;
OP_CHECK_IF(availableUb <= 0, OP_LOGE(context, "availableUb<=0"),
return ge::GRAPH_FAILED);
int64_t tileElem = availableUb / BYTE_PER_ELEM;
tileElem = Ops::Base::FloorAlign(tileElem, TILE_ALIGN);
if (tileElem < alignElem) {
tileElem = alignElem;
}
tiling->ubFactor = tileElem;
return ge::GRAPH_SUCCESS;
}
static void DispatchTilingKey(
gert::TilingContext* context, ge::DataType dtype, int64_t totalNum, int64_t alignElem)
{
uint32_t dtypeKey;
if (dtype == ge::DT_FLOAT) {
dtypeKey = static_cast<uint32_t>(C_DT_FLOAT);
} else if (dtype == ge::DT_FLOAT16) {
dtypeKey = static_cast<uint32_t>(C_DT_FLOAT16);
} else {
dtypeKey = static_cast<uint32_t>(C_DT_BF16);
}
uint32_t isAlign = (alignElem > 0 && totalNum % alignElem == 0) ? 1U : 0U;
ASCENDC_TPL_SEL_PARAM(context, dtypeKey, isAlign);
}
static ge::graphStatus NdtriTilingFunc(gert::TilingContext* context)
{
uint64_t ubSize = 0;
int64_t coreNum = 0;
uint32_t sysWorkspaceSize = 0;
OP_CHECK_IF(GetPlatformInfo(context, ubSize, coreNum, sysWorkspaceSize) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetPlatformInfo error"),
return ge::GRAPH_FAILED);
ge::DataType dtype;
OP_CHECK_IF(CheckDtype(context, dtype) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "CheckDtype error"),
return ge::GRAPH_FAILED);
int64_t totalNum = 0;
OP_CHECK_IF(GetTotalNum(context, totalNum) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetTotalNum error"),
return ge::GRAPH_FAILED);
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
currentWorkspace[0] = WS_USER_SIZE + sysWorkspaceSize;
NdtriTilingData* tiling = context->GetTilingData<NdtriTilingData>();
OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
OP_CHECK_IF(memset_s(tiling, sizeof(NdtriTilingData), 0, sizeof(NdtriTilingData)) != EOK,
OP_LOGE(context, "set tiling data error"),
return ge::GRAPH_FAILED);
tiling->totalNum = totalNum;
int64_t usedCoreNum = 0;
int64_t alignElem = 0;
OP_CHECK_IF(DoTiling(context, dtype, totalNum, ubSize, coreNum,
tiling, usedCoreNum, alignElem) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "DoTiling error"),
return ge::GRAPH_FAILED);
context->SetBlockDim(usedCoreNum);
DispatchTilingKey(context, dtype, totalNum, alignElem);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingParseForNdtri(
[[maybe_unused]] gert::TilingParseContext* context)
{
return ge::GRAPH_SUCCESS;
}
struct NdtriCompileInfo {};
IMPL_OP_OPTILING(Ndtri)
.Tiling(NdtriTilingFunc)
.TilingParse<NdtriCompileInfo>(TilingParseForNdtri);
}