* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file chunk_cat_tiling.cpp
* \brief
*/
#include "chunk_cat_tiling.h"
#include "tiling/platform/platform_ascendc.h"
namespace optiling {
constexpr uint32_t UB_BLOCK_SIZE = 32;
constexpr uint32_t TRANS_BLOCK = 16;
constexpr uint32_t RESERVE_UB = 256;
constexpr uint32_t HALF = 2;
constexpr uint32_t ONETHIRD = 3;
constexpr uint32_t DEFAUL_TILING_KEY = 0;
static const std::set<ge::DataType> supportedDtype =
{ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16};
std::string ChunkCatTiling::TilingDataToString() const {
return "blockRowNum = " + std::to_string(blockRowNum_) + ", blockColNum = " + std::to_string(blockColNum_) +
", dim = " + std::to_string(dim_) + ", numChunk = " + std::to_string(numChunk_) +
", outputRow = " + std::to_string(outputRow_) + ", outputCol = " + std::to_string(outputCol_) +
", blockRowFactor = " + std::to_string(blockRowFactor_) + ", blockColFactor = " + std::to_string(blockColFactor_) +
", tailBlockRowFactor = " + std::to_string(tailBlockRowFactor_) + ", tailBlockColFactor = " + std::to_string(tailBlockColFactor_) +
", ubRowFactor = " + std::to_string(ubRowFactor_) + ", ubColFactor = " + std::to_string(ubColFactor_) +
", inputNum = " + std::to_string(inputNum_) + ", inUbSize = " + std::to_string(inUbSize_) +
", outUbSize = " + std::to_string(outUbSize_) + ", isAllAlign = " + std::to_string(isAllAlign_) +
", isHalfAlign = " + std::to_string(isHalfAlign_) + ", isOneConcat = " + std::to_string(isOneConcat_);
}
ge::graphStatus ChunkCatTiling::GetPlatformInfo()
{
auto platformInfo = context_->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context_, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
coreNum_ = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF(coreNum_ == 0, OP_LOGE(context_, "coreNum is 0"), return ge::GRAPH_FAILED);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize_);
OP_CHECK_IF(ubSize_ == 0, OP_LOGE(context_, "ubSize is 0"), return ge::GRAPH_FAILED);
sysWorkspaceSize_ = ascendcPlatform.GetLibApiWorkSpaceSize();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ChunkCatTiling::GetInputInfo()
{
const gert::RuntimeAttrs* attrs = context_->GetAttrs();
const int64_t* dimPtr = attrs->GetAttrPointer<int64_t>(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, dimPtr);
dim_ = static_cast<int64_t>(*dimPtr);
OP_CHECK_IF(dim_ != 0, OP_LOGE(context_, "dim must be 0 now"), return ge::GRAPH_FAILED);
const int64_t* numChunkPtr = attrs->GetAttrPointer<int64_t>(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, numChunkPtr);
numChunk_ = static_cast<int64_t>(*numChunkPtr);
OP_CHECK_IF(numChunk_ <= 0, OP_LOGE(context_, "numChunk must be greater than 0 now"), return ge::GRAPH_FAILED);
auto computeNodeInfo = context_->GetComputeNodeInfo();
OP_CHECK_NULL_WITH_CONTEXT(context_, computeNodeInfo);
auto anchorInstanceInfo = computeNodeInfo->GetInputInstanceInfo(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, anchorInstanceInfo);
inputNum_ = anchorInstanceInfo->GetInstanceNum();
OP_CHECK_IF(inputNum_ == 0, OP_LOGE(context_, "input num can not be 0"), return ge::GRAPH_FAILED);
auto inputDesc = context_->GetDynamicInputDesc(0, 0);
OP_CHECK_NULL_WITH_CONTEXT(context_, inputDesc);
auto inputDataType = inputDesc->GetDataType();
auto outputDesc = context_->GetOutputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, outputDesc);
auto outputDataType = outputDesc->GetDataType();
OP_CHECK_IF(supportedDtype.count(inputDataType) == 0, OP_LOGE(context_, "input dtype is invalid"), return ge::GRAPH_FAILED);
OP_CHECK_IF(supportedDtype.count(outputDataType) == 0, OP_LOGE(context_, "output dtype is invalid"), return ge::GRAPH_FAILED);
if (inputDataType == ge::DT_FLOAT && outputDataType != ge::DT_FLOAT) {
OP_LOGE(context_, "output dtype must be float when input dtype is float");
return ge::GRAPH_FAILED;
}
srcDtypeSize_ = ge::GetSizeByDataType(inputDataType);
OP_CHECK_IF(srcDtypeSize_ == 0, OP_LOGE(context_, "input dtype size can not be 0"), return ge::GRAPH_FAILED);
srcEleUbBlock_ = UB_BLOCK_SIZE / srcDtypeSize_;
if (inputDataType != outputDataType) {
inUbSize_ = (ubSize_ + RESERVE_UB) / ONETHIRD;
} else {
inUbSize_ = (ubSize_ + RESERVE_UB) / HALF;
}
outUbSize_ = (ubSize_ + RESERVE_UB) - inUbSize_ ;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ChunkCatTiling::CalculateOutputInfo()
{
for (int64_t i = 0; i < inputNum_; i++) {
auto inputTensorShapePtr = context_->GetDynamicInputShape(0, i);
OP_CHECK_NULL_WITH_CONTEXT(context_, inputTensorShapePtr);
gert::Shape inputTensorShape = inputTensorShapePtr->GetStorageShape();
int64_t chunkDimSize = inputTensorShape.GetDim(dim_);
int64_t chunkCol = (chunkDimSize + numChunk_ - 1) / numChunk_;
int64_t dim0Size = numChunk_;
int64_t dim1Size = chunkCol;
int64_t inputTensorDimNum = inputTensorShape.GetDimNum();
for (int64_t j = 0; j < dim_; j++) {
dim0Size *= inputTensorShape.GetDim(j);
}
for (int64_t j = dim_ + 1; j < inputTensorDimNum; j++) {
dim1Size *= inputTensorShape.GetDim(j);
}
if (i == 0) {
outputRow_ = dim0Size;
} else if (outputRow_ != dim0Size) {
OP_LOGE(context_, "op expects same sizes of 0,...,dim-1 dimensions for all tensors");
return ge::GRAPH_FAILED;
}
if (dim1Size % srcEleUbBlock_ != 0) {
isAllAlign_ = false;
}
if (dim1Size % (srcEleUbBlock_ / 2) != 0) {
isHalfAlign_ = false;
}
outputCol_ += dim1Size;
}
isOneConcat_ = (outputRow_ == 1) && (!isAllAlign_);
return ge::GRAPH_SUCCESS;
}
void ChunkCatTiling::DoUbSplit()
{
if (isAllAlign_ || isOneConcat_) {
uint32_t colLimit = inUbSize_ / srcDtypeSize_;
colLimit = (isOneConcat_) ? colLimit - 32 * (srcEleUbBlock_ - 1) : colLimit;
int64_t ubColLoop = (outputCol_ + colLimit - 1) / colLimit ;
ubColFactor_ = (outputCol_ + ubColLoop - 1) / ubColLoop;
ubColFactor_ = (ubColFactor_ + srcEleUbBlock_ - 1) / srcEleUbBlock_ * srcEleUbBlock_;
ubColFactor_ = ubColFactor_ > colLimit ? colLimit : ubColFactor_;
uint32_t rowLimit = ubColLoop == 1 ? colLimit / ubColFactor_ : 1;
int64_t ubRowLoop = (outputRow_ + rowLimit - 1) / rowLimit ;
ubRowFactor_ = (outputRow_ + ubRowLoop - 1) / ubRowLoop;
} else {
uint32_t rowLimit = isHalfAlign_ ? TRANS_BLOCK * HALF: TRANS_BLOCK * srcEleUbBlock_;
int64_t ubRowLoop = (outputRow_ + rowLimit - 1) / rowLimit;
ubRowFactor_ = (outputRow_ + ubRowLoop - 1) / ubRowLoop;
uint32_t colLimit = inUbSize_ / rowLimit / srcDtypeSize_ - 2 * srcEleUbBlock_;
int64_t ubColLoop = (outputCol_ + colLimit - 1) / colLimit;
ubColFactor_ = (outputCol_ + ubColLoop - 1) / ubColLoop;
ubColFactor_ = (ubColFactor_ + srcEleUbBlock_ - 1) / srcEleUbBlock_ * srcEleUbBlock_;
ubColFactor_ = ubColFactor_ > colLimit ? colLimit : ubColFactor_;
}
}
void ChunkCatTiling::DoBlockSplit()
{
blockRowFactor_ = (outputRow_ + coreNum_ - 1) / coreNum_;
blockRowFactor_ = (blockRowFactor_ + ubRowFactor_ - 1) / ubRowFactor_ * ubRowFactor_;
blockRowNum_ = (outputRow_ + blockRowFactor_ - 1) / blockRowFactor_;
tailBlockRowFactor_ = outputRow_ - blockRowFactor_ * (blockRowNum_ - 1);
tailBlockRowFactor_ = tailBlockRowFactor_ == 0 ? blockRowFactor_ : tailBlockRowFactor_;
int64_t leftCoreNum = coreNum_ / blockRowNum_;
blockColFactor_ = (outputCol_ + leftCoreNum - 1) / leftCoreNum;
blockColFactor_ = (blockColFactor_ + ubColFactor_ - 1) / ubColFactor_ * ubColFactor_;
blockColNum_ = (outputCol_ + blockColFactor_ - 1) / blockColFactor_;
tailBlockColFactor_ = outputCol_ - blockColFactor_ * (blockColNum_ - 1);
tailBlockColFactor_ = tailBlockColFactor_ == 0 ? blockColFactor_ : tailBlockColFactor_;
usedCoreNum_ = blockRowNum_ * blockColNum_;
}
void ChunkCatTiling::SetTilingData(ChunkCatTilingData* tilingData)
{
tilingData->isAllAlign = isAllAlign_;
tilingData->isHalfAlign = isHalfAlign_;
tilingData->isOneConcat = isOneConcat_;
tilingData->dim = dim_;
tilingData->numChunk = numChunk_;
tilingData->outputRow = outputRow_;
tilingData->outputCol = outputCol_;
tilingData->blockRowFactor = blockRowFactor_;
tilingData->blockColFactor = blockColFactor_;
tilingData->tailBlockRowFactor = tailBlockRowFactor_;
tilingData->tailBlockColFactor = tailBlockColFactor_;
tilingData->ubRowFactor = ubRowFactor_;
tilingData->ubColFactor = ubColFactor_;
tilingData->inputNum = inputNum_;
tilingData->inUbSize = inUbSize_;
tilingData->outUbSize = outUbSize_;
tilingData->blockRowNum = blockRowNum_;
tilingData->blockColNum = blockColNum_;
size_t* workspaceSize = context_->GetWorkspaceSizes(1);
*workspaceSize = sysWorkspaceSize_;
context_->SetTilingKey(DEFAUL_TILING_KEY);
context_->SetBlockDim(usedCoreNum_);
}
static ge::graphStatus Tiling4ChunkCat(gert::TilingContext* context)
{
OP_LOGD(context, "ChunkCatTiling");
ChunkCatTiling tiling(context);
OP_CHECK_IF(
tiling.GetPlatformInfo() != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetPlatformInfo error"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
tiling.GetInputInfo() != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetInputInfo error"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
tiling.CalculateOutputInfo() != ge::GRAPH_SUCCESS, OP_LOGE(context, "CalculateOutputInfo error"),
return ge::GRAPH_FAILED);
tiling.DoUbSplit();
tiling.DoBlockSplit();
ChunkCatTilingData* tilingData = context->GetTilingData<ChunkCatTilingData>();
OP_CHECK_NULL_WITH_CONTEXT(context, tilingData);
OP_CHECK_IF(
memset_s(tilingData, sizeof(ChunkCatTilingData), 0, sizeof(ChunkCatTilingData)) != EOK,
OP_LOGE(context, "set tiling data error"), return ge::GRAPH_FAILED);
tiling.SetTilingData(tilingData);
OP_LOGD(context, "tiling data: %s", tiling.TilingDataToString().c_str());
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingPrepare4ChunkCat([[maybe_unused]] gert::TilingParseContext* context)
{
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(ChunkCat)
.Tiling(Tiling4ChunkCat)
.TilingParse<ChunkCatCompileInfo>(TilingPrepare4ChunkCat);
}