* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file polar_tiling.cpp
* \brief polar tiling
*/
#include "polar_tiling.h"
#include <graph/utils/type_utils.h>
using namespace ge;
namespace optiling {
static constexpr uint64_t INPUT_ABS = 0;
static constexpr uint64_t INPUT_ANGLE = 1;
static constexpr uint64_t OUTPUT_Y = 0;
ge::graphStatus PolarTiling::GetPlatformInfo()
{
OP_LOGD(context_, "PolarTiling GetPlatformInfo.");
compileInfo_ = static_cast<const PolarCompileInfo*>(context_->GetCompileInfo());
OP_CHECK_NULL_WITH_CONTEXT(context_, compileInfo_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus PolarTiling::CheckDtype()
{
OP_LOGD(context_, "PolarTiling CheckDtype.");
auto input0Desc = context_->GetInputDesc(INPUT_ABS);
OP_CHECK_NULL_WITH_CONTEXT(context_, input0Desc);
ge::DataType input0Dtype = input0Desc->GetDataType();
auto input1Desc = context_->GetInputDesc(INPUT_ANGLE);
OP_CHECK_NULL_WITH_CONTEXT(context_, input1Desc);
ge::DataType input1Dtype = input1Desc->GetDataType();
auto outputDesc = context_->GetOutputDesc(OUTPUT_Y);
OP_CHECK_NULL_WITH_CONTEXT(context_, outputDesc);
ge::DataType outputDtype = outputDesc->GetDataType();
if (input0Dtype != ge::DT_FLOAT || input1Dtype != ge::DT_FLOAT || outputDtype != ge::DT_COMPLEX64) {
std::string dtypesStr = ge::TypeUtils::DataTypeToSerialString(input0Dtype) + ", " +
ge::TypeUtils::DataTypeToSerialString(input1Dtype) + " and " +
ge::TypeUtils::DataTypeToSerialString(outputDtype);
OP_LOGE_FOR_INVALID_DTYPES_WITH_REASON(context_->GetNodeName(), "abs, angle and y",
dtypesStr.c_str(), "The dtypes of abs and angle must be float, and the dtype of y must be complex64");
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus PolarTiling::CheckBroadcastAndMergeShape()
{
OP_LOGD(context_, "PolarTiling CheckBroadcastAndMergeShape.");
const gert::StorageShape* absStorageShape = context_->GetInputShape(INPUT_ABS);
OP_CHECK_NULL_WITH_CONTEXT(context_, absStorageShape);
const gert::StorageShape* angleStorageShape = context_->GetInputShape(INPUT_ANGLE);
OP_CHECK_NULL_WITH_CONTEXT(context_, angleStorageShape);
auto absShape = absStorageShape->GetStorageShape();
auto angleShape = angleStorageShape->GetStorageShape();
int64_t absDimNum = static_cast<int64_t>(absShape.GetDimNum());
int64_t angleDimNum = static_cast<int64_t>(angleShape.GetDimNum());
dimNum_ = std::max(absDimNum, angleDimNum);
OP_CHECK_IF(dimNum_ > POLAR_MAX_DIM,
OP_LOGE(context_, "dimNum %ld exceeds POLAR_MAX_DIM %ld", dimNum_, POLAR_MAX_DIM),
return ge::GRAPH_FAILED);
for (int64_t i = 0; i < dimNum_; i++) {
int64_t absOffset = i - (dimNum_ - absDimNum);
int64_t angleOffset = i - (dimNum_ - angleDimNum);
int64_t absDim = (absOffset >= 0) ? absShape.GetDim(absOffset) : 1;
int64_t angleDim = (angleOffset >= 0) ? angleShape.GetDim(angleOffset) : 1;
absDims_[i] = absDim;
angleDims_[i] = angleDim;
OP_CHECK_IF(absDim != angleDim && absDim != 1 && angleDim != 1,
OP_LOGE(context_, "Shapes not broadcastable at dim %ld: %ld vs %ld", i, absDim, angleDim),
return ge::GRAPH_FAILED);
mergedShape_[i] = std::max(absDim, angleDim);
}
totalElements_ = 1;
for (int64_t i = 0; i < dimNum_; i++) {
totalElements_ *= mergedShape_[i];
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus PolarTiling::CalcStride()
{
OP_LOGD(context_, "PolarTiling CalcStride.");
int64_t strideAbs = 1;
int64_t strideAngle = 1;
int64_t strideMerged = 1;
int64_t strideY = 1;
for (int64_t i = dimNum_ - 1; i >= 0; i--) {
absStride_[i] = (absDims_[i] == 1) ? 0 : strideAbs;
angleStride_[i] = (angleDims_[i] == 1) ? 0 : strideAngle;
mergedStride_[i] = strideMerged;
yStride_[i] = strideY;
strideAbs *= absDims_[i];
strideAngle *= angleDims_[i];
strideMerged *= mergedShape_[i];
strideY *= mergedShape_[i];
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus PolarTiling::GetShapeAttrsInfo()
{
if (CheckDtype() != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED;
if (CheckBroadcastAndMergeShape() != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED;
if (CalcStride() != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus PolarTiling::DoOpTiling()
{
OP_LOGD(context_, "PolarTiling DoOpTiling.");
int64_t coreNum = compileInfo_->coreNum;
int64_t elementsPerCore = totalElements_ / coreNum;
int64_t formerCore = totalElements_ % coreNum;
tilingData_.totalElements = totalElements_;
tilingData_.elementsPerCore = elementsPerCore;
tilingData_.coreNum = coreNum;
tilingData_.formerCore = formerCore;
tilingData_.dimNum = dimNum_;
for (int64_t i = 0; i < POLAR_MAX_DIM; i++) {
if (i < dimNum_) {
tilingData_.mergedStride[i] = mergedStride_[i];
tilingData_.absStride[i] = absStride_[i];
tilingData_.angleStride[i] = angleStride_[i];
tilingData_.yStride[i] = yStride_[i];
} else {
tilingData_.mergedStride[i] = 1;
tilingData_.absStride[i] = 0;
tilingData_.angleStride[i] = 0;
tilingData_.yStride[i] = 0;
}
}
blockDim_ = (totalElements_ < coreNum) ? totalElements_ : coreNum;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus PolarTiling::PostTiling()
{
OP_LOGD(context_, "PolarTiling PostTiling.");
auto workspaces = context_->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, workspaces);
workspaces[0] = 0;
auto res = context_->SetBlockDim(static_cast<uint32_t>(blockDim_));
OP_CHECK_IF((res != ge::GRAPH_SUCCESS), OP_LOGE(context_, "SetBlockDim failed."), return ge::GRAPH_FAILED);
errno_t ret = memcpy_s(
context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity(), &tilingData_,
sizeof(PolarTilingData));
if (ret != EOK) {
OP_LOGE(context_->GetNodeName(), "memcpy_s failed, ret=%d", ret);
return ge::GRAPH_FAILED;
}
context_->GetRawTilingData()->SetDataSize(sizeof(PolarTilingData));
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus Tiling4Polar(gert::TilingContext* context)
{
OP_LOGD(context, "Tiling4Polar start.");
PolarTiling polarTiling(context);
auto ret = polarTiling.DoTiling();
OP_CHECK_IF((ret == ge::GRAPH_FAILED), OP_LOGD(context, "Tiling4Polar failed!"), return ge::GRAPH_FAILED);
OP_LOGD(context, "Tiling4Polar end.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingPrepare4PolarAscendc(gert::TilingParseContext* context)
{
OP_LOGD(context->GetNodeName(), "Enter TilingPrepare4PolarAscendc.");
auto compileInfo = context->GetCompiledInfo<PolarCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->coreNum = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF(
(compileInfo->coreNum <= 0), OP_LOGE(context->GetNodeName(), "core num is negative."), return ge::GRAPH_FAILED);
OP_LOGD(context->GetNodeName(), "Exit TilingPrepare4PolarAscendc.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingPrepare4Polar(gert::TilingParseContext* context)
{
auto compileInfo = context->GetCompiledInfo<PolarCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
OP_LOGD("TilingPrepare4Polar", "Ascend C TilingPrepare4Polar success.");
return TilingPrepare4PolarAscendc(context);
}
IMPL_OP_OPTILING(Polar).Tiling(Tiling4Polar).TilingParse<PolarCompileInfo>(TilingPrepare4Polar);
}