* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* NOTE: Portions of this code were AI-generated and have been
* technically reviewed for functional accuracy and security
*/
* \file log_space_tiling.cpp
* \brief LogSpace Tiling 实现(arch35 / ascend950)
*
* 完整实现 6 个 TilingKey(fp32/fp16/bf16 × NORMAL/SINGLE)。
* Tiling 按输出 dtype 计算 stepF/logBase/分核策略,TilingKey 通过 (D_T_Y, MODE) 二元组下发。
*/
#include <cmath>
#include <climits>
#include <limits>
#include "register/op_def_registry.h"
#include "op_common/log/log.h"
#include "op_common/op_host/util/math_util.h"
#include "op_common/op_host/util/platform_util.h"
#include "../op_kernel/log_space_tiling_data.h"
#include "../op_kernel/log_space_tiling_key.h"
namespace optiling {
using Ops::Base::CeilDiv;
using Ops::Base::CeilAlign;
using Ops::Base::FloorDiv;
using Ops::Base::FloorAlign;
constexpr uint32_t WS_SYS_SIZE = 0U;
constexpr int64_t MIN_PER_CORE = 64;
constexpr int64_t UB_CHUNK_ELEMS = 2048;
constexpr int ATTR_IDX_START = 0;
constexpr int ATTR_IDX_END = 1;
constexpr int ATTR_IDX_STEPS = 2;
constexpr int ATTR_IDX_BASE = 3;
constexpr uint32_t MODE_NORMAL = 0;
constexpr uint32_t MODE_SINGLE = 1;
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, int64_t& coreNum)
{
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
coreNum = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF(coreNum == 0, OP_LOGE(context, "coreNum is 0"), return ge::GRAPH_FAILED);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
OP_CHECK_IF(ubSize == 0, OP_LOGE(context, "ubSize is 0"), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
{
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
currentWorkspace[0] = WS_SYS_SIZE;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus LogSpaceTilingFunc(gert::TilingContext* context)
{
uint64_t ubSize = 0;
int64_t coreNum = 0;
OP_CHECK_IF(GetPlatformInfo(context, ubSize, coreNum) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetPlatformInfo error"), return ge::GRAPH_FAILED);
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const float* startPtr = attrs->GetAttrPointer<float>(ATTR_IDX_START);
const float* endPtr = attrs->GetAttrPointer<float>(ATTR_IDX_END);
const int64_t* stepsPtr = attrs->GetAttrPointer<int64_t>(ATTR_IDX_STEPS);
const float* basePtr = attrs->GetAttrPointer<float>(ATTR_IDX_BASE);
OP_CHECK_NULL_WITH_CONTEXT(context, startPtr);
OP_CHECK_NULL_WITH_CONTEXT(context, endPtr);
OP_CHECK_NULL_WITH_CONTEXT(context, stepsPtr);
OP_CHECK_NULL_WITH_CONTEXT(context, basePtr);
const float startF = *startPtr;
const float endF = *endPtr;
const int64_t steps = *stepsPtr;
const float baseF = *basePtr;
OP_CHECK_IF(steps < 0, OP_LOGE(context, "steps must be >= 0, got %ld", steps),
return ge::GRAPH_FAILED);
OP_CHECK_IF(steps > static_cast<int64_t>(std::numeric_limits<uint32_t>::max()),
OP_LOGE(context, "steps must be <= UINT32_MAX (%u), got %ld",
std::numeric_limits<uint32_t>::max(), steps),
return ge::GRAPH_FAILED);
OP_CHECK_IF(baseF <= 0.0f, OP_LOGE(context, "base must be > 0, got %f", baseF),
return ge::GRAPH_FAILED);
auto outDesc = context->GetOutputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context, outDesc);
ge::DataType dtype = outDesc->GetDataType();
OP_CHECK_IF(dtype != ge::DT_FLOAT && dtype != ge::DT_FLOAT16 && dtype != ge::DT_BF16,
OP_LOGE(context, "unsupported dtype %d", static_cast<int>(dtype)),
return ge::GRAPH_FAILED);
OP_CHECK_IF(GetWorkspaceSize(context) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetWorkspaceSize error"), return ge::GRAPH_FAILED);
LogSpaceTilingData* tiling = context->GetTilingData<LogSpaceTilingData>();
OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
OP_CHECK_IF(memset_s(tiling, sizeof(LogSpaceTilingData), 0, sizeof(LogSpaceTilingData)) != EOK,
OP_LOGE(context, "memset tiling data failed"), return ge::GRAPH_FAILED);
tiling->totalLen = static_cast<uint64_t>(steps);
tiling->startF = startF;
tiling->logBase = std::log(baseF);
tiling->ubChunk = static_cast<uint32_t>(UB_CHUNK_ELEMS);
uint32_t mode = MODE_NORMAL;
int64_t usedCoreNum = 1;
if (steps <= 0) {
tiling->coreNum = 1;
tiling->tileLen = 0;
tiling->tailTileLen = 0;
tiling->tailCoreIdx = 0;
tiling->stepF = 0.0f;
mode = MODE_SINGLE;
} else if (steps == 1) {
tiling->coreNum = 1;
tiling->tileLen = 1;
tiling->tailTileLen = 1;
tiling->tailCoreIdx = 0;
tiling->stepF = 0.0f;
mode = MODE_SINGLE;
usedCoreNum = 1;
} else {
int64_t maxCores = CeilDiv(steps, MIN_PER_CORE);
int64_t useCores = (maxCores < coreNum) ? maxCores : coreNum;
if (useCores < 1) useCores = 1;
int64_t tileLen = steps / useCores;
int64_t tailLen = steps - tileLen * (useCores - 1);
tiling->coreNum = static_cast<uint32_t>(useCores);
tiling->tileLen = static_cast<uint32_t>(tileLen);
tiling->tailTileLen = static_cast<uint32_t>(tailLen);
tiling->tailCoreIdx = static_cast<uint32_t>(useCores - 1);
tiling->stepF = (endF - startF) / static_cast<float>(steps - 1);
mode = MODE_NORMAL;
usedCoreNum = useCores;
}
context->SetBlockDim(static_cast<uint32_t>(usedCoreNum));
uint32_t dTypeY = static_cast<uint32_t>(dtype);
ASCENDC_TPL_SEL_PARAM(context, dTypeY, mode);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingParseForLogSpace([[maybe_unused]] gert::TilingParseContext* context)
{
return ge::GRAPH_SUCCESS;
}
struct LogSpaceCompileInfo {};
IMPL_OP_OPTILING(LogSpace).Tiling(LogSpaceTilingFunc).TilingParse<LogSpaceCompileInfo>(TilingParseForLogSpace);
}