* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file concat_tiling_arch35.cpp
* \brief concat tiling for ascendC impl
*/
#include "concat_tiling_arch35.h"
#include "log/log.h"
#include <cmath>
#include <sstream>
#include <cctype>
#include "op_common/op_host/util/shape_util.h"
#include "op_common/op_host/util/platform_util.h"
#include "op_api/op_util.h"
#include <algorithm>
using namespace std;
using namespace ge;
namespace optiling {
constexpr size_t CONCAT_DIM_IDX = 0;
constexpr int64_t INVLID_CONCAT_DIM_IDX = static_cast<int64_t>(-1);
constexpr size_t PACK_ATTR_AXIS_IDX = 0;
constexpr size_t PACK_INPUT_IDX = 0;
constexpr int64_t PACK_AXIS_DEFAULT_VALUE = 1;
constexpr int64_t DIM0 = 0;
constexpr int64_t DIM1 = 1;
constexpr int64_t DIM2 = 2;
constexpr int64_t HALF = 2;
constexpr int64_t BLOCK_SIZE = 32;
constexpr int64_t DIM1_ALIGN_THRESHOLD = 128;
constexpr int64_t BUFFER_NUM = 2;
constexpr int64_t MIN_RESERVED_SIZE = 2048;
constexpr size_t SYSTEM_WORKSPACE_SIZE = 16777216;
constexpr int64_t INDEX_USE_UB = 1024;
constexpr int64_t TENS_DIGITS = 10;
constexpr int64_t HUNDREDS_DIGITS = 100;
constexpr int64_t THOUSANDS_DIGITS = 1000;
constexpr int64_t TEN_THOUSANDS_DIGITS = 10000;
constexpr int64_t LEAST_ROWS = 64;
constexpr int64_t LEAST_COLS = 256;
constexpr bool ENABLE_DB = true;
constexpr int64_t B64_BYTES = 8;
constexpr int64_t B32_BYTES = 4;
constexpr int64_t B16_BYTES = 2;
constexpr int64_t B8_BYTES = 1;
constexpr int64_t B4_BYTES = 1004;
constexpr int64_t DIGIT_TWO = 2;
constexpr int64_t DIGIT_ONE = 1;
constexpr int64_t DIGIT_THREE = 3;
constexpr int64_t FP4_TO_B8_RATIO = 2;
constexpr int64_t GATHER_MODE = 3;
constexpr int64_t EVERY_CORE_THRESHOLD = 2048;
constexpr int64_t LEAST_BLOCK_BYTES = 512;
constexpr int64_t PURE_COPY_COL_THRESHOLD_BASE = 256;
constexpr int64_t PURE_COPY_COL_THRESHOLD_ALIGN = 1024;
constexpr int64_t PURE_COPY_COL_THRESHOLD_NOALIGN = 2048;
constexpr int64_t BLOCK_THRESHOLD = 49152;
constexpr double LARGE_TENSOR_RATIO_THRESHOLD = 0.9;
constexpr int64_t PURE_COPY_NO_SPLIT_DIM1_TILINGKEY = 20001;
constexpr int64_t PURE_COPY_SPLIT_DIM1_TILINGKEY = 20002;
constexpr int64_t SIMT_PER_CORE_THRESHOLD = 65536;
constexpr int64_t SIMT_TILINGKEY_PREFIX = 30000;
constexpr int64_t SIMT_COMPARE_THRESHOLD = 1024;
constexpr int64_t SMALL_BAG = 128;
constexpr int64_t ALL_DATA_SMALL = 8192;
constexpr int32_t NUM_2 = 2;
constexpr int32_t NUM_3 = 3;
template <typename T>
inline static ge::graphStatus ConcatSetTilingData(gert::TilingContext* context, T& tilingData)
{
if (tilingData.GetDataSize() > context->GetRawTilingData()->GetCapacity()) {
return ge::GRAPH_FAILED;
}
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
return ge::GRAPH_SUCCESS;
}
template <typename T>
static inline void PrintTilingDataList(T &tilingData)
{
auto strideList = tilingData.get_strideList();
for (int32_t i = 0; i < tilingData.get_tensorNum(); i++) {
OP_LOGI("[Concat list]","tensor: %d, stride: %ld", i, strideList[i]);
}
}
template <typename T>
static inline void PrintTilingData(T& tilingData, int64_t tilingKey, int64_t usedCoreNum)
{
OP_LOGI(
"[Concat]",
"ubSplitDim1: %d, dim: %d, blockFactor: %ld,tailBlockFactor: %ld,\
ubFactorDim0: %d,ubFactorDim1: %d,tailUbFactorDim0: %d, tailUbFactorDim1: %d,uoDim0: %ld,uoDim1: %ld,\
tensorNum: %d,catDim1: %ld,isnon: %d,tilingKey: %ld,usedCoreNum: %ld",
tilingData.get_ubSplitDim1(), tilingData.get_dim(), tilingData.get_blockFactor(),
tilingData.get_tailBlockFactor(), tilingData.get_ubFactorDim0(), tilingData.get_ubFactorDim1(),
tilingData.get_tailUbFactorDim0(), tilingData.get_tailUbFactorDim1(), tilingData.get_uoDim0(),
tilingData.get_uoDim1(), tilingData.get_tensorNum(), tilingData.get_catDim1(), tilingData.get_isNonContiguous(), tilingKey, usedCoreNum);
PrintTilingDataList(tilingData);
}
inline static ge::graphStatus GetTensorList(
const gert::TilingContext* context, ConcatTilingParam& param, int64_t inputIdx)
{
auto computeNodeInfo = context->GetComputeNodeInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, computeNodeInfo);
auto anchorInstanceInfo = computeNodeInfo->GetInputInstanceInfo(inputIdx);
OP_CHECK_NULL_WITH_CONTEXT(context, anchorInstanceInfo);
uint32_t inputNum = anchorInstanceInfo->GetInstanceNum();
for (uint32_t i = 0; i < inputNum; ++i) {
gert::Shape inputTensorShape = GetShapeByAll(context, param.isNonContiguous, inputIdx, i);
size_t inputTensorDimNum = inputTensorShape.GetDimNum();
vector<int64_t> inputShapeList(inputTensorDimNum, 0);
for (size_t j = 0; j < inputTensorDimNum; j++) {
inputShapeList[j] = inputTensorShape.GetDim(j);
}
param.tensorList.push_back(inputShapeList);
}
return ge::GRAPH_SUCCESS;
}
inline static int64_t MergeDim(const vector<int64_t>& tensorSize, int64_t startIdx, int64_t endIdx)
{
int64_t ans = 1;
for (int64_t i = startIdx; i < endIdx; i++) {
ans *= tensorSize[i];
}
return ans;
}
inline static void GetTensorListDim(ConcatTilingParam& param)
{
vector<int64_t> tmpTensor(DIM2);
for (const auto& tensorSize : param.tensorList) {
tmpTensor[DIM0] = MergeDim(tensorSize, 0, param.dim);
tmpTensor[DIM1] = MergeDim(tensorSize, param.dim, tensorSize.size());
param.mergeTensorList.push_back(tmpTensor);
}
for (const auto& tensorSize : param.mergeTensorList) {
param.tensorListDim0.push_back(tensorSize[0]);
param.tensorListDim1.push_back(tensorSize[1]);
}
}
inline static void GetTensorSameDim1(ConcatTilingParam& param)
{
if (static_cast<int64_t>(param.tensorListDim1.size()) > 0) {
if (param.inputShapeSame == 1) {
param.sameShapeTensorDim1 = param.tensorListDim1[0];
} else {
param.sameShapeTensorDim1 = MergeDim(param.tensorList[0], param.dim + 1, param.tensorList[0].size());
}
}
}
inline static int64_t CalcSum(const vector<int64_t>& vec)
{
int64_t sum = 0;
for (const auto& it : vec) {
sum += it;
}
return sum;
}
inline static void GenerateOutputShape(ConcatTilingParam& param)
{
if (static_cast<int64_t>(param.tensorListDim0.size()) > 0) {
param.catDim0 = param.tensorListDim0[0];
} else {
param.catDim0 = 0;
}
param.catDim1 = CalcSum(param.tensorListDim1);
param.isEmpty = (param.catDim0 * param.catDim1) == 0;
}
inline static void CalcNoAlignTensorNum(ConcatTilingParam& param, int64_t dtypeSize)
{
int64_t num = 0;
for (const auto& tensorSize : param.tensorListDim1) {
if (tensorSize * dtypeSize % BLOCK_SIZE != 0) {
num += 1;
}
}
param.noAlignTensorNum = num;
OP_LOGD("[Concat]", "noAlignTensorNum: %ld", param.noAlignTensorNum);
}
inline static bool CheckCatDimAlign(vector<vector<int64_t>>& mergeTensorList, int64_t dtypeSize)
{
for (int64_t i = 0; i < static_cast<int64_t>(mergeTensorList.size()); i++) {
if (mergeTensorList[i][DIM1] * dtypeSize % BLOCK_SIZE != 0) {
return false;
}
}
return true;
}
inline static bool CheckDim1Align(vector<vector<int64_t>>& mergeTensorList, int64_t dtypeSize)
{
for (int64_t i = 0; i < static_cast<int64_t>(mergeTensorList.size()); i++) {
if (mergeTensorList[i][DIM1] * dtypeSize % DIM1_ALIGN_THRESHOLD != 0) {
return false;
}
}
return true;
}
inline static bool CheckInputShapeSame(vector<vector<int64_t>>& tensorList)
{
for (int64_t i = 0; i < static_cast<int64_t>(tensorList.size()) - 1; i++) {
if (tensorList[i] != tensorList[i + 1]) {
return false;
}
}
return true;
}
inline static ge::graphStatus CheckFP4Dim1Even(const ConcatTilingParam& param)
{
for (const auto& tensorSize : param.tensorListDim1) {
if (tensorSize % FP4_TO_B8_RATIO != 0) {
OP_LOGE("[Concat]", "FP4 Dim1 must be even, but got %ld", tensorSize);
return ge::GRAPH_FAILED;
}
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus ConvertFP4DimsToB8(ConcatTilingParam& param)
{
if (!param.isFP4Type) {
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ret = CheckFP4Dim1Even(param);
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
param.catDim1 /= FP4_TO_B8_RATIO;
if (param.inputShapeSame == 1) {
param.sameShapeTensorDim1 /= FP4_TO_B8_RATIO;
}
for (auto& tensorSize : param.tensorListDim1) {
tensorSize /= FP4_TO_B8_RATIO;
}
for (auto& tensorSize : param.mergeTensorList) {
tensorSize[1] /= FP4_TO_B8_RATIO;
}
if (param.isNonContiguous) {
for (int16_t i = 0; i < param.tensorNum; ++i) {
param.strideList[i] /= FP4_TO_B8_RATIO;
}
}
param.dtypeSize = B8_BYTES;
param.orgDtypeSize = B8_BYTES;
return ge::GRAPH_SUCCESS;
}
inline static ge::graphStatus CalcBaseTilingParam(const gert::TilingContext* context, ConcatTilingParam& param)
{
auto compileInfo = reinterpret_cast<const ConcatDCompileInfo*>(context->GetCompileInfo());
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
param.totalCoreNum = min(static_cast<int64_t>(compileInfo->totalCoreNum), TILING_ARRAY_LENGTH);
if (compileInfo->totalCoreNum > TILING_ARRAY_LENGTH) {
OP_LOGW("[Concat]", "Currently, more than 72 cores are not supported, Only 72 cores are used.");
}
param.ubSize = compileInfo->ubSize;
param.tensorNum = param.tensorList.size();
param.gatherThreshold = compileInfo->vectorLen / DIGIT_TWO;
GetTensorListDim(param);
GenerateOutputShape(param);
param.orgDtypeSize = param.dtypeSize;
param.inputShapeSame = CheckInputShapeSame(param.mergeTensorList) ? 1 : 0;
GetTensorSameDim1(param);
param.isFP4Type = (param.orgDtypeSize == B4_BYTES) ? 1 : 0;
OP_CHECK_IF(
ConvertFP4DimsToB8(param) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "ConvertFP4DimsToB8 failed."), return ge::GRAPH_FAILED);
param.isAllTensorAlign = CheckCatDimAlign(param.mergeTensorList, param.dtypeSize) ? 1 : 0;
param.isDim1AllAlign = CheckDim1Align(param.mergeTensorList, param.dtypeSize) ? 1 : 0;
OP_CHECK_IF(
param.dtypeSize <= 0,
OP_LOGE(
context->GetNodeName(), "param.dtypeSize must be greater than 0, param.dtypeSize: %ld", param.dtypeSize),
return ge::GRAPH_FAILED);
param.leastCopyNumber = MIN_RESERVED_SIZE / param.dtypeSize;
param.everyBlockNumber = BLOCK_SIZE / param.dtypeSize;
CalcNoAlignTensorNum(param, param.dtypeSize);
return ge::GRAPH_SUCCESS;
}
template <typename T>
inline static ge::graphStatus GetConcatDimInput(
const gert::TilingContext* context, ConcatTilingParam& param, int64_t dimIdx)
{
auto concatDimTensor = context->GetRequiredInputTensor(dimIdx);
OP_CHECK_NULL_WITH_CONTEXT(context, concatDimTensor);
const T* concatDimValPtr = concatDimTensor->GetData<T>();
OP_CHECK_NULL_WITH_CONTEXT(context, concatDimValPtr);
param.dim = concatDimValPtr[0];
return ge::GRAPH_SUCCESS;
}
inline static bool IsInvalidType(const DataType dtype)
{
std::set<ge::DataType> supportedDtype = {
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_UINT8, ge::DT_INT8, ge::DT_UINT16, ge::DT_INT16,
ge::DT_UINT32, ge::DT_INT32, ge::DT_UINT64, ge::DT_INT64, ge::DT_BOOL, ge::DT_DOUBLE, ge::DT_COMPLEX64,
ge::DT_FLOAT8_E4M3FN, ge::DT_FLOAT8_E5M2, ge::DT_HIFLOAT8, ge::DT_FLOAT8_E8M0, ge::DT_FLOAT4_E1M2, ge::DT_FLOAT4_E2M1};
bool isInvalidType = (supportedDtype.count(dtype) == 0);
return isInvalidType;
}
inline static ge::graphStatus GetDtypeSize(
const gert::TilingContext* context, ConcatTilingParam& param, size_t inputIndex)
{
auto inputDesc = context->GetDynamicInputDesc(inputIndex, 0);
OP_CHECK_NULL_WITH_CONTEXT(context, inputDesc);
auto inputDataType = inputDesc->GetDataType();
param.dtypeSize = ge::GetSizeByDataType(inputDataType);
return ge::GRAPH_SUCCESS;
}
template <typename T>
inline static void DupTensor(vector<T>& dst, const vector<T>& src, int64_t num)
{
int64_t index = 0;
for (int i = 0; i < num; i++) {
for (int64_t j = 0; (j < static_cast<int64_t>(src.size()) && index < TILING_ARRAY_LENGTH); j++) {
dst[index] = src[j];
index++;
}
}
}
inline static bool IsEnableGather(const ConcatTilingParam& param)
{
if (param.isAllTensorAlign == 0 && param.inputShapeSame == 1 &&
param.sameShapeTensorDim1 * param.dtypeSize < param.gatherThreshold) {
return true;
}
return false;
}
inline static bool IsEnableScatter(const ConcatTilingParam& param)
{
if (param.isAllTensorAlign == 0 && param.inputShapeSame == 0) {
return true;
}
return false;
}
inline static bool IsEnableRowConcat(const ConcatTilingParam& param)
{
if (param.isAllTensorAlign == 1 && param.dim == 0 && !param.isEmpty) {
return true;
}
return false;
}
inline static void CalcLargeTensorNum(
const ConcatTilingParam& param, int64_t tensorCol, int64_t rowsUsedCoreNum, int64_t& largeInputNum,
int64_t& totalInputNum)
{
if (tensorCol * param.ubFactorDim0 * param.dtypeSize >= BLOCK_THRESHOLD) {
largeInputNum += (rowsUsedCoreNum - 1);
}
if (tensorCol * param.tailUbFactorDim0 * param.dtypeSize >= BLOCK_THRESHOLD) {
largeInputNum += 1;
}
totalInputNum += rowsUsedCoreNum;
}
inline static bool IsEnablePureCopyTemplate(
const ConcatTilingParam& param, int64_t rowsUsedCoreNum, int64_t colsUsedCoreNum)
{
int64_t threshold = 0;
if (param.isDim1AllAlign == 1 && param.inputShapeSame == 1) {
threshold = PURE_COPY_COL_THRESHOLD_BASE;
} else if (param.isDim1AllAlign == 1 || param.inputShapeSame == 1) {
threshold = PURE_COPY_COL_THRESHOLD_ALIGN;
} else {
threshold = PURE_COPY_COL_THRESHOLD_NOALIGN;
}
for (const auto& tensorSize : param.tensorListDim1) {
if (tensorSize * param.dtypeSize < threshold) {
return false;
}
}
int64_t totalInputNum = 0;
int64_t largeInputNum = 0;
if (param.blockSplitAxis == 0) {
for (const auto& tensorSize : param.tensorListDim1) {
CalcLargeTensorNum(param, tensorSize, param.usedCoreNum, largeInputNum, totalInputNum);
}
} else {
for (int64_t i = 0; i < colsUsedCoreNum; i++) {
if (param.startTensorIdx[i] == param.endTensorIdx[i]) {
int64_t tensorCol = param.endTensorOffset[i] - param.startTensorOffset[i];
CalcLargeTensorNum(param, tensorCol, rowsUsedCoreNum, largeInputNum, totalInputNum);
continue;
}
int16_t startIdx = param.startTensorIdx[i];
CalcLargeTensorNum(
param, param.tensorListDim1[startIdx] - param.startTensorOffset[i], rowsUsedCoreNum,
largeInputNum, totalInputNum);
for (int16_t k = param.startTensorIdx[i] + 1; k < param.endTensorIdx[i]; k++) {
CalcLargeTensorNum(param, param.tensorListDim1[k], rowsUsedCoreNum, largeInputNum, totalInputNum);
}
int64_t lastTensorCol = param.endTensorOffset[i];
CalcLargeTensorNum(param, lastTensorCol, rowsUsedCoreNum, largeInputNum, totalInputNum);
}
}
if (totalInputNum <= 0) {
return false;
}
double largeRatio = static_cast<double>(largeInputNum) / static_cast<double>(totalInputNum);
if (largeRatio >= LARGE_TENSOR_RATIO_THRESHOLD) {
return true;
}
return false;
}
inline static void GenTilingKey(ConcatTilingParam& param)
{
if (param.isEmpty) {
param.tilingKey = 0;
return;
}
bool shapeSame = param.inputShapeSame == 1;
bool isAllTensorAlign = param.isAllTensorAlign == 1;
int64_t isCatDimAlign = isAllTensorAlign ? 1 : 2;
int64_t dtypeSize = param.dtypeSize;
if (IsEnableScatter(param)) {
dtypeSize = param.orgDtypeSize;
}
if (IsEnableGather(param)) {
isCatDimAlign = GATHER_MODE;
}
int64_t isInputShapeSame = shapeSame ? 1 : 2;
int64_t isFirstDim = DIGIT_TWO;
int64_t isUseSpcTilingData = param.blockSplitAxis == 0 ? 1 : 0;
param.tilingKey = dtypeSize + isInputShapeSame * TENS_DIGITS + isCatDimAlign * HUNDREDS_DIGITS +
isFirstDim * THOUSANDS_DIGITS + isUseSpcTilingData * TEN_THOUSANDS_DIGITS;
}
inline static ge::graphStatus IsDimValid(const gert::TilingContext* context, int64_t& dim, int64_t inputIdx, bool isNonContiguous, int64_t& strideDim)
{
gert::Shape inputShape = GetShapeByAll(context, isNonContiguous, inputIdx, 0);
int64_t shapeSize = static_cast<int64_t>(inputShape.GetDimNum());
int64_t minDim = shapeSize * static_cast<int64_t>(-1);
int64_t maxDim = shapeSize - 1;
if (!(dim >= minDim && dim <= maxDim)) {
return ge::GRAPH_FAILED;
}
if (dim < 0) {
dim += shapeSize;
}
strideDim = dim - 1;
return ge::GRAPH_SUCCESS;
}
inline static ge::graphStatus IsShapeValid(
const gert::TilingContext* context, vector<vector<int64_t>>& tensorList, int64_t realDim)
{
if (tensorList.size() < 1) {
return ge::GRAPH_SUCCESS;
}
int64_t dimSize = tensorList[0].size();
auto shape0 = tensorList[0];
for (const auto& tensorSize : tensorList) {
int64_t curDimSize = tensorSize.size();
OP_CHECK_IF(
curDimSize != dimSize, OP_LOGE(context->GetNodeName(), "dimSize of input tensor should be equal."),
return ge::GRAPH_FAILED);
for (int64_t j = 0; j < dimSize; j++) {
if (realDim == j) {
continue;
}
OP_CHECK_IF(
shape0[j] != tensorSize[j],
OP_LOGE(context->GetNodeName(), "dim %ld of input tensor should be equal.", j),
return ge::GRAPH_FAILED);
}
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingUbForNosplitDim1(
gert::TilingContext* context, int64_t maxAvaliableUb, ConcatTilingParam& param)
{
OP_CHECK_IF(
param.catDim1 <= 0,
OP_LOGE(context->GetNodeName(), "param.catDim1 must be greater than 0, param.catDim1: %ld", param.catDim1),
return ge::GRAPH_FAILED);
param.ubFactorDim0 = min(maxAvaliableUb / param.catDim1, param.catDim0);
OP_CHECK_IF(
param.ubFactorDim0 <= 0,
OP_LOGE(context->GetNodeName(), "ubFactorDim0 must be greater than 0, ubFactorDim0: %ld", param.ubFactorDim0),
return ge::GRAPH_FAILED);
param.uoDim0 = (param.catDim0 + param.ubFactorDim0 - 1) / param.ubFactorDim0;
param.uoDim1 = 1;
param.ubFactorDim1 = param.catDim1;
param.tailUbFactorDim0 = param.catDim0 % param.ubFactorDim0;
if (param.tailUbFactorDim0 == 0) {
param.tailUbFactorDim0 = param.ubFactorDim0;
}
param.tailUbFactorDim1 = param.catDim1;
param.ubSplitDim1 = 0;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingUbForSplitDim1(
gert::TilingContext* context, int64_t maxAvaliableUb, int64_t storageAlignUsed, int64_t maxDim1Factor,
ConcatTilingParam& param)
{
int64_t realFactorDim1 = maxDim1Factor;
if (param.isAllTensorAlign == 0 && param.inputShapeSame == 1) {
realFactorDim1 = maxAvaliableUb / std::min(LEAST_ROWS, param.catDim0);
OP_CHECK_IF(
param.everyBlockNumber <= 0,
OP_LOGE(
context->GetNodeName(), "everyBlockNumber must be greater than 0, everyBlockNumber: %ld",
param.everyBlockNumber),
return ge::GRAPH_FAILED);
int64_t alignFactorDim1 = param.everyBlockNumber;
if (param.inputShapeSame == 1 && param.sameShapeTensorDim1 * param.dtypeSize <= param.gatherThreshold) {
alignFactorDim1 = param.sameShapeTensorDim1;
}
realFactorDim1 = realFactorDim1 / alignFactorDim1 * alignFactorDim1;
} else {
maxAvaliableUb -= storageAlignUsed;
}
param.ubFactorDim1 = min(realFactorDim1, param.catDim1);
OP_CHECK_IF(
param.ubFactorDim1 <= 0,
OP_LOGE(
context->GetNodeName(), "param.ubFactorDim1 must be greater than 0, param.ubFactorDim1: %ld",
param.ubFactorDim1),
return ge::GRAPH_FAILED);
param.ubFactorDim0 = min(maxAvaliableUb / param.ubFactorDim1, param.catDim0);
OP_CHECK_IF(
param.ubFactorDim0 <= 0,
OP_LOGE(
context->GetNodeName(), "param.ubFactorDim0 must be greater than 0, param.ubFactorDim0: %ld",
param.ubFactorDim0),
return ge::GRAPH_FAILED);
param.uoDim1 = (param.catDim1 + param.ubFactorDim1 - 1) / param.ubFactorDim1;
param.uoDim0 = (param.catDim0 + param.ubFactorDim0 - 1) / param.ubFactorDim0;
param.tailUbFactorDim0 = param.catDim0 % param.ubFactorDim0;
if (param.tailUbFactorDim0 == 0) {
param.tailUbFactorDim0 = param.ubFactorDim0;
}
param.tailUbFactorDim1 = param.catDim1 % param.ubFactorDim1;
if (param.tailUbFactorDim1 == 0) {
param.tailUbFactorDim1 = param.ubFactorDim1;
}
param.ubSplitDim1 = 1;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingUb(gert::TilingContext* context, ConcatTilingParam& param)
{
OP_CHECK_IF(
param.dtypeSize <= 0,
OP_LOGE(
context->GetNodeName(), "param.dtypeSize must be greater than 0, param.dtypeSize: %ld", param.dtypeSize),
return ge::GRAPH_FAILED);
int64_t maxAvaliableUb = (param.ubSize - INDEX_USE_UB) / param.dtypeSize;
if (param.isAllTensorAlign == 0) {
maxAvaliableUb = maxAvaliableUb / BUFFER_NUM;
maxAvaliableUb = std::min(maxAvaliableUb, static_cast<int64_t>(std::numeric_limits<uint16_t>::max()));
}
param.bufferSize = maxAvaliableUb;
int64_t realFactorDim1 = maxAvaliableUb / std::min(LEAST_ROWS, param.catDim0);
int64_t storageAlignUsed = 0;
if (param.isAllTensorAlign == 0) {
storageAlignUsed = param.everyBlockNumber * (param.noAlignTensorNum + 1);
realFactorDim1 = (maxAvaliableUb - storageAlignUsed) / std::min(LEAST_ROWS, param.catDim0);
}
OP_CHECK_IF(
param.everyBlockNumber <= 0,
OP_LOGE(
context->GetNodeName(), "param.everyBlockNumber must be greater than 0, param.everyBlockNumber: %ld",
param.everyBlockNumber),
return ge::GRAPH_FAILED);
realFactorDim1 = realFactorDim1 / param.everyBlockNumber * param.everyBlockNumber;
if (param.catDim1 < realFactorDim1) {
maxAvaliableUb = maxAvaliableUb - storageAlignUsed;
OP_CHECK_IF(
TilingUbForNosplitDim1(context, maxAvaliableUb, param) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "TilingUbForNosplitDim1 failed"), return ge::GRAPH_FAILED);
} else {
OP_CHECK_IF(
TilingUbForSplitDim1(context, maxAvaliableUb, storageAlignUsed, realFactorDim1, param) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "TilingUbForSplitDim1 failed"), return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
inline static ge::graphStatus TilingBlock(gert::TilingContext* context, ConcatTilingParam& param)
{
if (param.uoDim0 > (param.totalCoreNum / HALF)) {
OP_CHECK_IF(
param.totalCoreNum <= 0,
OP_LOGE(
context->GetNodeName(), "param.totalCoreNum must be greater than 0, param.totalCoreNum: %ld",
param.totalCoreNum),
return ge::GRAPH_FAILED);
param.blockFactor = (param.uoDim0 + param.totalCoreNum - 1) / param.totalCoreNum;
OP_CHECK_IF(
param.blockFactor <= 0,
OP_LOGE(
context->GetNodeName(), "param.blockFactor must be greater than 0, param.blockFactor: %ld",
param.blockFactor),
return ge::GRAPH_FAILED);
param.usedCoreNum = (param.uoDim0 + param.blockFactor - 1) / param.blockFactor;
param.tailBlockFactor = param.uoDim0 - (param.usedCoreNum - 1) * param.blockFactor;
param.blockSplitAxis = 0;
} else {
int64_t rowsUsedCoreNum = param.uoDim0;
OP_CHECK_IF(
rowsUsedCoreNum <= 0,
OP_LOGE(
context->GetNodeName(), "rowsUsedCoreNum must be greater than 0, rowsUsedCoreNum: %ld",
rowsUsedCoreNum),
return ge::GRAPH_FAILED);
int64_t leftCore = param.totalCoreNum / rowsUsedCoreNum;
int64_t alignFactorDim1 = param.everyBlockNumber;
if (param.inputShapeSame == 1 && param.sameShapeTensorDim1 * param.dtypeSize <= param.gatherThreshold) {
alignFactorDim1 = param.sameShapeTensorDim1;
}
OP_CHECK_IF(
alignFactorDim1 <= 0,
OP_LOGE(
context->GetNodeName(), "alignFactorDim1 must be greater than 0, alignFactorDim1: %ld",
alignFactorDim1),
return ge::GRAPH_FAILED);
while (param.uoDim1 < leftCore && param.ubFactorDim1 * param.dtypeSize >= LEAST_COLS &&
param.ubFactorDim1 * param.ubFactorDim0 >= HALF * param.leastCopyNumber) {
param.ubFactorDim1 = (param.ubFactorDim1 / HALF) / alignFactorDim1 * alignFactorDim1;
OP_CHECK_IF(
param.ubFactorDim1 <= 0,
OP_LOGE(
context->GetNodeName(), "param.ubFactorDim1 must be greater than 0, param.ubFactorDim1: %ld",
param.ubFactorDim1),
return ge::GRAPH_FAILED);
param.uoDim1 = (param.catDim1 + param.ubFactorDim1 - 1) / param.ubFactorDim1;
param.tailUbFactorDim1 = param.catDim1 % param.ubFactorDim1;
if (param.tailUbFactorDim1 == 0) {
param.tailUbFactorDim1 = param.ubFactorDim1;
}
}
if (param.uoDim1 > 1) {
param.blockSplitAxis = 1;
OP_CHECK_IF(
leftCore <= 0,
OP_LOGE(context->GetNodeName(), "leftCore must be greater than 0, leftCore: %ld", leftCore),
return ge::GRAPH_FAILED);
param.blockFactor = (param.uoDim1 + leftCore - 1) / leftCore;
OP_CHECK_IF(
param.blockFactor <= 0,
OP_LOGE(
context->GetNodeName(), "param.blockFactor must be greater than 0, param.blockFactor: %ld",
param.blockFactor),
return ge::GRAPH_FAILED);
int64_t colsUsedCoreNum = (param.uoDim1 + param.blockFactor - 1) / param.blockFactor;
param.usedCoreNum = rowsUsedCoreNum * colsUsedCoreNum;
param.tailBlockFactor = param.uoDim1 - (colsUsedCoreNum - 1) * param.blockFactor;
} else {
param.blockSplitAxis = 0;
param.blockFactor = 1;
param.tailBlockFactor = 1;
param.usedCoreNum = rowsUsedCoreNum;
}
}
return ge::GRAPH_SUCCESS;
}
inline static void SetTensorListTilingData(ConcatTilingData& tilingData, ConcatTilingParam& param)
{
std::copy(param.endTensorIdx.begin(), param.endTensorIdx.end(), param.endIdxArr);
tilingData.set_endTensorIdx(param.endIdxArr);
std::copy(param.endTensorOffset.begin(), param.endTensorOffset.end(), param.endOffsetArr);
tilingData.set_endTensorOffset(param.endOffsetArr);
}
template <typename T>
inline static void SetTilingData(T& tilingData, ConcatTilingParam& param)
{
tilingData.set_ubSplitDim1(param.ubSplitDim1);
tilingData.set_dim(static_cast<int16_t>(param.dim));
tilingData.set_blockFactor(param.blockFactor);
tilingData.set_tailBlockFactor(param.tailBlockFactor);
tilingData.set_ubFactorDim0(static_cast<int32_t>(param.ubFactorDim0));
tilingData.set_ubFactorDim1(static_cast<int32_t>(param.ubFactorDim1));
tilingData.set_tailUbFactorDim0(static_cast<int32_t>(param.tailUbFactorDim0));
tilingData.set_tailUbFactorDim1(static_cast<int32_t>(param.tailUbFactorDim1));
tilingData.set_uoDim0(param.uoDim0);
tilingData.set_uoDim1(param.uoDim1);
tilingData.set_tensorNum(param.tensorNum);
tilingData.set_catDim1(param.catDim1);
tilingData.set_sameShapeTensorDim1(param.sameShapeTensorDim1);
tilingData.set_isFP4Type(param.isFP4Type);
tilingData.set_bufferSize(static_cast<int32_t>(param.bufferSize));
tilingData.set_dtypeSize(static_cast<int16_t>(param.dtypeSize));
int64_t preLoadSize = std::min(TILING_PRELOAD_DIM1_LENGTH, static_cast<int64_t>(param.tensorListDim1.size()));
std::copy(param.tensorListDim1.begin(), param.tensorListDim1.begin() + preLoadSize, param.preLoadDim1Arr);
tilingData.set_preLoadDim1(param.preLoadDim1Arr);
tilingData.set_isNonContiguous(static_cast<int16_t>(param.isNonContiguous ? 1 : 0));
if (param.isNonContiguous) {
uint64_t strideList[NON_CON_TENSOR_SIZE];
std::copy(param.strideList.begin(), param.strideList.end(), strideList);
tilingData.set_strideList(strideList);
std::copy(param.concatDimList.begin(), param.concatDimList.end(), strideList);
tilingData.set_concatDimList(strideList);
}
}
inline static void CalcTensorList(ConcatTilingParam& param, int64_t everyCoreData, int64_t rowsUsedCoreNum)
{
int64_t curOffset = 0;
int64_t curTensorOffset = 0;
for (int16_t i = 0; i < param.tensorNum; i++) {
if (param.blockStartTensorIdx.size() == param.blockEndTensorIdx.size()) {
param.blockStartTensorIdx.push_back(i);
param.blockStartTensorOffset.push_back(0);
}
while (curOffset + param.tensorListDim1[i] - curTensorOffset >= everyCoreData) {
param.blockEndTensorIdx.push_back(i);
curTensorOffset = curTensorOffset + everyCoreData - curOffset;
param.blockEndTensorOffset.push_back(curTensorOffset);
if (curTensorOffset == param.tensorListDim1[i]) {
curOffset = 0;
break;
} else {
param.blockStartTensorIdx.push_back(i);
param.blockStartTensorOffset.push_back(curTensorOffset);
curOffset = 0;
}
}
if (curTensorOffset == 0) {
curOffset += param.tensorListDim1[i];
} else {
curOffset = param.tensorListDim1[i] - curTensorOffset;
}
curTensorOffset = 0;
}
if (curOffset != 0) {
param.blockEndTensorIdx.push_back(param.tensorNum - 1);
param.blockEndTensorOffset.push_back(param.tensorListDim1[param.tensorNum - 1]);
}
DupTensor(param.startTensorIdx, param.blockStartTensorIdx, rowsUsedCoreNum);
DupTensor(param.endTensorIdx, param.blockEndTensorIdx, rowsUsedCoreNum);
DupTensor(param.startTensorOffset, param.blockStartTensorOffset, rowsUsedCoreNum);
DupTensor(param.endTensorOffset, param.blockEndTensorOffset, rowsUsedCoreNum);
}
inline static bool IsEnableb8ToB16(const ConcatTilingParam& param)
{
if (param.isNonContiguous) {
if (param.dtypeSize != B8_BYTES || param.inputShapeSame != 1 || param.sameShapeTensorDim1 % DIGIT_TWO != 0 || param.strideList[0] % DIGIT_TWO != 0) {
return false;
}
for (const auto& tensorSize : param.tensorListDim1) {
if (tensorSize % DIGIT_TWO != 0) {
return false;
};
}
for (int16_t i = 0; i < param.tensorNum; ++i) {
if (param.strideList[i] % DIGIT_TWO != 0) {
return false;
}
}
} else {
if (param.dtypeSize != B8_BYTES || param.inputShapeSame != 1 || param.sameShapeTensorDim1 % DIGIT_TWO != 0) {
return false;
}
for (const auto& tensorSize : param.tensorListDim1) {
if (tensorSize % DIGIT_TWO != 0) {
return false;
};
}
}
return true;
}
static ge::graphStatus PreProcessForNoAlign(ConcatTilingParam& param)
{
if (param.isAllTensorAlign == 1) {
return ge::GRAPH_SUCCESS;
}
if (param.dtypeSize == B64_BYTES) {
param.sameShapeTensorDim1 *= DIGIT_TWO;
param.dtypeSize = B32_BYTES;
param.leastCopyNumber = MIN_RESERVED_SIZE / param.dtypeSize;
param.everyBlockNumber = BLOCK_SIZE / param.dtypeSize;
param.catDim1 *= DIGIT_TWO;
for (auto& tensorSize : param.tensorListDim1) {
tensorSize *= DIGIT_TWO;
}
for (auto& tensorSize : param.mergeTensorList) {
tensorSize[1] *= DIGIT_TWO;
}
if (param.isNonContiguous) {
for (int16_t i = 0; i < param.tensorNum; ++i) {
param.strideList[i] *= DIGIT_TWO;
}
}
return ge::GRAPH_SUCCESS;
}
if (IsEnableb8ToB16(param)) {
param.sameShapeTensorDim1 /= DIGIT_TWO;
param.dtypeSize = B16_BYTES;
param.leastCopyNumber = MIN_RESERVED_SIZE / param.dtypeSize;
param.everyBlockNumber = BLOCK_SIZE / param.dtypeSize;
param.catDim1 /= DIGIT_TWO;
for (auto& tensorSize : param.tensorListDim1) {
tensorSize /= DIGIT_TWO;
}
for (auto& tensorSize : param.mergeTensorList) {
tensorSize[1] /= DIGIT_TWO;
}
if (param.isNonContiguous) {
for (int16_t i = 0; i < param.tensorNum; ++i) {
param.strideList[i] /= DIGIT_TWO;
}
}
return ge::GRAPH_SUCCESS;
}
return ge::GRAPH_SUCCESS;
}
inline static std::vector<int64_t> FindUniqueCut(int64_t coreNum)
{
std::vector<int64_t> candidateSet;
int64_t upBound = static_cast<int64_t>(std::ceil(std::sqrt(coreNum) + 1.0));
for (int64_t m = 1; m < upBound; m++) {
int64_t y = coreNum / m;
candidateSet.push_back(m);
candidateSet.push_back(y);
}
return candidateSet;
}
static std::pair<int64_t, int64_t> AutoBlockTiling(int64_t rows, int64_t cols, int64_t coreNum)
{
std::vector<int64_t> candidateSet = FindUniqueCut(coreNum);
std::vector<std::vector<int64_t>> allTiling;
for (int64_t m : candidateSet) {
if (m > rows) {
continue;
}
int64_t mPart = (rows + m - 1) / m;
int64_t n = coreNum / m;
if (n > cols) {
continue;
}
int64_t nPart = (cols + n - 1) / n;
int64_t delta = mPart * nPart;
if (m * n == coreNum) {
if (rows % m == 0 && cols % n == 0) {
delta = 0;
} else if (rows % m == 0) {
delta = delta - mPart * (cols - nPart * (n - 1));
} else if (cols % n == 0) {
delta = delta - nPart * (rows - mPart * (m - 1));
} else {
delta = delta - (cols - nPart * (n - 1)) * (rows - mPart * (m - 1));
}
}
allTiling.push_back({m, n, m * n, delta});
}
std::sort(allTiling.begin(), allTiling.end(), [](const std::vector<int64_t>& a, const std::vector<int64_t>& b) {
return std::make_pair(a[DIGIT_THREE], -a[DIGIT_ONE]) < std::make_pair(b[DIGIT_THREE], -b[DIGIT_ONE]);
});
if (allTiling.size() == 0) {
return std::make_pair(0, 0);
}
return std::make_pair(allTiling[0][0], allTiling[0][DIGIT_ONE]);
}
static bool IsEnableUsedSimt(ConcatTilingParam& param)
{
int64_t totalDataNum = param.catDim0 * param.catDim1 * param.dtypeSize;
int64_t useCoreNum = std::min(static_cast<int64_t>(param.tensorNum), param.totalCoreNum);
int64_t maxDim1 = param.tensorListDim1[0];
int64_t minDim1 = param.tensorListDim1[0];
if (param.isNonContiguous) {
return false;
}
if (totalDataNum >= useCoreNum * SIMT_PER_CORE_THRESHOLD) {
return false;
}
if (param.isAllTensorAlign == 1 || param.inputShapeSame == 1) {
return false;
}
if (static_cast<int64_t>(param.tensorNum) % (param.totalCoreNum + 1) <= (param.totalCoreNum / DIGIT_TWO)) {
return false;
}
if (param.tensorNum > TILING_COLS_OFFSET_LENGTH) {
return false;
}
for (auto dim1Ptr = param.tensorListDim1.begin(); dim1Ptr != param.tensorListDim1.end(); ++dim1Ptr) {
maxDim1 = std::max(maxDim1, *dim1Ptr);
minDim1 = std::min(minDim1, *dim1Ptr);
if (maxDim1 / minDim1 >= DIGIT_TWO && maxDim1 * param.catDim0 * param.dtypeSize > SIMT_COMPARE_THRESHOLD) {
return false;
}
}
return true;
}
inline static void CalcTensorColsOffset(ConcatTilingParam& param)
{
int32_t curConcatDimOffset = 0;
int32_t digitSize = param.dtypeSize > B64_BYTES ? param.dtypeSize / B64_BYTES : DIGIT_ONE;
for (auto dimPtr = param.tensorListDim1.begin(); dimPtr != param.tensorListDim1.end(); ++dimPtr) {
curConcatDimOffset = static_cast<int32_t>(curConcatDimOffset + *dimPtr * digitSize);
param.tensorColsOffset.push_back(curConcatDimOffset);
}
}
inline static void SetTensorColsOffset(ConcatTilingDataForSimt& tilingData, ConcatTilingParam& param)
{
int32_t tilingList[TILING_COLS_OFFSET_LENGTH];
std::copy(param.tensorColsOffset.begin(), param.tensorColsOffset.end(), tilingList);
tilingData.set_tensorColsOffset(tilingList);
}
static inline void PrintSimtTilingData(ConcatTilingDataForSimt& tilingData)
{
OP_LOGI(
"[Concat]", "tensorNumPerCore: %d, get_tensorNum: %d,catDim0: %d,catDim1: %d",
tilingData.get_tensorNumPerCore(), tilingData.get_tensorNum(), tilingData.get_catDim0(),
tilingData.get_catDim1());
}
static ge::graphStatus TilingForConcatDSimt(gert::TilingContext* context, ConcatTilingParam& param)
{
ConcatTilingDataForSimt tilingData;
int64_t dtypeSize = param.dtypeSize;
int64_t catDim1 = static_cast<int64_t>(param.catDim1);
catDim1 = dtypeSize > B64_BYTES ? (dtypeSize / B64_BYTES) * catDim1 : catDim1;
dtypeSize = dtypeSize > B64_BYTES ? B64_BYTES : dtypeSize;
param.usedCoreNum = std::min(static_cast<int64_t>(param.tensorNum), param.totalCoreNum);
param.tensorNumPerCore = (static_cast<int64_t>(param.tensorNum) + param.usedCoreNum - 1) / param.usedCoreNum;
param.tilingKey = SIMT_TILINGKEY_PREFIX + dtypeSize;
CalcTensorColsOffset(param);
tilingData.set_tensorNumPerCore(param.tensorNumPerCore);
tilingData.set_tensorNum(static_cast<int32_t>(param.tensorNum));
tilingData.set_catDim0(static_cast<int32_t>(param.catDim0));
tilingData.set_catDim1(catDim1);
SetTensorColsOffset(tilingData, param);
PrintSimtTilingData(tilingData);
context->SetBlockDim(param.usedCoreNum);
context->SetTilingKey(param.tilingKey);
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
currentWorkspace[0] = SYSTEM_WORKSPACE_SIZE;
OP_CHECK_IF(
ConcatSetTilingData(context, tilingData) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "SimtSetTilingData set tiling data fail."), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static bool TilingForPureCopy(ConcatTilingParam& param)
{
param.usedCoreNum =
std::min(param.totalCoreNum, (param.catDim0 * param.catDim1 + EVERY_CORE_THRESHOLD - 1) / EVERY_CORE_THRESHOLD);
int64_t nCols = (param.catDim1 + LEAST_BLOCK_BYTES - 1) / LEAST_BLOCK_BYTES;
int64_t mRows = param.catDim0;
int64_t rowsCutPart = 0;
int64_t colsCutPart = 0;
std::tie(rowsCutPart, colsCutPart) = AutoBlockTiling(mRows, nCols, param.usedCoreNum);
if (rowsCutPart == 0 || colsCutPart == 0) {
return false;
}
param.ubFactorDim0 = (param.catDim0 + rowsCutPart - 1) / rowsCutPart;
param.tailUbFactorDim0 = param.catDim0 - param.ubFactorDim0 * (rowsCutPart - 1);
param.ubFactorDim1 = (param.catDim1 + colsCutPart - 1) / colsCutPart;
param.tailUbFactorDim1 = param.catDim1 - param.ubFactorDim1 * (colsCutPart - 1);
if (param.tailUbFactorDim0 < 0 || param.tailUbFactorDim1 < 0) {
return false;
}
int64_t everyBlockCols = (param.catDim1 + colsCutPart - 1) / colsCutPart;
if (colsCutPart > 1) {
param.blockSplitAxis = 1;
CalcTensorList(param, everyBlockCols, rowsCutPart);
} else {
param.blockSplitAxis = 0;
}
bool isUsedPureCopy = IsEnablePureCopyTemplate(param, rowsCutPart, colsCutPart);
if (isUsedPureCopy) {
if (ENABLE_DB) {
param.ubSize = param.ubSize / DIGIT_TWO;
}
param.bufferSize = param.ubSize / param.dtypeSize;
if (param.blockSplitAxis == 0) {
param.tilingKey = PURE_COPY_NO_SPLIT_DIM1_TILINGKEY;
} else {
param.tilingKey = PURE_COPY_SPLIT_DIM1_TILINGKEY;
}
param.uoDim0 = rowsCutPart;
param.uoDim1 = colsCutPart;
param.inputShapeSame = 0;
GetTensorSameDim1(param);
} else {
param.blockStartTensorIdx.clear();
param.blockEndTensorIdx.clear();
param.blockStartTensorOffset.clear();
param.blockEndTensorOffset.clear();
}
return isUsedPureCopy;
}
inline static ge::graphStatus DoTiling(gert::TilingContext* context, ConcatTilingParam& param)
{
if (param.isEmpty) {
param.usedCoreNum = 0;
return ge::GRAPH_SUCCESS;
}
if (TilingForPureCopy(param)) {
return ge::GRAPH_SUCCESS;
}
if (param.isAllTensorAlign == 0 && (param.dtypeSize == B64_BYTES || param.dtypeSize == B8_BYTES)) {
OP_CHECK_IF(
PreProcessForNoAlign(param) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "check PreProcessForNoAlign failed"), return ge::GRAPH_FAILED);
}
if (ENABLE_DB) {
param.ubSize = param.ubSize / HALF;
}
OP_CHECK_IF(
TilingUb(context, param) != ge::GRAPH_SUCCESS, OP_LOGE(context->GetNodeName(), "check tiling_ub failed"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
TilingBlock(context, param) != ge::GRAPH_SUCCESS, OP_LOGE(context->GetNodeName(), "check tiling_block failed"),
return ge::GRAPH_FAILED);
if (param.blockSplitAxis == 1) {
CalcTensorList(param, param.blockFactor * param.ubFactorDim1, param.uoDim0);
}
GenTilingKey(param);
return ge::GRAPH_SUCCESS;
}
template <typename T>
inline static int64_t GetAxis(gert::TilingContext* context)
{
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const int64_t* axis = attrs->GetAttrPointer<T>(PACK_ATTR_AXIS_IDX);
OP_CHECK_NULL_WITH_CONTEXT(context, axis);
return *axis;
}
inline static ge::graphStatus IsPackDimValid(gert::TilingContext* context, int64_t& dim)
{
auto inputShapePtr = context->GetDynamicInputShape(PACK_INPUT_IDX, 0);
gert::Shape inputShape = inputShapePtr->GetStorageShape();
int64_t shapeSize = static_cast<int64_t>(inputShape.GetDimNum());
int64_t minAxis = (shapeSize + PACK_AXIS_DEFAULT_VALUE) * (-1);
int64_t maxAxis = shapeSize;
if (!(dim >= minAxis && dim <= maxAxis)) {
return ge::GRAPH_FAILED;
}
if (dim < 0) {
dim = dim + shapeSize + PACK_AXIS_DEFAULT_VALUE;
}
return ge::GRAPH_SUCCESS;
}
bool IsInvalidTypeForPack(const DataType dtype)
{
std::set<ge::DataType> supportedDtype = {ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_UINT8,
ge::DT_INT8, ge::DT_UINT16, ge::DT_INT16, ge::DT_UINT32,
ge::DT_INT32, ge::DT_UINT64, ge::DT_INT64, ge::DT_BOOL,
ge::DT_DOUBLE, ge::DT_COMPLEX64, ge::DT_COMPLEX32};
bool isInvalidType = (supportedDtype.count(dtype) == 0);
return isInvalidType;
}
ge::graphStatus CheckInputShapeSameForPack(gert::TilingContext* context)
{
auto computeNodeInfo = context->GetComputeNodeInfo();
auto anchorInstanceInfo = computeNodeInfo->GetInputInstanceInfo(PACK_INPUT_IDX);
uint32_t inputNum = anchorInstanceInfo->GetInstanceNum();
if (inputNum < 1) {
return ge::GRAPH_FAILED;
}
auto firstInputTensorShapePtr = context->GetDynamicInputShape(PACK_INPUT_IDX, 0);
gert::Shape firstInputTensorShape = firstInputTensorShapePtr->GetStorageShape();
size_t firstInputTensorDimNum = firstInputTensorShape.GetDimNum();
vector<int64_t> fisrtInputShapeList(firstInputTensorDimNum, 0);
for (size_t i = 0; i < firstInputTensorDimNum; i++) {
fisrtInputShapeList[i] = firstInputTensorShape.GetDim(i);
}
for (uint32_t i = 1; i < inputNum; ++i) {
auto inputTensorShapePtr = context->GetDynamicInputShape(PACK_INPUT_IDX, i);
gert::Shape inputTensorShape = inputTensorShapePtr->GetStorageShape();
size_t inputTensorDimNum = inputTensorShape.GetDimNum();
vector<int64_t> inputShapeList(inputTensorDimNum, 0);
for (size_t j = 0; j < inputTensorDimNum; j++) {
inputShapeList[j] = inputTensorShape.GetDim(j);
if (inputTensorShape.GetDim(j) != fisrtInputShapeList[j]) {
return ge::GRAPH_FAILED;
}
}
}
return ge::GRAPH_SUCCESS;
}
void GetTensorListForPack(gert::TilingContext* context, ConcatTilingParam& param)
{
auto inputTensorShapePtr = context->GetDynamicInputShape(PACK_INPUT_IDX, 0);
gert::Shape inputTensorShape = inputTensorShapePtr->GetStorageShape();
size_t inputTensorDimNum = inputTensorShape.GetDimNum();
vector<int64_t> inputShapeList(inputTensorDimNum, 0);
for (size_t i = 0; i < inputTensorDimNum; i++) {
inputShapeList[i] = inputTensorShape.GetDim(i);
}
auto computeNodeInfo = context->GetComputeNodeInfo();
auto anchorInstanceInfo = computeNodeInfo->GetInputInstanceInfo(PACK_INPUT_IDX);
uint32_t inputNum = anchorInstanceInfo->GetInstanceNum();
for (uint32_t i = 0; i < inputNum; ++i) {
param.tensorList.push_back(inputShapeList);
}
}
ge::graphStatus Tiling4PackToConcatForAscendC(gert::TilingContext* context)
{
OP_LOGD(context->GetNodeName(), "Tiling4PackToConcatForAscendC running begin");
ConcatTilingParam param;
param.dim = GetAxis<int64_t>(context);
OP_CHECK_IF(
IsPackDimValid(context, param.dim) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "check pack_axis failed, please check pack_axis."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
CheckInputShapeSameForPack(context) != ge::GRAPH_SUCCESS,
OP_LOGE(
context->GetNodeName(),
"check pack input tensor shape failed, please make sure all input tensor shape same."),
return ge::GRAPH_FAILED);
auto inputDesc = context->GetDynamicInputDesc(PACK_INPUT_IDX, 0);
OP_CHECK_NULL_WITH_CONTEXT(context, inputDesc);
auto inputDataType = inputDesc->GetDataType();
OP_CHECK_IF(
IsInvalidTypeForPack(inputDataType),
OP_LOGE(
context->GetNodeName(),
"input dtype only support uint8, int8, bool, float32, int32, uint32, int16, float16, bfloat16, uint16, "
"int64,"
"uint64, doulbe, complex32, complex64 currently, please check."),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
GetDtypeSize(context, param, PACK_INPUT_IDX) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "GetDtypeSize failed."), return ge::GRAPH_FAILED);
GetTensorListForPack(context, param);
OP_CHECK_IF(
CalcBaseTilingParam(context, param) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "CalcBaseTilingParam failed."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
DoTiling(context, param) != ge::GRAPH_SUCCESS, OP_LOGE(context->GetNodeName(), "DoTiling failed."),
return ge::GRAPH_FAILED);
context->SetTilingKey(param.tilingKey);
context->SetBlockDim(param.usedCoreNum);
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
currentWorkspace[0] = SYSTEM_WORKSPACE_SIZE;
ge::graphStatus ret = ge::GRAPH_SUCCESS;
if (param.blockSplitAxis == 1) {
ConcatTilingData tilingData;
SetTilingData<ConcatTilingData>(tilingData, param);
SetTensorListTilingData(tilingData, param);
PrintTilingData(tilingData, param.tilingKey, param.usedCoreNum);
ret = ConcatSetTilingData<ConcatTilingData>(context, tilingData);
} else {
ConcatTilingDataNoArray tilingData;
SetTilingData<ConcatTilingDataNoArray>(tilingData, param);
PrintTilingData(tilingData, param.tilingKey, param.usedCoreNum);
ret = ConcatSetTilingData<ConcatTilingDataNoArray>(context, tilingData);
}
OP_CHECK_IF(
ret != ge::GRAPH_SUCCESS, OP_LOGE(context->GetNodeName(), "PackSetTilingData set tiling data fail."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GetConcatDim(gert::TilingContext* context, ConcatTilingParam& param, int64_t dimIdx)
{
if (dimIdx == INVLID_CONCAT_DIM_IDX) {
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const int64_t* axis = attrs->GetAttrPointer<int64_t>(0);
OP_CHECK_NULL_WITH_CONTEXT(context, axis);
param.dim = *axis;
} else {
auto concatDimPtr = context->GetRequiredInputDesc(dimIdx);
OP_CHECK_NULL_WITH_CONTEXT(context, concatDimPtr);
ge::DataType concatDimType = concatDimPtr->GetDataType();
if (concatDimType == ge::DT_INT32) {
OP_CHECK_IF(
GetConcatDimInput<int32_t>(context, param, dimIdx) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "get concat_dim failed."), return ge::GRAPH_FAILED);
} else {
OP_CHECK_IF(
GetConcatDimInput<int64_t>(context, param, dimIdx) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "get concat_dim failed."), return ge::GRAPH_FAILED);
}
}
return ge::GRAPH_SUCCESS;
}
gert::Shape GetShapeByAll(const gert::TilingContext* context, bool isNonContiguous, int inputIdx, int index)
{
auto inputTensorShapePtr = context->GetDynamicInputShape(inputIdx, index);
if (isNonContiguous) {
return inputTensorShapePtr->GetShape();
} else {
return inputTensorShapePtr->GetStorageShape();
}
}
bool IsAllContiguous(gert::TilingContext* context, ConcatTilingParam ¶m, int64_t inputIdx)
{
auto computeNodeInfo = context->GetComputeNodeInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, computeNodeInfo);
auto anchorInstanceInfo = computeNodeInfo->GetInputInstanceInfo(inputIdx);
OP_CHECK_NULL_WITH_CONTEXT(context, anchorInstanceInfo);
param.tensorNum = anchorInstanceInfo->GetInstanceNum();
for (int16_t i = 0; i < param.tensorNum; ++i) {
bool isViewI = context->DynamicInputIsView(inputIdx, i);
auto nonStrideI = context->GetDynamicInputStride(inputIdx, i);
if (isViewI && nonStrideI != nullptr && nonStrideI->GetDimNum() > 0) {
return false;
}
}
return true;
}
ge::graphStatus CheckNonConBasic(gert::TilingContext* context, ConcatTilingParam ¶m)
{
OP_CHECK_IF(param.tensorNum <= 1 || param.tensorNum > NON_CON_TENSOR_SIZE,
OP_LOGE(context->GetNodeName(),
"input tensor number should be at least 1 and less than 32."), return ge::GRAPH_FAILED);
OP_CHECK_IF(param.strideDim < 0,
OP_LOGE(context->GetNodeName(),
"non contiguous scenarious only support concat dim with greater than 0 or greater than -8."), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CheckNonContiguous(gert::TilingContext* context, ConcatTilingParam ¶m, int64_t inputIdx)
{
if (CheckNonConBasic(context, param) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
auto input0Desc = context->GetDynamicInputDesc(inputIdx, 0);
OP_CHECK_NULL_WITH_CONTEXT(context, input0Desc);
auto input0DataType = input0Desc->GetDataType();
for (int64_t i = 0; i < param.tensorNum; i++) {
int64_t allData = 1;
auto inputIDesc = context->GetDynamicInputDesc(inputIdx, i);
OP_CHECK_NULL_WITH_CONTEXT(context, inputIDesc);
OP_CHECK_IF(inputIDesc->GetDataType() != input0DataType,
OP_LOGE(context->GetNodeName(), "non contiguous scenarious only support identical data type."), return ge::GRAPH_FAILED);
bool isViewI = context->DynamicInputIsView(inputIdx, i);
auto nonStrideI = context->GetDynamicInputStride(inputIdx, i);
if (isViewI && nonStrideI != nullptr && nonStrideI->GetDimNum() > 0) {
OP_CHECK_IF(nonStrideI->GetStride(param.tensorList[i].size() - 1) != 1,
OP_LOGE(context->GetNodeName(),
"non contiguous scenarious, only the axis immediately preceding the concat dim is allowed to be non contiguous, while all other axis must be contiguous."),
return ge::GRAPH_FAILED);
for (int32_t j = param.tensorList[i].size() - 2; j >= 0; j--) {
if (param.strideDim != j) {
OP_CHECK_IF(nonStrideI->GetStride(j) != nonStrideI->GetStride(j + 1) * param.tensorList[i][j + 1],
OP_LOGE(context->GetNodeName(),
"non contiguous scenarious, only the axis immediately preceding the concat dim is allowed to be non contiguous, while all other axis must be contiguous."),
return ge::GRAPH_FAILED);
}
}
param.strideList[i] = static_cast<uint64_t>(nonStrideI->GetStride(param.strideDim));
} else {
param.strideList[i] = MergeDim(param.tensorList[i], param.strideDim + 1, param.tensorList[i].size());
}
param.concatDimList[i] = static_cast<uint32_t>(param.tensorList[i][param.dim]);
allData = MergeDim(param.tensorList[i], 0, param.tensorList[i].size());
OP_CHECK_IF(!(param.tensorListDim1[i] * param.dtypeSize >= SMALL_BAG)
&& !(param.strideList[i] * param.dtypeSize > SMALL_BAG)
&& !(allData * param.dtypeSize < param.totalCoreNum * ALL_DATA_SMALL),
OP_LOGE(context->GetNodeName(),
"In non contiguous scenarios, either the combined size of the concat dim and subsequent dim is at least 128 bytes, "
"or the stride of the non contiguous axis is greater than 128 bytes, or the total data size of the tensor is less than 8192 bytes multiplied by the core number."),
return ge::GRAPH_FAILED);
}
param.isNonContiguous = true;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingCommon(gert::TilingContext* context, int64_t inputIdx, int64_t dimIdx)
{
ConcatTilingParam param;
OP_CHECK_IF(
GetConcatDim(context, param, dimIdx) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "get concat_dim failed."), return ge::GRAPH_FAILED);
param.isNonContiguous = !(IsAllContiguous(context, param, inputIdx));
OP_CHECK_IF(
IsDimValid(context, param.dim, inputIdx, param.isNonContiguous, param.strideDim) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "check concat_dim failed, please check concat_dim."), return ge::GRAPH_FAILED);
auto inputDesc = context->GetDynamicInputDesc(inputIdx, 0);
OP_CHECK_NULL_WITH_CONTEXT(context, inputDesc);
auto inputDataType = inputDesc->GetDataType();
OP_CHECK_IF(
IsInvalidType(inputDataType),
OP_LOGE(
context->GetNodeName(),
"input dtype only support uint8, int8, bool, float32, int32, uint32, int16, float16, bfloat16, uint16, "
"int64,"
"uint64, double, complex64, HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN、FLOAT8_E8M0 FLOAT4_E1M2、FLOAT4_E2M1 currently, please check."),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
GetDtypeSize(context, param, inputIdx) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "GetDtypeSize failed."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
GetTensorList(context, param, inputIdx) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "GetTensorList failed."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
IsShapeValid(context, param.tensorList, param.dim) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "check input shape failed."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
CalcBaseTilingParam(context, param) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "CalcBaseTilingParam failed."), return ge::GRAPH_FAILED);
if (param.isNonContiguous) {
OP_CHECK_IF(CheckNonContiguous(context, param, inputIdx) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "input tensor non contiguous validation failed."), return ge::GRAPH_FAILED);
}
if (IsEnableUsedSimt(param)) {
return TilingForConcatDSimt(context, param);
}
OP_CHECK_IF(
DoTiling(context, param) != ge::GRAPH_SUCCESS, OP_LOGE(context->GetNodeName(), "DoTiling failed."),
return ge::GRAPH_FAILED);
context->SetTilingKey(param.tilingKey);
context->SetBlockDim(param.usedCoreNum);
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
currentWorkspace[0] = SYSTEM_WORKSPACE_SIZE;
ge::graphStatus ret = ge::GRAPH_SUCCESS;
if (param.blockSplitAxis == 1) {
ConcatTilingData tilingData;
SetTilingData<ConcatTilingData>(tilingData, param);
SetTensorListTilingData(tilingData, param);
PrintTilingData(tilingData, param.tilingKey, param.usedCoreNum);
ret = ConcatSetTilingData<ConcatTilingData>(context, tilingData);
} else {
ConcatTilingDataNoArray tilingData;
SetTilingData<ConcatTilingDataNoArray>(tilingData, param);
PrintTilingData(tilingData, param.tilingKey, param.usedCoreNum);
ret = ConcatSetTilingData<ConcatTilingDataNoArray>(context, tilingData);
}
OP_CHECK_IF(
ret != ge::GRAPH_SUCCESS, OP_LOGE(context->GetNodeName(), "ConcatSetTilingData set tiling data fail."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus Tiling4ConcatForAscendC(gert::TilingContext* context)
{
OP_LOGD(context->GetNodeName(), "Tiling4ConcatForAscendC running begin");
return TilingCommon(context, 1, 0);
}
ge::graphStatus TilingPrepareForConcat(gert::TilingParseContext* context)
{
auto compileInfo = context->GetCompiledInfo<ConcatDCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF(
(compileInfo->totalCoreNum <= 0),
OP_LOGE(context->GetNodeName(), "TilingPrepareForConcat fail to get core num."), return ge::GRAPH_FAILED);
uint64_t ubSize;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
compileInfo->ubSize = static_cast<int64_t>(ubSize);
OP_CHECK_IF(
(compileInfo->ubSize <= 0), OP_LOGE(context->GetNodeName(), "TilingPrepareForConcat fail to get ub size."),
return ge::GRAPH_FAILED);
compileInfo->vectorLen = static_cast<int64_t>(Ops::Base::GetVRegSize(context));
OP_CHECK_IF(
(compileInfo->vectorLen <= 0), OP_LOGE(context->GetNodeName(), "TilingPrepareForConcat fail to get vectorLen."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(Concat)
.TilingInputsDataDependency({CONCAT_DIM_IDX})
.Tiling(Tiling4ConcatForAscendC)
.TilingParse<ConcatDCompileInfo>(TilingPrepareForConcat);
}