* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* NOTE: Portions of this code were AI-generated and have been
* technically reviewed for functional accuracy and security
*/
* \file eltwise_tiling_arch35.cpp
* \brief Eltwise tiling implementation (arch35)
*
* Tiling strategy:
* 1. Multi-core: divide total elements evenly across AI Cores
* 2. UB: divide per-core elements into UB-sized chunks
* 3. Buffer layout depends on dtype:
* - FP32: inputBuf(1) + accBuf(1) + outputBuf(1) = 3 buffers
* - FP16/BF16: inputBuf(1) + castBuf(1) + accBuf(1) + outputBuf(1) = 4 buffers
*
* UB factor formula:
* FP32: ubFactor = FloorAlign(ubSize / (3 * 4), ubBlockSize)
* FP16/BF16: ubFactor = FloorAlign(ubSize / (2 + 4 + 4 + 2), ubBlockSize)
* = FloorAlign(ubSize / 12, ubBlockSize)
*/
#include "register/op_def_registry.h"
#include "op_common/log/log.h"
#include "op_common/op_host/util/math_util.h"
#include "op_common/op_host/util/platform_util.h"
#include "../../op_kernel/arch35/eltwise_tiling_data.h"
#include "../../op_kernel/arch35/eltwise_tiling_key.h"
namespace optiling {
using Ops::Base::CeilDiv;
using Ops::Base::FloorDiv;
using Ops::Base::FloorAlign;
constexpr uint32_t WS_SYS_SIZE = 0U;
constexpr uint32_t MAX_INPUT_NUM = 32;
struct TilingParams {
uint64_t ubSize = 0;
int64_t coreNum = 0;
uint32_t inputNum = 0;
ge::DataType dtype = ge::DT_FLOAT;
int64_t totalNum = 0;
int64_t mode = 1;
int64_t typeSize = 4;
};
static const gert::Shape g_vec_1_shape = {1};
static inline const gert::Shape EnsureNotScalar(const gert::Shape& in_shape)
{
if (in_shape.GetDimNum() == 0) {
return g_vec_1_shape;
}
return in_shape;
}
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, int64_t& coreNum)
{
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
if (platformInfoPtr != nullptr) {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
coreNum = ascendcPlatform.GetCoreNumAiv();
if (coreNum == 0) {
coreNum = ascendcPlatform.GetCoreNum();
}
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
}
constexpr int64_t FALLBACK_CORE_NUM = 40;
constexpr uint64_t FALLBACK_UB_SIZE = 253952;
if (coreNum == 0) {
OP_LOGW(context, "Eltwise: failed to get core num, using fallback %ld", FALLBACK_CORE_NUM);
coreNum = FALLBACK_CORE_NUM;
}
if (ubSize == 0) {
OP_LOGW(context, "Eltwise: failed to get ub size, using fallback %lu", FALLBACK_UB_SIZE);
ubSize = FALLBACK_UB_SIZE;
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
{
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
currentWorkspace[0] = WS_SYS_SIZE;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetInputInfo(gert::TilingContext* context, uint32_t& inputNum,
ge::DataType& dtype, int64_t& totalNum)
{
auto computeNodeInfo = context->GetComputeNodeInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, computeNodeInfo);
inputNum = static_cast<uint32_t>(computeNodeInfo->GetInputsNum());
OP_CHECK_IF(inputNum == 0 || inputNum > MAX_INPUT_NUM,
OP_LOGE(context, "Eltwise: invalid inputNum %u", inputNum),
return ge::GRAPH_FAILED);
auto inputDesc = context->GetInputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context, inputDesc);
dtype = inputDesc->GetDataType();
auto inputShape0 = context->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, inputShape0);
auto shape = EnsureNotScalar(inputShape0->GetStorageShape());
totalNum = shape.GetShapeSize();
OP_CHECK_IF(totalNum < 0,
OP_LOGE(context, "Eltwise: invalid totalNum %ld (negative shape)", totalNum),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetModeAttr(gert::TilingContext* context, int64_t& mode)
{
const auto* attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
mode = 1;
const int64_t* modePtr = attrs->GetAttrPointer<int64_t>(0);
if (modePtr != nullptr) {
mode = *modePtr;
}
OP_CHECK_IF(mode < 0 || mode > 2,
OP_LOGE(context, "Eltwise: invalid mode %ld", mode),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetTypeSize(gert::TilingContext* context, ge::DataType dtype, int64_t& typeSize)
{
switch (dtype) {
case ge::DT_FLOAT:
typeSize = 4;
return ge::GRAPH_SUCCESS;
case ge::DT_FLOAT16:
case ge::DT_BF16:
typeSize = 2;
return ge::GRAPH_SUCCESS;
default:
OP_LOGE(context, "Eltwise: unsupported dtype %d", static_cast<int>(dtype));
return ge::GRAPH_FAILED;
}
}
static void SetEmptyTilingAndKey(gert::TilingContext* context, EltwiseTilingData* tiling,
uint32_t inputNum, ge::DataType dtype, int64_t mode)
{
tiling->totalNum = 0;
tiling->blockFactor = 0;
tiling->ubFactor = 0;
tiling->inputNum = inputNum;
context->SetBlockDim(1);
uint32_t dType = static_cast<uint32_t>(dtype);
uint32_t modeVal = static_cast<uint32_t>(mode);
ASCENDC_TPL_SEL_PARAM(context, dType, modeVal);
}
static ge::graphStatus ComputeUbFactor(gert::TilingContext* context, uint64_t ubSize,
ge::DataType dtype, int64_t typeSize,
int64_t ubBlockSize, int64_t& ubFactor)
{
constexpr int64_t UB_SYS_OVERHEAD = 2048;
int64_t availUbSize = static_cast<int64_t>(ubSize) - UB_SYS_OVERHEAD;
if (availUbSize < 0) {
availUbSize = static_cast<int64_t>(ubSize);
}
int64_t bytesPerElem;
if (dtype == ge::DT_FLOAT) {
bytesPerElem = 3 * static_cast<int64_t>(sizeof(float));
} else {
bytesPerElem = 2 * typeSize + 2 * static_cast<int64_t>(sizeof(float));
}
ubFactor = FloorAlign(availUbSize / bytesPerElem, ubBlockSize);
constexpr int64_t MAX_BLOCK_LEN = 65535;
int64_t maxUbFactor = FloorAlign(MAX_BLOCK_LEN / typeSize, ubBlockSize);
if (ubFactor > maxUbFactor) {
ubFactor = maxUbFactor;
}
OP_CHECK_IF(ubFactor <= 0,
OP_LOGE(context, "Eltwise: ubFactor=%ld, UB too small", ubFactor),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static void FillCoeff(gert::TilingContext* context, EltwiseTilingData* tiling, uint32_t inputNum)
{
const auto* attrs = context->GetAttrs();
const auto* coeffList = attrs->GetListFloat(1);
if (coeffList != nullptr && coeffList->GetSize() > 0) {
size_t coeffSize = coeffList->GetSize();
const float* coeffData = coeffList->GetData();
if (coeffData != nullptr) {
for (size_t i = 0; i < coeffSize && i < MAX_INPUT_NUM; i++) {
tiling->coeff[i] = coeffData[i];
}
}
} else {
for (uint32_t i = 0; i < inputNum; i++) {
tiling->coeff[i] = 1.0f;
}
}
}
static ge::graphStatus InitTilingData(gert::TilingContext* context, EltwiseTilingData*& tiling)
{
tiling = context->GetTilingData<EltwiseTilingData>();
OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
OP_CHECK_IF(
memset_s(tiling, sizeof(EltwiseTilingData), 0, sizeof(EltwiseTilingData)) != EOK,
OP_LOGE(context, "set tiling data error"), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus FillTilingAndKey(gert::TilingContext* context, EltwiseTilingData* tiling,
const TilingParams& params)
{
int64_t ubBlockSize = 32 / params.typeSize;
int64_t blockFactor = CeilDiv(params.totalNum, params.coreNum);
blockFactor = ((blockFactor + ubBlockSize - 1) / ubBlockSize) * ubBlockSize;
int64_t usedCoreNum = CeilDiv(params.totalNum, blockFactor);
int64_t ubFactor = 0;
OP_CHECK_IF(
ComputeUbFactor(context, params.ubSize, params.dtype, params.typeSize,
ubBlockSize, ubFactor) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "ComputeUbFactor error"), return ge::GRAPH_FAILED);
tiling->totalNum = params.totalNum;
tiling->blockFactor = blockFactor;
tiling->ubFactor = ubFactor;
tiling->inputNum = params.inputNum;
if (params.mode == 1) {
FillCoeff(context, tiling, params.inputNum);
}
context->SetBlockDim(usedCoreNum);
uint32_t dType = static_cast<uint32_t>(params.dtype);
uint32_t modeVal = static_cast<uint32_t>(params.mode);
ASCENDC_TPL_SEL_PARAM(context, dType, modeVal);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus EltwiseTilingFunc(gert::TilingContext* context)
{
TilingParams params;
OP_CHECK_IF(
GetPlatformInfo(context, params.ubSize, params.coreNum) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetPlatformInfo error"), return ge::GRAPH_FAILED);
OP_CHECK_IF(
GetInputInfo(context, params.inputNum, params.dtype, params.totalNum) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetInputInfo error"), return ge::GRAPH_FAILED);
OP_CHECK_IF(
GetModeAttr(context, params.mode) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetModeAttr error"), return ge::GRAPH_FAILED);
OP_CHECK_IF(
GetWorkspaceSize(context) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetWorkspaceSize error"), return ge::GRAPH_FAILED);
EltwiseTilingData* tiling = nullptr;
OP_CHECK_IF(
InitTilingData(context, tiling) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "InitTilingData error"), return ge::GRAPH_FAILED);
OP_CHECK_IF(
GetTypeSize(context, params.dtype, params.typeSize) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetTypeSize error"), return ge::GRAPH_FAILED);
if (params.totalNum == 0) {
SetEmptyTilingAndKey(context, tiling, params.inputNum, params.dtype, params.mode);
return ge::GRAPH_SUCCESS;
}
return FillTilingAndKey(context, tiling, params);
}
static ge::graphStatus TilingParseForEltwise([[maybe_unused]] gert::TilingParseContext* context)
{
return ge::GRAPH_SUCCESS;
}
struct EltwiseCompileInfo {};
IMPL_OP_OPTILING(Eltwise).Tiling(EltwiseTilingFunc).TilingParse<EltwiseCompileInfo>(TilingParseForEltwise);
}