* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file addr_tiling.cpp
* \brief addr_tiling source file
*/
#include "addr_tiling.h"
#include <graph/utils/type_utils.h>
#include "../../op_kernel/arch35/addr_struct.h"
#include "log/log.h"
#include "platform/platform_info.h"
#include "register/op_impl_registry.h"
#include "register/tilingdata_base.h"
#include "platform/platform_ascendc.h"
#include "atvoss/broadcast/broadcast_tiling.h"
#include "op_host/util/fp16.h"
#include "util/bfloat16.h"
#include "op_host/math_tiling_templates_registry.h"
using namespace Ops::Base;
using namespace AscendC;
using namespace ge;
using namespace ops;
namespace optiling {
constexpr static uint64_t ADDR_COMMON_TILING_PRIORITY = 0;
constexpr static uint32_t INPUT_IDX_X1 = 0;
constexpr static uint32_t INPUT_IDX_X2 = 1;
constexpr static uint32_t INPUT_IDX_X3 = 2;
constexpr static uint32_t INPUT_IDX_BETA = 3;
constexpr static uint32_t INPUT_IDX_ALPHA = 4;
constexpr static uint32_t OUTPUT_IDX_Y = 0;
constexpr static int32_t nTwo = 2;
constexpr static uint16_t FP16_BF16_ZERO = 0x0000;
constexpr static uint16_t FP16_ONE = 0x3C00;
constexpr static uint16_t BF16_ONE = 0x3F80;
constexpr static uint64_t TILINGKET_ZERO = 0;
constexpr static uint64_t TILINGKET_NOT_ZERO = 1;
static ge::graphStatus TilingPrepareForAddr(gert::TilingParseContext* context)
{
auto compileInfoPtr = context->GetCompiledInfo<AddrCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr);
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
compileInfoPtr->coreNum = ascendcPlatform.GetCoreNumAiv();
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AddrTiling::GetShapeAttrsInfo()
{
return ge::GRAPH_SUCCESS;
}
bool AddrTiling::IsCapable()
{
return true;
}
bool AddrTiling::CheckDtype()
{
auto x1Desc = context_->GetInputDesc(INPUT_IDX_X1);
OP_CHECK_NULL_WITH_CONTEXT(context_, x1Desc);
x1Dtype_ = x1Desc->GetDataType();
auto x2Desc = context_->GetInputDesc(INPUT_IDX_X2);
OP_CHECK_NULL_WITH_CONTEXT(context_, x2Desc);
x2Dtype_ = x2Desc->GetDataType();
auto x3Desc = context_->GetInputDesc(INPUT_IDX_X3);
OP_CHECK_NULL_WITH_CONTEXT(context_, x3Desc);
x3Dtype_ = x3Desc->GetDataType();
auto betaDesc = context_->GetInputDesc(INPUT_IDX_BETA);
OP_CHECK_NULL_WITH_CONTEXT(context_, betaDesc);
betaDtype_ = betaDesc->GetDataType();
auto alphaDesc = context_->GetInputDesc(INPUT_IDX_ALPHA);
OP_CHECK_NULL_WITH_CONTEXT(context_, alphaDesc);
alphaDtype_ = alphaDesc->GetDataType();
auto yDesc = context_->GetOutputDesc(OUTPUT_IDX_Y);
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
yDtype_ = yDesc->GetDataType();
if (x1Dtype_ != x2Dtype_ || x2Dtype_ != x3Dtype_ || x3Dtype_ != betaDtype_ || betaDtype_ != alphaDtype_ ||
alphaDtype_ != yDtype_) {
OP_LOGE(
context_->GetNodeName(), "check datetype failed, x1: %s, x2: %s, x3: %s, beta: %s, alpha: %s, y: %s.",
ge::TypeUtils::DataTypeToSerialString(x1Dtype_).c_str(),
ge::TypeUtils::DataTypeToSerialString(x2Dtype_).c_str(),
ge::TypeUtils::DataTypeToSerialString(x3Dtype_).c_str(),
ge::TypeUtils::DataTypeToSerialString(betaDtype_).c_str(),
ge::TypeUtils::DataTypeToSerialString(alphaDtype_).c_str(),
ge::TypeUtils::DataTypeToSerialString(yDtype_).c_str());
return false;
}
return true;
}
bool AddrTiling::CheckBroadcast()
{
int64_t n = x2Shape_.GetDim(0);
int64_t m = x3Shape_.GetDim(0);
size_t x1DimNum = x1Shape_.GetDimNum();
if (x1DimNum == 1) {
int64_t inM = x1Shape_.GetDim(0);
if (inM == 1 || inM == m) {
return true;
}
OP_LOGE(context_->GetNodeName(), "size mismatch, x1:[%ld], x2: [%ld], x3: [%ld].", inM, n, m);
} else if (x1DimNum == static_cast<size_t>(nTwo)) {
int64_t inN = x1Shape_.GetDim(0);
int64_t inM = x1Shape_.GetDim(1);
if ((inN == 1 || inN == n) && (inM == 1 || inM == m)) {
return true;
}
OP_LOGE(context_->GetNodeName(), "size mismatch, x1:[%ld, %ld], x2: [%ld], x3: [%ld].", inN, inM, n, m);
} else {
OP_LOGE(context_->GetNodeName(), "x1 Dim should be 1 or 2, but got %d.", static_cast<int32_t>(x1DimNum));
}
return false;
}
bool AddrTiling::CheckShapes()
{
auto x1StorageShape = context_->GetInputShape(INPUT_IDX_X1);
OP_CHECK_NULL_WITH_CONTEXT(context_, x1StorageShape);
x1Shape_ = EnsureNotScalar(x1StorageShape->GetStorageShape());
auto x2StorageShape = context_->GetInputShape(INPUT_IDX_X2);
OP_CHECK_NULL_WITH_CONTEXT(context_, x2StorageShape);
x2Shape_ = EnsureNotScalar(x2StorageShape->GetStorageShape());
auto x3StorageShape = context_->GetInputShape(INPUT_IDX_X3);
OP_CHECK_NULL_WITH_CONTEXT(context_, x3StorageShape);
x3Shape_ = EnsureNotScalar(x3StorageShape->GetStorageShape());
auto betaStorageShape = context_->GetInputShape(INPUT_IDX_BETA);
OP_CHECK_NULL_WITH_CONTEXT(context_, betaStorageShape);
betaShape_ = EnsureNotScalar(betaStorageShape->GetStorageShape());
auto alphaStorageShape = context_->GetInputShape(INPUT_IDX_ALPHA);
OP_CHECK_NULL_WITH_CONTEXT(context_, alphaStorageShape);
alphaShape_ = EnsureNotScalar(alphaStorageShape->GetStorageShape());
auto yStorageShape = context_->GetOutputShape(OUTPUT_IDX_Y);
OP_CHECK_NULL_WITH_CONTEXT(context_, yStorageShape);
yShape_ = EnsureNotScalar(yStorageShape->GetStorageShape());
if (x1Shape_.GetShapeSize() <= 0L || x2Shape_.GetShapeSize() <= 0L || x3Shape_.GetShapeSize() <= 0L ||
yShape_.GetShapeSize() <= 0L) {
OP_LOGE(context_->GetNodeName(), "The x1, x2, x3, y shape should not be empty.");
return false;
}
OP_CHECK_IF(
x2Shape_.GetDimNum() != 1, OP_LOGE(context_->GetNodeName(), "the x2 shape rank must be 1."), return false);
OP_CHECK_IF(
x3Shape_.GetDimNum() != 1, OP_LOGE(context_->GetNodeName(), "the x3 shape rank must be 1."), return false);
OP_CHECK_IF(!CheckBroadcast(), OP_LOGE(context_->GetNodeName(), "CheckBroadcast failed."), return false);
OP_CHECK_IF(
yShape_.GetDimNum() != nTwo ||
(yShape_.GetDim(0) != x2Shape_.GetDim(0) || yShape_.GetDim(1) != x3Shape_.GetDim(0)),
OP_LOGE(context_->GetNodeName(), "y Shape must be [%ld, %ld].", x2Shape_.GetDim(0), x3Shape_.GetDim(0)),
return false);
return true;
}
template <typename T>
ge::graphStatus AddrTiling::GetConstData(uint32_t inputIdx, bool isEmpty, T emptyValue, T& data)
{
if (isEmpty) {
data = emptyValue;
} else {
auto tensor = context_->GetInputTensor(inputIdx);
OP_CHECK_NULL_WITH_CONTEXT(context_, tensor);
const T* value = tensor->GetData<T>();
if (value == nullptr) {
OP_LOGE(context_->GetNodeName(), "const tensor is null.");
return ge::GRAPH_FAILED;
}
data = value[0];
}
return ge::GRAPH_SUCCESS;
}
template <typename OpDag1, typename OpDag2, typename OpDag3, typename T>
ge::graphStatus AddrTiling::BroadcastTiling(T beta, T alpha)
{
bool isBetaZero = (beta == 0);
ge::graphStatus ret = ge::GRAPH_SUCCESS;
vector<gert::Shape> inputShapes;
if (isBetaZero) {
inputShapes.push_back(x2ReShape_);
inputShapes.push_back(x3Shape_);
BroadcastBaseTiling<OpDag2> brcBaseTiling(context_);
brcBaseTiling.SetScalar(alpha);
brcBaseTiling.SetOpInputStorageShapes(inputShapes);
ret = brcBaseTiling.DoTiling();
tilingKey_ = GET_TPL_TILING_KEY(brcBaseTiling.GetSchMode(), TILINGKET_ZERO, TILINGKET_NOT_ZERO);
} else {
inputShapes.push_back(x1Shape_);
inputShapes.push_back(x2ReShape_);
inputShapes.push_back(x3Shape_);
BroadcastBaseTiling<OpDag3> brcBaseTiling(context_);
brcBaseTiling.SetScalar(beta);
brcBaseTiling.SetScalar(alpha);
brcBaseTiling.SetOpInputStorageShapes(inputShapes);
ret = brcBaseTiling.DoTiling();
tilingKey_ = GET_TPL_TILING_KEY(brcBaseTiling.GetSchMode(), TILINGKET_NOT_ZERO, TILINGKET_NOT_ZERO);
}
return ret;
}
ge::graphStatus AddrTiling::DoOpTiling4Float(bool betaEmpty, bool alphaEmpty)
{
float beta = 0;
float alpha = 0;
OP_CHECK_IF(
GetConstData<float>(INPUT_IDX_BETA, betaEmpty, float(1), beta) != ge::GRAPH_SUCCESS ||
GetConstData<float>(INPUT_IDX_ALPHA, alphaEmpty, float(1), alpha) != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "get float beta or alpha failed!"), return ge::GRAPH_FAILED);
OP_LOGI(context_->GetNodeName(), "beta = %f, alpha = %f", beta, alpha);
return BroadcastTiling<
AddrOp::AddrWithoutAlphaCommon<float>::OpDag, AddrOp::AddrWithoutBetaCommon<float>::OpDag,
AddrOp::AddrWithBetaWithAlphaCommon<float>::OpDag, float>(beta, alpha);
}
ge::graphStatus AddrTiling::DoOpTiling4Half(bool betaEmpty, bool alphaEmpty)
{
float beta = 0;
float alpha = 0;
uint16_t betaTmp = 0;
uint16_t alphaTmp = 0;
ge::graphStatus ret = ge::GRAPH_SUCCESS;
if (x1Dtype_ == ge::DT_FLOAT16) {
OP_CHECK_IF(
GetConstData<uint16_t>(INPUT_IDX_BETA, betaEmpty, FP16_ONE, betaTmp) != ge::GRAPH_SUCCESS ||
GetConstData<uint16_t>(INPUT_IDX_ALPHA, alphaEmpty, FP16_ONE, alphaTmp) != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "get fp16 beta or alpha failed!"), return ge::GRAPH_FAILED);
beta = float(*(reinterpret_cast<fp16_t*>(&betaTmp)));
alpha = float(*(reinterpret_cast<fp16_t*>(&alphaTmp)));
} else {
OP_CHECK_IF(
GetConstData<uint16_t>(INPUT_IDX_BETA, betaEmpty, BF16_ONE, betaTmp) != ge::GRAPH_SUCCESS ||
GetConstData<uint16_t>(INPUT_IDX_ALPHA, alphaEmpty, BF16_ONE, alphaTmp) != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "get bf16 beta or alpha failed!"), return ge::GRAPH_FAILED);
beta = float(*(reinterpret_cast<bfloat16*>(&betaTmp)));
alpha = float(*(reinterpret_cast<bfloat16*>(&alphaTmp)));
}
OP_LOGI(context_->GetNodeName(), "beta = %f, alpha = %f", beta, alpha);
ret = BroadcastTiling<
AddrOp::AddrWithoutAlphaBf16Fp16<half>::OpDag, AddrOp::AddrWithoutBetaBf16Fp16<half>::OpDag,
AddrOp::AddrWithBetaWithAlphaBf16Fp16<half>::OpDag, float>(beta, alpha);
return ret;
}
ge::graphStatus AddrTiling::DoOpTiling()
{
OP_CHECK_IF(!CheckDtype(), OP_LOGE(context_->GetNodeName(), "CheckDtype error!"), return ge::GRAPH_FAILED);
OP_CHECK_IF(!CheckShapes(), OP_LOGE(context_->GetNodeName(), "CheckShapes error!"), return ge::GRAPH_FAILED);
x2ReShape_.AppendDim(x2Shape_.GetDim(0));
x2ReShape_.AppendDim(1);
bool betaEmpty = (betaShape_.GetShapeSize() == 0);
bool alphaEmpty = (alphaShape_.GetShapeSize() == 0);
ge::graphStatus ret = ge::GRAPH_SUCCESS;
if (x1Dtype_ == ge::DT_FLOAT) {
ret = DoOpTiling4Float(betaEmpty, alphaEmpty);
} else if (x1Dtype_ == ge::DT_FLOAT16 || x1Dtype_ == ge::DT_BF16) {
ret = DoOpTiling4Half(betaEmpty, alphaEmpty);
} else if (x1Dtype_ == ge::DT_BOOL || x1Dtype_ == ge::DT_INT8) {
int8_t tmpBeta = 0;
int8_t tmpAlpha = 0;
OP_CHECK_IF(
GetConstData<int8_t>(INPUT_IDX_BETA, betaEmpty, int8_t(1), tmpBeta) != ge::GRAPH_SUCCESS ||
GetConstData<int8_t>(INPUT_IDX_ALPHA, alphaEmpty, int8_t(1), tmpAlpha) != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "get int8 beta or alpha failed!"), return ge::GRAPH_FAILED);
int32_t beta = static_cast<int32_t>(tmpBeta);
int32_t alpha = static_cast<int32_t>(tmpAlpha);
OP_LOGI(context_->GetNodeName(), "beta = %d, alpha = %d", beta, alpha);
ret = BroadcastTiling<
AddrOp::AddrWithoutAlphaInt8::OpDag, AddrOp::AddrWithoutBetaInt8::OpDag,
AddrOp::AddrWithBetaWithAlphaInt8::OpDag, int32_t>(beta, alpha);
} else if (x1Dtype_ == ge::DT_UINT8) {
uint8_t tmpBeta = 0;
uint8_t tmpAlpha = 0;
OP_CHECK_IF(
GetConstData<uint8_t>(INPUT_IDX_BETA, betaEmpty, uint8_t(1), tmpBeta) != ge::GRAPH_SUCCESS ||
GetConstData<uint8_t>(INPUT_IDX_ALPHA, alphaEmpty, uint8_t(1), tmpAlpha) != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "get int8 beta or alpha failed!"), return ge::GRAPH_FAILED);
int32_t beta = static_cast<int32_t>(tmpBeta);
int32_t alpha = static_cast<int32_t>(tmpAlpha);
OP_LOGI(context_->GetNodeName(), "beta = %d, alpha = %d", beta, alpha);
ret = BroadcastTiling<
AddrOp::AddrWithoutAlphaUint8::OpDag, AddrOp::AddrWithoutBetaUint8::OpDag,
AddrOp::AddrWithBetaWithAlphaUint8::OpDag, int32_t>(beta, alpha);
} else {
OP_LOGE(
context_->GetNodeName(), "Input dtype is only support fp16, bf16, fp32, uint8, int8, bool, while got %s!",
ge::TypeUtils::DataTypeToSerialString(x1Dtype_).c_str());
return ge::GRAPH_FAILED;
}
return ret;
}
ge::graphStatus AddrTiling::DoLibApiTiling()
{
return ge::GRAPH_SUCCESS;
}
uint64_t AddrTiling::GetTilingKey() const
{
return tilingKey_;
}
ge::graphStatus AddrTiling::GetWorkspaceSize()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AddrTiling::PostTiling()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AddrTiling::GetPlatformInfo()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingForAddr(gert::TilingContext* context)
{
OP_LOGD("AddrTiling", "Enter TilingForAddr");
if (context == nullptr) {
OP_LOGE("AddrTiling", "Tiling context is nullptr");
return ge::GRAPH_FAILED;
}
auto compileInfo = reinterpret_cast<const AddrCompileInfo*>(context->GetCompileInfo());
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
OP_LOGD("AddrTiling", "Enter ascendc AddrTiling");
return Ops::Math::OpTiling::TilingRegistry::GetInstance().DoTilingImpl(context);
}
IMPL_OP_OPTILING(Addr)
.Tiling(TilingForAddr)
.TilingParse<AddrCompileInfo>(TilingPrepareForAddr)
.TilingInputsDataDependency({INPUT_IDX_BETA, INPUT_IDX_ALPHA});
REGISTER_OPS_TILING_TEMPLATE(Addr, AddrTiling, ADDR_COMMON_TILING_PRIORITY);
}