* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file abs_tiling_arch35.cpp
* \brief
*/
#include <iostream>
#include "abs_tiling_arch35.h"
#include <graph/utils/type_utils.h>
#include "tiling/tiling_api.h"
#include "tiling/platform/platform_ascendc.h"
#include "register/op_def_registry.h"
#include "register/tilingdata_base.h"
#include "../../op_kernel/arch35/abs_dag.h"
#include "../../op_kernel/arch35/abs_complex_dag.h"
using namespace ge;
using namespace AbsNs;
namespace optiling {
constexpr uint64_t ABS_TILING_KEY_ELEMENTWISE_BF16 = 101;
constexpr uint64_t ABS_TILING_KEY_ELEMENTWISE_OTHER = 102;
constexpr uint64_t ABS_TILING_KEY_ELEMENTWISE_COMPLEX = 103;
constexpr uint64_t ABS_WORKSPACE_RESERVE_BYTE = 16777216;
ge::graphStatus AbsTiling::SetTilingData()
{
size_t* currentWorkspace = tilingContext->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(tilingContext, currentWorkspace);
currentWorkspace[0] = ABS_WORKSPACE_RESERVE_BYTE;
if (this->outputDtype == ge::DT_BF16) {
tilingContext->SetTilingKey(ABS_TILING_KEY_ELEMENTWISE_BF16);
} else if (this->inputDtype == ge::DT_COMPLEX64 || this->inputDtype == ge::DT_COMPLEX32) {
tilingContext->SetTilingKey(ABS_TILING_KEY_ELEMENTWISE_COMPLEX);
} else {
tilingContext->SetTilingKey(ABS_TILING_KEY_ELEMENTWISE_OTHER);
}
tilingContext->SetBlockDim(tiling->baseTiling.blockNum);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AbsTiling::CalcOutputDtype()
{
auto inputDesc = tilingContext->GetInputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(tilingContext, inputDesc);
this->inputDtype = inputDesc->GetDataType();
auto outputDesc = tilingContext->GetOutputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(tilingContext, outputDesc);
this->outputDtype = outputDesc->GetDataType();
if (this->inputDtype != ge::DT_COMPLEX64 && this->inputDtype != ge::DT_COMPLEX32) {
OP_CHECK_IF(this->inputDtype != this->outputDtype,
OP_LOGE_FOR_INVALID_DTYPES_WITH_REASON(tilingContext->GetNodeName(), "inputDtype, outputDtype",
ge::TypeUtils::DataTypeToSerialString(this->inputDtype) + ", " +
ge::TypeUtils::DataTypeToSerialString(this->outputDtype),
"input and output dtypes must match for non-complex inputs"),
return ge::GRAPH_FAILED);
} else if (inputDtype == ge::DT_COMPLEX64) {
OP_CHECK_IF(this->outputDtype != ge::DT_FLOAT,
OP_LOGE_FOR_INVALID_DTYPE_WITH_REASON(tilingContext->GetNodeName(), "outputDtype",
ge::TypeUtils::DataTypeToSerialString(this->outputDtype),
"when input is complex64, output dtype must be DT_FLOAT"),
return ge::GRAPH_FAILED);
} else if (inputDtype == ge::DT_COMPLEX32) {
OP_CHECK_IF(this->outputDtype != ge::DT_FLOAT16,
OP_LOGE_FOR_INVALID_DTYPE_WITH_REASON(tilingContext->GetNodeName(), "outputDtype",
ge::TypeUtils::DataTypeToSerialString(this->outputDtype),
"when input is complex32, output dtype must be DT_FLOAT16"),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AbsTiling::RunTiling()
{
ElewiseBaseTiling elewiseBaseTiling(tilingContext);
OP_CHECK_IF(CalcOutputDtype() == ge::GRAPH_FAILED,
OP_LOGE(tilingContext, "get output dtype failed"),
return ge::GRAPH_FAILED);
ge::graphStatus res = ge::GRAPH_FAILED;
tiling = tilingContext->GetTilingData<AbsTilingData>();
if (this->inputDtype == ge::DT_FLOAT16) {
res = elewiseBaseTiling.DoTiling<AbsOp::AbsDag<half, half>::OpDag>(tiling->baseTiling);
} else if (this->inputDtype == ge::DT_COMPLEX64) {
res = elewiseBaseTiling.DoTiling<AbsOp::AbscomplexDag<int64_t, float>::OpDag>(tiling->baseTiling);
} else if (this->inputDtype == ge::DT_COMPLEX32) {
res = elewiseBaseTiling.DoTiling<AbsOp::AbscomplexDag<int32_t, half>::OpDag>(tiling->baseTiling);
} else if (this->inputDtype == ge::DT_FLOAT) {
res = elewiseBaseTiling.DoTiling<AbsOp::AbsDag<float, float>::OpDag>(tiling->baseTiling);
} else if (this->inputDtype == ge::DT_BF16) {
res = elewiseBaseTiling.DoTiling<AbsOp::AbsDag<bfloat16_t, float>::OpDag>(tiling->baseTiling);
} else if (this->inputDtype == ge::DT_INT8) {
res = elewiseBaseTiling.DoTiling<AbsOp::AbsDag<int8_t, int8_t>::OpDag>(tiling->baseTiling);
} else if (this->inputDtype == ge::DT_INT16) {
res = elewiseBaseTiling.DoTiling<AbsOp::AbsDag<int16_t, int16_t>::OpDag>(tiling->baseTiling);
} else if (this->inputDtype == ge::DT_INT32) {
res = elewiseBaseTiling.DoTiling<AbsOp::AbsDag<int32_t, int32_t>::OpDag>(tiling->baseTiling);
} else if (this->inputDtype == ge::DT_INT64) {
res = elewiseBaseTiling.DoTiling<AbsOp::AbsDag<int64_t, int64_t>::OpDag>(tiling->baseTiling);
} else {
OP_LOGE_FOR_INVALID_DTYPE_WITH_REASON(tilingContext->GetNodeName(), "inputDtype",
ge::TypeUtils::DataTypeToSerialString(this->inputDtype),
"input dtype must be in [DT_FLOAT16, DT_FLOAT, DT_BF16, DT_INT8, DT_INT16, "
"DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX32]");
return ge::GRAPH_FAILED;
}
OP_CHECK_IF(res == ge::GRAPH_FAILED,
OP_LOGE(tilingContext, "DoTiling failed"),
return ge::GRAPH_FAILED);
ge::graphStatus result = SetTilingData();
return result;
}
static ge::graphStatus TilingForAbs(gert::TilingContext *context)
{
OP_LOGD("AbsTiling", "Enter TilingForAbs");
OP_CHECK_IF(context == nullptr,
OP_LOGE(context, "Tiling context is null"),
return ge::GRAPH_FAILED);
OP_LOGD("AbsTiling", "Enter new AbsTiling");
AbsTiling absTiling(context);
return absTiling.RunTiling();
}
ge::graphStatus TilingPrepareForAbs(gert::TilingParseContext* context)
{
auto compileInfoPtr = context->GetCompiledInfo<AbsCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr);
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
compileInfoPtr->coreNum = ascendcPlatform.GetCoreNumAiv();
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(Abs).Tiling(TilingForAbs)
.TilingParse<AbsCompileInfo>(TilingPrepareForAbs);
}