* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file transdata_tiling_arch35.cpp
* \brief calc tiling data for transdata AscendC kernel
*/
#include "transdata_tiling_arch35.h"
#include <algorithm>
#include "log/log.h"
#include "op_host/tiling_base_util.h"
#include "util/math_util.h"
#include "util/platform_util.h"
using namespace Ops::Base;
namespace optiling {
namespace transdata_asc {
ge::graphStatus TransDataTilingAscendC::GetHardwareInfo()
{
auto compileInfo = reinterpret_cast<const TransDataCompileInfo*>(context_->GetCompileInfo());
OP_CHECK_NULL_WITH_CONTEXT(context_, compileInfo);
coreNum_ = static_cast<uint32_t>(compileInfo->coreNum);
ubSize_ = compileInfo->ubSize;
blockSize_ = compileInfo->blockSize;
OP_CHECK_IF(
(coreNum_ <= 0U || ubSize_ <= 0L),
OP_LOGE(context_->GetNodeName(), "TransData GetHardwareInfo failed, core num: %u, ub size: %ld", coreNum_,
ubSize_),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
void TransDataTilingAscendC::ReshapeInShape()
{
auto dimCnt = inShape_.GetDimNum();
if (dimCnt > 1) {
return;
}
auto bakDim = inShape_.GetDim(0);
inShape_.SetDim(0, 1);
inShape_.AppendDim(1);
inShape_.AppendDim(bakDim);
}
bool TransDataTilingAscendC::GetShapeInfo()
{
auto xStorage = context_->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, xStorage);
inShape_ = Ops::Base::EnsureNotScalar(xStorage->GetStorageShape());
auto yStorage = context_->GetOutputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, yStorage);
outShape_ = Ops::Base::EnsureNotScalar(yStorage->GetStorageShape());
OP_CHECK_IF(
inShape_.GetShapeSize() == 0 || outShape_.GetShapeSize() == 0,
OP_LOGE(context_->GetNodeName(), "The input or output shape is empty!"),
return false);
ReshapeInShape();
return true;
}
bool TransDataTilingAscendC::GetTransFormatAndDType()
{
auto srcTd = context_->GetInputDesc(0);
auto dstTd = context_->GetOutputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, srcTd);
OP_CHECK_NULL_WITH_CONTEXT(context_, dstTd);
auto srcFormat = static_cast<ge::Format>(ge::GetPrimaryFormat(srcTd->GetStorageFormat()));
dstFormat_ = static_cast<ge::Format>(ge::GetPrimaryFormat(dstTd->GetStorageFormat()));
srcDtype_ = srcTd->GetDataType();
dtypeSize_ = ge::GetSizeByDataType(srcDtype_);
std::vector<ge::Format> supportSrcFormat = {ge::FORMAT_ND, ge::FORMAT_NCL, ge::FORMAT_NCHW, ge::FORMAT_NHWC};
std::vector<ge::Format> supportDstFormat = {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ_C0_16,
ge::FORMAT_FRACTAL_NZ_C0_32};
OP_CHECK_IF(
(std::find(supportSrcFormat.begin(), supportSrcFormat.end(), srcFormat) == supportSrcFormat.end() ||
std::find(supportDstFormat.begin(), supportDstFormat.end(), dstFormat_) == supportDstFormat.end()),
OP_LOGE(context_->GetNodeName(), "The input or output format is invalid!"),
return false);
return true;
}
bool TransDataTilingAscendC::GetTransNz2NdFormatAndDType()
{
auto srcTd = context_->GetInputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, srcTd);
auto srcFormat = static_cast<ge::Format>(ge::GetPrimaryFormat(srcTd->GetStorageFormat()));
srcDtype_ = srcTd->GetDataType();
dtypeSize_ = ge::GetSizeByDataType(srcDtype_);
if (dtypeSize_ == ge::kDataTypeSizeBitOffset + B4_BIT_SIZE) {
dtypeSize_ = 1;
}
OP_CHECK_IF(
(dtypeSize_ <= 0), OP_LOGE(context_->GetNodeName(), "The dtypeSize is invalid!"),
return false);
std::vector<ge::Format> supportSrcFormat = {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ_C0_2,
ge::FORMAT_FRACTAL_NZ_C0_4, ge::FORMAT_FRACTAL_NZ_C0_16,
ge::FORMAT_FRACTAL_NZ_C0_32};
std::map<ge::Format, int64_t> formatMap = {
{ge::FORMAT_FRACTAL_NZ, blockSize_ / dtypeSize_},
{ge::FORMAT_FRACTAL_NZ_C0_2, C0_2},
{ge::FORMAT_FRACTAL_NZ_C0_4, C0_4},
{ge::FORMAT_FRACTAL_NZ_C0_16, C0_16},
{ge::FORMAT_FRACTAL_NZ_C0_32, C0_32}
};
auto it = formatMap.find(srcFormat);
OP_CHECK_IF(
(it == formatMap.end()), OP_LOGE(context_->GetNodeName(), "The input format is invalid!"),
return false);
expectC0_ = it->second;
if (srcDtype_ == ge::DT_FLOAT4_E2M1 && srcFormat != ge::FORMAT_FRACTAL_NZ) {
expectC0_ = expectC0_ >> 1;
}
return true;
}
bool TransDataTilingAscendC::CalcNzToNdShapeSize()
{
auto xStorage = context_->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, xStorage);
inShape_ = Ops::Base::EnsureNotScalar(xStorage->GetStorageShape());
auto yStorage = context_->GetOutputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, yStorage);
outShape_ = Ops::Base::EnsureNotScalar(yStorage->GetStorageShape());
OP_CHECK_IF(
inShape_.GetDimNum() < (int64_t)nFour,
OP_LOGE(context_->GetNodeName(), "The input shape dim is less than 4!"), return false);
OP_CHECK_IF(
inShape_.GetDimNum() - outShape_.GetDimNum() != (int64_t)nTwo,
OP_LOGE(context_->GetNodeName(), "The input dim must be 2 higher than the output dim."), return false);
auto inputDimCnt = inShape_.GetDimNum();
c1_ = inShape_.GetDim(inputDimCnt - nFour);
n1_ = inShape_.GetDim(inputDimCnt - nThree);
n0_ = inShape_.GetDim(inputDimCnt - nTwo);
c0_ = inShape_.GetDim(inputDimCnt - 1);
OP_CHECK_IF(
(c1_ * n1_ * n0_ * c0_ == 0),
OP_LOGE(context_->GetNodeName(), "All input dim must not be zero"), return false);
auto outputDimCnt = outShape_.GetDimNum();
n_ = outShape_.GetDim(outputDimCnt - nTwo);
c_ = outShape_.GetDim(outputDimCnt - 1);
for (size_t i = 0; i < outputDimCnt - nTwo; i++) {
OP_CHECK_IF(
(inShape_.GetDim(i) != outShape_.GetDim(i) || inShape_.GetDim(i) == 0),
OP_LOGE(context_->GetNodeName(), "The input H dim must be same with output dim and not be zero."),
return false);
}
OP_CHECK_IF(
(n1_ != CeilDiv(n_, n0_)),
OP_LOGE(context_->GetNodeName(), "The dim N1 should be equal CeilDiv(N, N0)"), return false);
OP_CHECK_IF(
(c1_ != CeilDiv(c_, c0_)),
OP_LOGE(context_->GetNodeName(), "The dim C1 should be equal CeilDiv(C, C0)"), return false);
if (srcDtype_ == ge::DT_FLOAT4_E2M1 || srcDtype_ == ge::DT_FLOAT4_E1M2) {
c0_ = c0_ >> 1;
c_ = c_ >> 1;
}
OP_CHECK_IF(
(c0_ != expectC0_ || n0_ != N0_16),
OP_LOGE(context_->GetNodeName(), "The n0 should be 16 and c0 should be %ld", expectC0_), return false);
int64_t res = 1;
for (size_t i = 0; i < inputDimCnt - nFour; i++) {
res *= inShape_.GetDim(i);
}
h_ = res;
return true;
}
bool TransDataTilingAscendC::CalcC0Size()
{
auto dimCnt = outShape_.GetDimNum();
auto tmpC0 = outShape_.GetDim(dimCnt - 1);
OP_CHECK_IF(
dstFormat_ == ge::FORMAT_FRACTAL_NZ_C0_16 && tmpC0 != C0_16,
OP_LOGE(context_->GetNodeName(), "The c0 should be 16 when dst format is FRACTAL_NZ_C0_16!"),
return false);
OP_CHECK_IF(
dstFormat_ == ge::FORMAT_FRACTAL_NZ_C0_32 && tmpC0 != C0_32,
OP_LOGE(context_->GetNodeName(), "The c0 should be 32 when dst format is FRACTAL_NZ_C0_32!"),
return false);
int64_t expectC0 = tmpC0;
if (dtypeSize_ == 1U) {
expectC0 = C0_32;
} else if (dtypeSize_ == nTwo) {
expectC0 = C0_16;
}
OP_CHECK_IF(
dstFormat_ == ge::FORMAT_FRACTAL_NZ && dtypeSize_ == nTwo * nTwo && C0_8 != expectC0 && C0_16 != expectC0,
OP_LOGE(context_->GetNodeName(), "The c0 should be 8 or 16 when dst format is FRACTAL_NZ and dtype size is %zu!",
dtypeSize_),
return false);
OP_CHECK_IF(
dstFormat_ == ge::FORMAT_FRACTAL_NZ && tmpC0 != expectC0,
OP_LOGE(context_->GetNodeName(), "The c0 should be %ld when dst format is FRACTAL_NZ and dtype size is %zu!",
expectC0, dtypeSize_),
return false);
c0_ = tmpC0;
if (srcDtype_ == ge::DT_FLOAT4_E2M1 || srcDtype_ == ge::DT_FLOAT4_E1M2) {
c0_ = c0_ >> 1;
}
return true;
}
void TransDataTilingAscendC::CalcHSize()
{
int64_t res = 1;
auto dimCnt = inShape_.GetDimNum();
if (dimCnt <= nTwo) {
h_ = res;
return;
}
for (size_t i = 0; i < dimCnt - nTwo; i++) {
res *= inShape_.GetDim(i);
}
h_ = res;
}
void TransDataTilingAscendC::CalcNCSize()
{
auto dimCnt = inShape_.GetDimNum();
n_ = inShape_.GetDim(dimCnt - nTwo);
c_ = inShape_.GetDim(dimCnt - 1);
if (srcDtype_ == ge::DT_FLOAT4_E2M1 || srcDtype_ == ge::DT_FLOAT4_E1M2) {
c_ = c_ >> 1;
}
}
void TransDataTilingAscendC::CalcTilingKey()
{
int64_t ni = 16;
int64_t shapeSize = h_ * CeilAlign(n_, ni) * CeilAlign(c_, c0_);
tilingKey_ = (shapeSize > MAX_INT32_SIZE) ? TILING_MODE_SIMT_LARGE_SHAPE : TILING_MODE_SIMT;
}
void TransDataTilingAscendC::CalcBlockAndThreadNum()
{
bNum_ = coreNum_;
tNum_ = (tilingKey_ == TILING_MODE_SIMT) ? tNum512 : tNum256;
}
ge::graphStatus TransDataTilingAscendC::CalcTilingData()
{
OP_CHECK_IF(!GetShapeInfo(), OP_LOGE(context_->GetNodeName(), "Failed to get shape info!"), return ge::GRAPH_FAILED);
OP_CHECK_IF(
!GetTransFormatAndDType(), OP_LOGE(context_->GetNodeName(), "Failed to get format and dtype info!"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(!CalcC0Size(), OP_LOGE(context_->GetNodeName(), "Failed to get c0 size!"), return ge::GRAPH_FAILED);
CalcHSize();
CalcNCSize();
CalcTilingKey();
CalcBlockAndThreadNum();
return ge::GRAPH_SUCCESS;
}
void TransDataTilingAscendC::WriteNzToNdTilingData()
{
context_->SetBlockDim(coreNum_);
context_->SetTilingKey(TILING_MODE_SIMT_NZ_TO_ND);
context_->SetLocalMemorySize(ubSize_ - SIMT_RSV_SIZE);
tilingNzToNdData_.set_h(h_);
tilingNzToNdData_.set_n(n_);
tilingNzToNdData_.set_c(c_);
tilingNzToNdData_.set_c1(c1_);
tilingNzToNdData_.set_n1(n1_);
tilingNzToNdData_.set_n0(n0_);
tilingNzToNdData_.set_c0(c0_);
tilingNzToNdData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(tilingNzToNdData_.GetDataSize());
}
void TransDataTilingAscendC::WriteTilingData()
{
context_->SetBlockDim(bNum_);
context_->SetTilingKey(tilingKey_);
context_->SetLocalMemorySize(ubSize_ - SIMT_RSV_SIZE);
tilingData_.set_c0(c0_);
tilingData_.set_h(h_);
tilingData_.set_n(n_);
tilingData_.set_c(c_);
tilingData_.set_tNum(tNum_);
tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize());
}
std::string TransDataTilingAscendC::PrintTilingData()
{
std::string tilingStr;
tilingStr += std::to_string(c0_) + ",";
tilingStr += std::to_string(h_) + ",";
tilingStr += std::to_string(n_) + ",";
tilingStr += std::to_string(c_) + ",";
tilingStr += std::to_string(tNum_);
return tilingStr;
}
std::string TransDataTilingAscendC::PrintNz2NdTilingData()
{
std::string tilingStr;
tilingStr += std::to_string(h_) + ",";
tilingStr += std::to_string(n_) + ",";
tilingStr += std::to_string(c_) + ",";
tilingStr += std::to_string(c1_) + ",";
tilingStr += std::to_string(n1_) + ",";
tilingStr += std::to_string(n0_) + ",";
tilingStr += std::to_string(c0_);
return tilingStr;
}
ge::graphStatus TransDataTilingAscendC::DoNz2NdTiling()
{
OP_CHECK_IF(
!GetTransNz2NdFormatAndDType(), OP_LOGE(context_->GetNodeName(), "Failed to get format and dtype info!"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
!CalcNzToNdShapeSize(), OP_LOGE(context_->GetNodeName(), "The input shape is invalid!"),
return ge::GRAPH_FAILED);
WriteNzToNdTilingData();
size_t* currentWorkspace = context_->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, currentWorkspace);
currentWorkspace[0] = kSyncWorkSpaceSize;
OP_LOGI(context_->GetNodeName(), "TransData tiling data: %s", PrintNz2NdTilingData().c_str());
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TransDataTilingAscendC::DoTiling()
{
auto dstTd = context_->GetOutputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, dstTd);
dstFormat_ = static_cast<ge::Format>(ge::GetPrimaryFormat(dstTd->GetStorageFormat()));
if (dstFormat_ == ge::FORMAT_ND) {
return DoNz2NdTiling();
}
OP_CHECK_IF(
(CalcTilingData() != ge::GRAPH_SUCCESS),
OP_LOGE(context_->GetNodeName(), "TransDataTilingAscendC failed to calc tiling data."),
return ge::GRAPH_FAILED);
OP_CHECK_IF(ubSize_ <= SIMT_RSV_SIZE, OP_LOGE(context_->GetNodeName(), "UB size too small for SIMT reserved size."),
return ge::GRAPH_FAILED);
WriteTilingData();
size_t* currentWorkspace = context_->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, currentWorkspace);
currentWorkspace[0] = kSyncWorkSpaceSize;
OP_LOGI(context_->GetNodeName(), "TransData tiling data: %s", PrintTilingData().c_str());
return ge::GRAPH_SUCCESS;
}
}
ge::graphStatus Tiling4TransDataAscendC(gert::TilingContext* context)
{
transdata_asc::TransDataTilingAscendC tdTiling(context);
OP_CHECK_IF(
(tdTiling.GetHardwareInfo() != ge::GRAPH_SUCCESS),
OP_LOGE(context->GetNodeName(), "TransDataTilingAscendC failed to get hardware info."),
return ge::GRAPH_FAILED);
return tdTiling.DoTiling();
}
static ge::graphStatus TilingPrepare4TransData(gert::TilingParseContext* context)
{
auto compileInfo = context->GetCompiledInfo<TransDataCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->coreNum = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF((compileInfo->coreNum <= 0), OP_LOGE(context->GetNodeName(), "The core num is invalid."),
return ge::GRAPH_FAILED);
uint64_t ubSize = 0;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
compileInfo->ubSize = static_cast<int64_t>(ubSize);
OP_CHECK_IF((compileInfo->ubSize <= 0), OP_LOGE(context->GetNodeName(), "Failed to get ub size."),
return ge::GRAPH_FAILED);
compileInfo->blockSize = Ops::Base::GetUbBlockSize(context);
OP_CHECK_IF((compileInfo->blockSize < 1),
OP_LOGE(context->GetNodeName(), "The block size is invalid, %ld.",
compileInfo->blockSize),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(TransData).Tiling(Tiling4TransDataAscendC).TilingParse<TransDataCompileInfo>(TilingPrepare4TransData);
}