* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <set>
#include <vector>
#include <algorithm>
#include "register/op_impl_registry.h"
#include "split_v_tiling_arch35.h"
#include "util/const_util.h"
#include "atvoss/broadcast/broadcast_tiling.h"
#include "util/math_util.h"
#include "op_host/tiling_base_util.h"
#include "log/log.h"
using namespace gert;
using namespace ge;
using namespace AscendC;
namespace optiling {
constexpr int64_t HUGE_SPLIT_NUM = 4L * 1024;
constexpr int64_t HUGE_M = 64;
constexpr int64_t SIZE_SPLIT_INDEX = 1;
constexpr int64_t SPLIT_DIM_INDEX = 2;
constexpr int64_t SPLIT_DIM_INDEX_SAME_LEN = 0;
constexpr int64_t BASE_256 = 256;
constexpr int64_t BASE_512 = 512;
constexpr int64_t PURE_MOVE_BASE_LEN = 128;
constexpr int64_t PURE_MOVE_BASE_UB_SIZE = 48 * 1024;
constexpr int64_t PURE_MOVE_BASE_UB_SIZE_M1 = 24 * 1024;
constexpr int64_t USED_MIN_UB_SIZE = 8 * 1024;
constexpr uint64_t WORK_SPACE_SIZE = 16 * 1024 * 1024;
constexpr int64_t SPLIT_KEY_PURE_MOVE = 100;
constexpr int64_t SPLIT_KEY_UB_SPLIT = 101;
constexpr int64_t SPLITV_KEY_PURE_MOVE = 102;
constexpr int64_t SPLITV_KEY_UB_SPLIT_WITH_PURE_MOVE = 103;
constexpr int64_t SPLITV_KEY_PURE_MOVE_1 = 104;
constexpr int64_t SIMT_KEY = 200;
constexpr int64_t SIMT_SAME_LEN_KEY = 201;
constexpr int64_t SIMT_KEY_SPLIT_TENSOR = 202;
constexpr int64_t SIMT_SAME_LEN_KEY_SPLIT_TENSOR = 203;
constexpr int64_t BLOCK_SIZE = 32;
constexpr double PURE_MOVE_RATIO = 0.9;
constexpr double PURE_MOVE_RATIO_M1 = 0.8;
constexpr int64_t INT16_MAX_VALUE = 32767;
constexpr int64_t PREF_BRANCH_GSIZE = 15;
constexpr int64_t MAX_SIMT_PER_CORE_SIZE = 131072;
constexpr int32_t SIMT_THREAD_NUM = 1024;
constexpr int64_t MAX_NUM = 9223372036854775807;
constexpr int32_t BASE_TWO = 2;
constexpr int32_t BASE_FOUR = 4;
template <typename T>
std::string SplitVTiling::ArrayToString(const T *vec, size_t num) const
{
std::stringstream ss;
for (size_t i = 0; i < num; i++) {
ss << vec[i] << " ";
}
return ss.str();
}
template <typename T>
std::string SplitVTiling::VectorToString(const std::vector<T>& vec) const
{
std::stringstream ss;
for (size_t i = 0; i < vec.size(); i++) {
ss << vec[i] << " ";
}
return ss.str();
}
template <typename T>
std::string Shape2String(const T& shape)
{
std::ostringstream oss;
oss << "[";
if (shape.GetDimNum() > 0) {
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
oss << shape.GetDim(i) << ", ";
}
oss << shape.GetDim(shape.GetDimNum() - 1);
}
oss << "]";
return oss.str();
}
ge::graphStatus SplitVTiling::ModifySizeSplitList()
{
int32_t negNum = 0;
int32_t negIdx = 0;
int64_t sizeSplitsTotal = 0;
int64_t splitDimShape = inputShape_.GetDim(splitDim_);
for (int64_t i = 0; i < numSplit_; i++) {
if (oriSizeSplits_[i] < 0) {
negNum++;
negIdx = i;
continue;
}
sizeSplitsTotal += oriSizeSplits_[i];
nonZeroSplitCnt_ = (oriSizeSplits_[i] == 0) ? nonZeroSplitCnt_ : (nonZeroSplitCnt_ + 1);
pureOutIdx_ = (oriSizeSplits_[i] == 0) ? pureOutIdx_ : i;
}
OP_CHECK_IF((negNum == 0 && sizeSplitsTotal != splitDimShape),
OP_LOGE(context_->GetNodeName(), "sizeSplitsTotal:%ld not equal splitDimShape:%ld",
sizeSplitsTotal, splitDimShape), return ge::GRAPH_FAILED);
OP_CHECK_IF(negNum > 1,
OP_LOGE(context_->GetNodeName(), "sizeSplits negtive num:%d cannot be bigger than 1",
negNum), return ge::GRAPH_FAILED);
if (negNum > 0) {
OP_CHECK_IF(sizeSplitsTotal > splitDimShape,
OP_LOGE(context_->GetNodeName(),
"sizeSplitsTotal:%ld cannot be bigger than splitDimShape:%ld", sizeSplitsTotal, splitDimShape),
return ge::GRAPH_FAILED);
oriSizeSplits_[negIdx] = splitDimShape - sizeSplitsTotal;
nonZeroSplitCnt_ = (oriSizeSplits_[negIdx] == 0) ? nonZeroSplitCnt_ : (nonZeroSplitCnt_ + 1);
pureOutIdx_ = (oriSizeSplits_[negIdx] == 0) ? pureOutIdx_ : negIdx;
negIdx_ = negIdx;
negValue_ = oriSizeSplits_[negIdx];
}
int64_t firstElement = oriSizeSplits_[0];
for (size_t i = 1; i < oriSizeSplits_.size(); i++) {
if (oriSizeSplits_[i] != firstElement) {
isSameLenMode_ = false;
break;
}
}
return ge::GRAPH_SUCCESS;
}
template <typename T>
bool SplitVTiling::GetData(const gert::Tensor* tensor, std::vector<int64_t>& values) const
{
size_t shape_size = tensor->GetShapeSize();
values.resize(shape_size);
auto* tensor_data = tensor->GetData<T>();
if (tensor_data == nullptr) {
return false;
}
for (size_t i = 0; i < shape_size; i++) {
values[i] = static_cast<int64_t>(*(tensor_data + i));
}
return true;
}
bool SplitVTiling::GetSizeSplitsList()
{
const gert::Tensor* sizeSplitsT = context_->GetInputTensor(SIZE_SPLIT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, sizeSplitsT);
if (sizeSplitsT->GetDataType() == ge::DT_INT32) {
isInt32_ = 1L;
return GetData<int32_t>(sizeSplitsT, oriSizeSplits_);
} else if (sizeSplitsT->GetDataType() == ge::DT_INT64) {
return GetData<int64_t>(sizeSplitsT, oriSizeSplits_);
} else if (sizeSplitsT->GetDataType() == ge::DT_UINT32) {
isInt32_ = 1L;
return GetData<uint32_t>(sizeSplitsT, oriSizeSplits_);
} else if (sizeSplitsT->GetDataType() == ge::DT_UINT64) {
return GetData<uint64_t>(sizeSplitsT, oriSizeSplits_);
}
OP_LOGE(context_->GetNodeType(), "Get const input error, unsupported data type.");
return false;
}
ge::graphStatus SplitVTiling::GetInputParams()
{
auto xInput = context_->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, xInput);
const gert::Shape& xInputShape = Ops::Base::EnsureNotScalar(xInput->GetStorageShape());
inputShape_ = xInputShape;
OP_CHECK_IF(!GetSizeSplitsList(),
OP_LOGE(context_->GetNodeName(), "SplitV tiling get size_splits failed"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(!Ops::Base::GetConstInt(context_, SPLIT_DIM_INDEX, splitDim_),
OP_LOGE(context_->GetNodeName(), "SplitV tiling get split_dim failed"),
return ge::GRAPH_FAILED);
int64_t inputXDimNum = static_cast<int64_t>(inputShape_.GetDimNum());
OP_CHECK_IF((splitDim_ < -inputXDimNum || splitDim_ >= inputXDimNum),
OP_LOGE(context_->GetNodeName(),
"SplitDim:%ld value out range, inputXDimNum:%ld", splitDim_, inputXDimNum), return ge::GRAPH_FAILED);
if (splitDim_ < 0) {
splitDim_ = splitDim_ + inputXDimNum;
}
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
auto attr0 = attrs->GetInt(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, attr0);
numSplit_ = *attr0;
OP_CHECK_IF(numSplit_ < 0,
OP_LOGE(context_->GetNodeName(), "numSplit_:%ld must be greater than 0", numSplit_),
return ge::GRAPH_FAILED);
int64_t sizeSplitsNum = static_cast<int64_t>(oriSizeSplits_.size());
OP_CHECK_IF(numSplit_ != sizeSplitsNum,
OP_LOGE(context_->GetNodeName(), "attr numSplit_:%ld not equal sizeSplitsNum:%ld",
numSplit_, sizeSplitsNum), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus SplitVTiling::InitParams(int32_t maxCoreNum, uint32_t ubSize)
{
coreNum_ = std::min(maxCoreNum, static_cast<int32_t>(MAX_CORE_COUNT));
ubSize_ = ubSize;
auto xInputDesc = context_->GetInputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, xInputDesc);
ge::DataType xDtype = xInputDesc->GetDataType();
xDtypeSize_ = ge::GetSizeByDataType(xDtype);
if (GetInputParams() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (ModifySizeSplitList() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
OP_LOGI(context_->GetNodeName(), "coreNum_:%d ubSize_:%d xDtypeSize_:%ld inputShape_:%s "
"splitDim_:%ld numSplit_:%ld oriSizeSplits_:%s nonZeroSplitCnt_:%ld pureOutIdx_:%ld",
coreNum_, ubSize_, xDtypeSize_, Shape2String(inputShape_).c_str(), splitDim_, numSplit_,
VectorToString(oriSizeSplits_).c_str(), nonZeroSplitCnt_, pureOutIdx_);
return ge::GRAPH_SUCCESS;
}
void SplitVTiling::FuseAllShape()
{
fusedShape_[0] = 1;
fusedShape_[1] = inputShape_.GetShapeSize();
OP_LOGI(context_->GetNodeName(), "After fusedShape_:%ld %ld, pureOutIdx_:%ld",
fusedShape_[0], fusedShape_[1], pureOutIdx_);
}
void SplitVTiling::SetPureMoveTilingMode()
{
int64_t cores = std::min(static_cast<int64_t>(coreNum_), Ops::Base::CeilDiv(fusedShape_[0] * fusedShape_[1] * xDtypeSize_, BASE_512));
SetBlockSplitInfo(1, cores);
tilingKey_ = SPLITV_KEY_PURE_MOVE_1;
}
void SplitVTiling::FuseInputShape()
{
int64_t outerSize = 1;
int64_t innerSize = 1;
for (size_t i = 0; i < static_cast<size_t>(splitDim_); i++) {
outerSize *= inputShape_.GetDim(i);
}
for (size_t i = splitDim_ + 1; i < inputShape_.GetDimNum(); i++) {
innerSize *= inputShape_.GetDim(i);
}
splitStride_ = innerSize;
fusedShape_[0] = outerSize;
fusedShape_[1] = inputShape_.GetDim(splitDim_) * innerSize;
sizeSplits_.resize(numSplit_);
for (int64_t i = 0; i < numSplit_; i++) {
sizeSplits_[i] = oriSizeSplits_[i] * innerSize;
}
OP_LOGI(context_->GetNodeName(), "After fused shape, splitStride_:%ld fusedShape_:%ld %ld, sizeSplits_:%s",
splitStride_, fusedShape_[0], fusedShape_[1], VectorToString(sizeSplits_).c_str());
}
void SplitVTiling::CalEmptyInputTiling()
{
realCoreNum_ = 0;
blockDim_ = 1;
tilingKey_ = SPLITV_KEY_PURE_MOVE;
}
void SplitVTiling::FindAllPossibleCutCnt(std::set<int64_t>& cutCountSet, int64_t cores) const
{
int64_t upBound = static_cast<int64_t>(std::ceil(std::sqrt(cores) + 1.0));
int64_t n = 1;
for (int64_t m = 1; m < upBound; m++) {
n = cores / m;
cutCountSet.insert(m);
cutCountSet.insert(n);
}
}
void SplitVTiling::GetAllCutInfoNode(int64_t M, int64_t N, int64_t cores,
const std::set<int64_t>& cutCountSet, std::vector<DualSplitNode>& splitNodeList) const
{
int64_t n = 0;
int64_t mFactor = 0;
int64_t nFactor = 0;
int64_t delta = 0;
int64_t blockNum = BLOCK_SIZE / xDtypeSize_;
for (int64_t m : cutCountSet) {
if (m <= 0) {
continue;
}
n = cores / m;
if (m > M || n > N || n <= 0) {
continue;
}
mFactor = Ops::Base::CeilDiv(M, m);
nFactor = Ops::Base::CeilDiv(N, n);
delta = mFactor * nFactor;
if (m * n == cores) {
if (M % m == 0 && N % n == 0) {
delta = 0;
} else if (M % m == 0) {
delta = delta - mFactor * (N / n);
} else if (N % n == 0) {
delta = delta - nFactor * (M / m);
} else {
delta = delta - (M / m) * (N / n);
}
}
if ((M / m) < blockNum && M > blockNum && splitNodeList.size() > 0) {
continue;
}
splitNodeList.push_back({m, n, m * n, delta});
OP_LOGD(context_->GetNodeName(), "m:%ld n:%ld delta:%ld", m, n, delta);
}
}
void SplitVTiling::ChooseBestSplitInfo(std::vector<DualSplitNode>& splitNodeList, DualSplitNode& splitInfo) const
{
std::sort(splitNodeList.begin(), splitNodeList.end());
splitInfo.m = splitNodeList.front().m;
splitInfo.n = splitNodeList.front().n;
splitInfo.t = splitNodeList.front().t;
splitInfo.delta = splitNodeList.front().delta;
for (const auto& node : splitNodeList) {
splitInfo.m =
numSplit_ > HUGE_SPLIT_NUM && fusedShape_[0] >= HUGE_M && node.m == SPLIT_DIM_INDEX ? node.m : splitInfo.m;
splitInfo.n =
numSplit_ > HUGE_SPLIT_NUM && fusedShape_[0] >= HUGE_M && node.m == SPLIT_DIM_INDEX ? node.n : splitInfo.n;
splitInfo.t =
numSplit_ > HUGE_SPLIT_NUM && fusedShape_[0] >= HUGE_M && node.m == SPLIT_DIM_INDEX ? node.t : splitInfo.t;
splitInfo.delta = numSplit_ > HUGE_SPLIT_NUM && fusedShape_[0] >= HUGE_M && node.m == SPLIT_DIM_INDEX ?
node.delta :
splitInfo.delta;
}
OP_LOGD(context_->GetNodeName(), "choose best split m:%ld n:%ld t:%ld delta:%ld",
splitInfo.m, splitInfo.n, splitInfo.t, splitInfo.delta);
}
void SplitVTiling::CalBlockSplitTwoAxis(int64_t rowM, int64_t colN, int64_t dataBytes, int64_t coreNum,
DualSplitNode& splitInfo) const
{
int64_t M = rowM;
int64_t N = Ops::Base::CeilDiv(colN, BASE_512 / dataBytes);
int64_t cores = std::min(coreNum, Ops::Base::CeilDiv(rowM * colN * dataBytes, USED_MIN_UB_SIZE));
OP_LOGD(context_->GetNodeName(), "get split info rowM:%ld colN:%ld M:%ld N:%ld cores:%ld", rowM, colN, M, N, cores);
std::set<int64_t> cutCountSet;
std::vector<DualSplitNode> splitNodeList;
FindAllPossibleCutCnt(cutCountSet, cores);
GetAllCutInfoNode(M, N, cores, cutCountSet, splitNodeList);
ChooseBestSplitInfo(splitNodeList, splitInfo);
if (splitInfo.m > rowM || splitInfo.n > colN) {
splitInfo.m = std::min(splitInfo.m, rowM);
splitInfo.n = std::min(splitInfo.n, colN);
}
if (splitInfo.m == 0 || splitInfo.n == 0) {
splitInfo.m = std::min(cores, rowM);
splitInfo.n = 1;
}
}
void SplitVTiling::SetBlockSplitInfo(int64_t mBlockCnt, int64_t nBlockCnt)
{
mBlockCount_ = mBlockCnt;
mBlockFactor_ = Ops::Base::CeilDiv(fusedShape_[0], mBlockCount_);
mBlockFactorCount_ = (fusedShape_[0] % mBlockCount_ == 0) ? mBlockCount_ : (fusedShape_[0] % mBlockCount_);
mBlockFactorTail_ = fusedShape_[0] / mBlockCount_;
nBlockCount_ = nBlockCnt;
nBlockFactor_ = Ops::Base::CeilDiv(fusedShape_[1], nBlockCount_);
nBlockFactorCount_ = (fusedShape_[1] % nBlockCount_ == 0) ? nBlockCount_ : (fusedShape_[1] % nBlockCount_);
nBlockFactorTail_ = fusedShape_[1] / nBlockCount_;
realCoreNum_ = mBlockCount_ * nBlockCount_;
blockDim_ = realCoreNum_;
OP_LOGI(context_->GetNodeName(), "Get block split TotalNum-BlockCnt-MainFactor-MainCnt-TailFactor, "
"M:%ld %ld %ld %ld %ld, N:%ld %ld %ld %ld %ld",
fusedShape_[0], mBlockCount_, mBlockFactor_, mBlockFactorCount_, mBlockFactorTail_,
fusedShape_[1], nBlockCount_, nBlockFactor_, nBlockFactorCount_, nBlockFactorTail_);
}
void SplitVTiling::CalBlockTilingParams()
{
DualSplitNode splitInfo;
CalBlockSplitTwoAxis(fusedShape_[0], fusedShape_[1], xDtypeSize_, coreNum_, splitInfo);
SetBlockSplitInfo(splitInfo.m, splitInfo.n);
}
void SplitVTiling::CalcSplitLenCond(int64_t curSplitNumN, int64_t& condNum, int64_t& total) const
{
if (curSplitNumN * xDtypeSize_ > PURE_MOVE_BASE_LEN) {
condNum++;
}
total++;
}
void SplitVTiling::CalcSplitUBSizeCond(int64_t curSplitNumN, int64_t& condNum, int64_t& total)
{
int64_t limitUb = fusedShape_[0] == 1 ? PURE_MOVE_BASE_UB_SIZE_M1 : PURE_MOVE_BASE_UB_SIZE;
if (mBlockFactor_ * curSplitNumN * xDtypeSize_ > limitUb) {
condNum += mBlockFactorCount_;
}
if (mBlockFactorTail_ * curSplitNumN * xDtypeSize_ > limitUb) {
condNum += (mBlockCount_ - mBlockFactorCount_);
}
total += mBlockCount_;
}
int64_t SplitVTiling::UpdataNextBlockFactor(int64_t sizeSplitIdx, int64_t handedSplitSize,
int64_t leftSplitSize, int64_t& handedBlockNumN)
{
int64_t leftBlockFactorN = (handedSplitSize >= nBlockFactor_ * nBlockFactorCount_) ? \
nBlockFactorTail_ : nBlockFactor_;
nBlockSplitOffsetEnd_[handedBlockNumN] = sizeSplitIdx + 1;
handedBlockNumN++;
if (handedBlockNumN < MAX_CORE_COUNT && handedBlockNumN < nBlockCount_) {
nBlockSplitOffsetStart_[handedBlockNumN] = (leftSplitSize == 0) ? (sizeSplitIdx + 1) : sizeSplitIdx;
}
return leftBlockFactorN;
}
void SplitVTiling::UpdateMAxisSplitOffset()
{
for (int64_t i = 1; i < mBlockCount_; i++) {
if (memcpy_s(nBlockSplitOffsetStart_ + i * nBlockCount_, (MAX_CORE_COUNT - i * nBlockCount_) * sizeof(int64_t),
nBlockSplitOffsetStart_, nBlockCount_ * sizeof(int64_t)) != EOK) {
OP_LOGE(context_->GetNodeType(),
"memcpy_s offset start failed, i:%ld mBlockCount_:%ld nBlockCount_:%ld",
i, mBlockCount_, nBlockCount_);
}
if (memcpy_s(nBlockSplitOffsetEnd_ + i * nBlockCount_, (MAX_CORE_COUNT - i * nBlockCount_) * sizeof(int64_t),
nBlockSplitOffsetEnd_, nBlockCount_ * sizeof(int64_t)) != EOK) {
OP_LOGE(context_->GetNodeType(),
"memcpy_s offset end failed, i:%ld mBlockCount_:%ld nBlockCount_:%ld",
i, mBlockCount_, nBlockCount_);
}
}
}
void SplitVTiling::CalcTilingKey(double condN, double condM, double totalCompareN, double totalCompareM)
{
double nowRatio = (fusedShape_[0] == 1) ? PURE_MOVE_RATIO_M1 : PURE_MOVE_RATIO;
if ((condN > totalCompareN * nowRatio && condM > totalCompareM * nowRatio) || numSplit_ == 1) {
tilingKey_ = SPLITV_KEY_PURE_MOVE;
} else {
tilingKey_ = SPLITV_KEY_UB_SPLIT_WITH_PURE_MOVE;
}
OP_LOGD(context_->GetNodeName(), "condN:%f condM:%f totalCompareN:%f totalCompareM:%f tilingKey_:%ld numSplit_:%ld",
condN, condM, totalCompareN, totalCompareM, tilingKey_, numSplit_);
}
int64_t SplitVTiling::SplitPrefix(int64_t i)
{
int64_t shape = 0;
for (int64_t index = 0; index < i + 1; index++) {
if (index < numSplit_) {
shape += oriSizeSplits_[index] * splitStride_;
}
}
return shape;
}
void SplitVTiling::CountSplitPrefix()
{
for (int64_t i = 0; i < MAX_CORE_COUNT; i++) {
nBlockSplitPrefixStart_[i] = SplitPrefix(nBlockSplitOffsetStart_[i]);
nBlockSplitPrefixEnd_[i] = SplitPrefix(nBlockSplitOffsetEnd_[i]);
}
}
void SplitVTiling::FillSIMTSplitVTilingData()
{
OP_LOGD(context_->GetNodeName(), "Entering FillSIMTSplitVTilingData.");
simtSplitVTilingData_.set_mSize(static_cast<int32_t>(mSimtSize_));
simtSplitVTilingData_.set_nSize(static_cast<int32_t>(nSimtSize_));
simtSplitVTilingData_.set_sizeAfterSplit(static_cast<int32_t>(simtSizeAfterSplit_));
simtSplitVTilingData_.set_splitNum(static_cast<int32_t>(numSplit_));
simtSplitVTilingData_.set_realCoreNum(static_cast<int32_t>(realCoreNum_));
simtSplitVTilingData_.set_colOffset(colOffset_);
OP_LOGI(context_->GetNodeName(), "TilingData mSize:%d nSize:%d sizeAfterSplit:%d splitNum:%d realCoreNum:%d ",
simtSplitVTilingData_.get_mSize(), simtSplitVTilingData_.get_nSize(), simtSplitVTilingData_.get_sizeAfterSplit(),
simtSplitVTilingData_.get_splitNum(), simtSplitVTilingData_.get_realCoreNum());
OP_LOGI(context_->GetNodeName(), "TilingData colOffset: %s",
ArrayToString(simtSplitVTilingData_.get_colOffset(), realCoreNum_).c_str());
simtSplitVTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(simtSplitVTilingData_.GetDataSize());
}
void SplitVTiling::SetTilingMode()
{
int64_t leftSplitSize = 0;
int64_t handedSplitSize = 0;
int64_t leftBlockFactorN = nBlockFactor_;
int64_t handedBlockNumN = 0;
int64_t condNum[BASE_2] = {0, 0};
int64_t totalCompareN = 0;
int64_t totalCompareM = 0;
for (int64_t i = 0; i < numSplit_; i++) {
leftSplitSize = sizeSplits_[i];
while (leftSplitSize > leftBlockFactorN) {
CalcSplitLenCond(leftBlockFactorN, condNum[0], totalCompareN);
CalcSplitUBSizeCond(leftBlockFactorN, condNum[1], totalCompareM);
leftSplitSize -= leftBlockFactorN;
handedSplitSize += leftBlockFactorN;
leftBlockFactorN = UpdataNextBlockFactor(i, handedSplitSize, leftSplitSize, handedBlockNumN);
}
if (leftSplitSize <= 0) {
continue;
}
CalcSplitLenCond(leftSplitSize, condNum[0], totalCompareN);
CalcSplitUBSizeCond(leftSplitSize, condNum[1], totalCompareM);
leftBlockFactorN -= leftSplitSize;
handedSplitSize += leftSplitSize;
leftSplitSize = 0;
if (leftBlockFactorN == 0) {
leftBlockFactorN = UpdataNextBlockFactor(i, handedSplitSize, leftSplitSize, handedBlockNumN);
}
}
OP_LOGD(context_->GetNodeName(), "handedSplitSize:%ld handedBlockNumN:%ld nBlockCount_:%ld",
handedSplitSize, handedBlockNumN, nBlockCount_);
UpdateMAxisSplitOffset();
CountSplitPrefix();
CalcTilingKey(condNum[0], condNum[1], totalCompareN, totalCompareM);
}
int64_t SplitVTiling::CalInputDataSize() {
for (size_t i = 0; i < static_cast<size_t>(splitDim_); i++) {
mSimtSize_ *= inputShape_.GetDim(i);
}
for (size_t i = splitDim_; i < inputShape_.GetDimNum(); i++) {
nSimtSize_ *= inputShape_.GetDim(i);
}
if (inputShape_.GetDim(splitDim_) <= 0) {
return false;
}
simtSizeAfterSplit_ = nSimtSize_ / inputShape_.GetDim(splitDim_);
return xDtypeSize_ * mSimtSize_ * nSimtSize_;
}
void SplitVTiling::DoDiffLenModeSplitVSIMTTiling(int32_t maxCoreNum)
{
if (numSplit_ <= static_cast<int64_t>(maxCoreNum / BASE_TWO) && static_cast<int64_t>(mSimtSize_ * maxFuseSizeAfterSplit_) > SIMT_THREAD_NUM) {
int32_t perBlkSize = static_cast<int32_t>(maxCoreNum / numSplit_);
int32_t preSum = 0;
int32_t curSplitLen = 0;
for (int32_t i = 0; i < numSplit_; i++) {
int32_t curNLen = oriSizeSplits_[i] * simtSizeAfterSplit_;
int32_t curBlkFactor = Ops::Base::CeilDiv(curNLen, perBlkSize);
for (int32_t j = 0; j < perBlkSize; j++) {
if (curNLen >= curBlkFactor) {
curSplitLen = curBlkFactor;
curNLen -= curSplitLen;
} else {
curSplitLen = curNLen;
curNLen -= curSplitLen;
}
preSum += curSplitLen;
colOffset_[i * perBlkSize + j] = static_cast<uint32_t>(preSum);
}
}
realCoreNum_ = perBlkSize * numSplit_;
tilingKey_ = SIMT_KEY_SPLIT_TENSOR;
} else {
int32_t preSum = 0;
for (int32_t i = 0; i < numSplit_; i++) {
preSum += oriSizeSplits_[i] * simtSizeAfterSplit_;
colOffset_[i] = static_cast<uint32_t>(preSum);
}
realCoreNum_ = numSplit_ <= maxCoreNum ? numSplit_ : maxCoreNum;
tilingKey_ = SIMT_KEY;
}
FillSIMTSplitVTilingData();
}
ge::graphStatus SplitVTiling::DoSplitVSIMTTiling(int32_t maxCoreNum)
{
OP_LOGD("SplitTilingForAscendC", "DoSplitVSIMTTiling start");
DoDiffLenModeSplitVSIMTTiling(maxCoreNum);
blockDim_ = realCoreNum_;
SetBlockDimAndTilingKey();
SetWorkspaceSize();
OP_LOGD("SplitTilingForAscendC", "DoSplitVSIMTTiling end");
return ge::GRAPH_SUCCESS;
}
void SplitVTiling::CalSplitSizeDiff(int32_t maxCoreNum, int64_t &maxSizeSplit,
int64_t &minSizeSplit, int32_t countPerCore)
{
if (numSplit_ <= maxCoreNum) {
for (size_t i = 0; i < oriSizeSplits_.size(); i++) {
if (oriSizeSplits_[i] == 0) {
continue;
}
maxSizeSplit = maxSizeSplit > oriSizeSplits_[i] ? maxSizeSplit : oriSizeSplits_[i];
minSizeSplit = minSizeSplit < oriSizeSplits_[i] ? minSizeSplit : oriSizeSplits_[i];
}
} else {
for (int32_t i = 0; i < maxCoreNum; i++) {
int32_t curSizeSplits = 0;
for (int32_t j = 0; j < countPerCore; j++) {
if (static_cast<int64_t>(j * maxCoreNum + i) >= numSplit_) {
break;
}
curSizeSplits += oriSizeSplits_[j * maxCoreNum + i];
}
if (curSizeSplits == 0) {
continue;
}
maxSizeSplit = maxSizeSplit > curSizeSplits ? maxSizeSplit : curSizeSplits;
minSizeSplit = minSizeSplit < curSizeSplits ? minSizeSplit : curSizeSplits;
}
}
maxFuseSizeAfterSplit_ = maxSizeSplit * simtSizeAfterSplit_;
minFuseSizeAfterSplit_ = minSizeSplit * simtSizeAfterSplit_;
}
bool SplitVTiling::IsDoSplitVSIMT(int32_t maxCoreNum) {
if (isSameLenMode_) {
return false;
}
if (numSplit_ > static_cast<int64_t>(MAX_COL_OFFSET_COUNT) || numSplit_ <= static_cast<int64_t>(maxCoreNum / BASE_FOUR)) {
return false;
}
if (CalInputDataSize() > maxCoreNum * MAX_SIMT_PER_CORE_SIZE) {
return false;
}
int32_t countPerCore = 1;
if (numSplit_ > static_cast<int64_t>(maxCoreNum / BASE_TWO) && numSplit_ <= static_cast<int64_t>(maxCoreNum)) {
countPerCore = 1;
} else if (numSplit_ > maxCoreNum) {
countPerCore = Ops::Base::CeilDiv(static_cast<int32_t>(numSplit_), maxCoreNum);
}
int64_t maxSizeSplit = 0;
int64_t minSizeSplit = MAX_NUM;
CalSplitSizeDiff(maxCoreNum, maxSizeSplit, minSizeSplit, countPerCore);
int32_t divPara = 1;
if (numSplit_ <= static_cast<int64_t>(maxCoreNum / BASE_TWO)) {
divPara = static_cast<int32_t>(maxCoreNum / numSplit_);
}
if (maxSizeSplit * maxCoreNum / inputShape_.GetDim(splitDim_) / divPara >= BASE_TWO &&
mSimtSize_ * (maxFuseSizeAfterSplit_ - minFuseSizeAfterSplit_) / divPara >= SIMT_THREAD_NUM) {
return false;
}
return true;
}
void SplitVTiling::FillTilingData()
{
OP_LOGD(context_->GetNodeName(), "Entering FillTilingData.");
tilingData_.set_ubSize(static_cast<int64_t>(ubSize_));
tilingData_.set_splitDim(splitDim_);
tilingData_.set_sizeAfterSplitDim(splitStride_);
tilingData_.set_mBlockFactor(mBlockFactor_);
tilingData_.set_mBlockFactorTail(mBlockFactorTail_);
tilingData_.set_mBlockFactorNum(mBlockFactorCount_);
tilingData_.set_mBlockCount(mBlockCount_);
tilingData_.set_gUBFactor(gUBFactor_);
tilingData_.set_gUBFactorTail(gUBFactorTail_);
tilingData_.set_gSize(numSplit_);
tilingData_.set_gUBCount(gUBCount_);
tilingData_.set_nBlockFactor(nBlockFactor_);
tilingData_.set_nBlockFactorAlign(Ops::Base::CeilAlign(nBlockFactor_, BLOCK_SIZE / xDtypeSize_));
tilingData_.set_nBlockFactorTail(nBlockFactorTail_);
tilingData_.set_nBlockFactorTailAlign(Ops::Base::CeilAlign(nBlockFactorTail_, BLOCK_SIZE / xDtypeSize_));
tilingData_.set_nBlockFactorNum(nBlockFactorCount_);
tilingData_.set_nBlockCount(nBlockCount_);
tilingData_.set_realCoreNum(realCoreNum_);
tilingData_.set_blockFactor(blockFactor_);
tilingData_.set_blockFactorTail(blockFactorTail_);
tilingData_.set_nSize(fusedShape_[1]);
tilingData_.set_nSizeAlign(Ops::Base::CeilAlign(fusedShape_[1], BLOCK_SIZE / xDtypeSize_));
tilingData_.set_pureOutIdx(pureOutIdx_);
tilingData_.set_negIdx(negIdx_);
tilingData_.set_negValue(negValue_);
tilingData_.set_isInt32(isInt32_);
tilingData_.set_mSize(fusedShape_[0]);
tilingData_.set_nBlockSplitOffset(nBlockSplitOffsetStart_);
tilingData_.set_nBlockSplitOffsetEnd(nBlockSplitOffsetEnd_);
tilingData_.set_nBlockSplitPrefixStart(nBlockSplitPrefixStart_);
tilingData_.set_nBlockSplitPrefixEnd(nBlockSplitPrefixEnd_);
tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize());
OP_LOGI(context_->GetNodeName(), "TilingData ubSize:%u splitDim:%ld sizeAfterSplitDim:%ld mBlockFactor:%ld "
"mBlockFactorTail:%ld mBlockFactorNum:%ld mBlockCount:%ld gUBFactor:%ld gUBFactorTail:%ld gSize:%ld gUBCount:%ld "
"nBlockFactor:%ld nBlockFactorAlign:%ld nBlockFactorTail:%ld nBlockFactorTailAlign:%ld nBlockFactorNum:%ld "
"nBlockCount:%ld realCoreNum:%ld blockFactor:%ld blockFactorTail:%ld nSize:%ld nSizeAlign:%ld "
"pureOutIdx_:%ld",
ubSize_, splitDim_, splitStride_, mBlockFactor_, mBlockFactorTail_, mBlockFactorCount_, mBlockCount_, gUBFactor_,
gUBFactorTail_, numSplit_, gUBCount_, nBlockFactor_, tilingData_.get_nBlockFactorAlign(), nBlockFactorTail_,
tilingData_.get_nBlockFactorTailAlign(), nBlockFactorCount_, nBlockCount_, realCoreNum_, blockFactor_,
blockFactorTail_, fusedShape_[1], tilingData_.get_nSizeAlign(), pureOutIdx_);
OP_LOGI(context_->GetNodeName(), "TilingData nBlockSplitOffset: %s",
ArrayToString(tilingData_.get_nBlockSplitOffset(), realCoreNum_).c_str());
OP_LOGI(context_->GetNodeName(), "TilingData nBlockSplitOffsetEnd:%s",
ArrayToString(tilingData_.get_nBlockSplitOffsetEnd(), realCoreNum_).c_str());
}
void SplitVTiling::SetBlockDimAndTilingKey() const
{
context_->SetBlockDim(blockDim_);
context_->SetTilingKey(tilingKey_);
OP_LOGI(context_->GetNodeName(), "Tiling blockDim_:%ld tilingKey_:%ld", blockDim_, tilingKey_);
}
void SplitVTiling::SetWorkspaceSize() const
{
size_t* workspaces = context_->GetWorkspaceSizes(1);
workspaces[0] = WORK_SPACE_SIZE;
}
ge::graphStatus SplitVTiling::DoSplitVTiling(int32_t maxCoreNum, uint32_t ubSize)
{
OP_CHECK_IF((InitParams(maxCoreNum, ubSize) != ge::GRAPH_SUCCESS),
OP_LOGE(context_->GetNodeName(), "SplitV Tiling InitParams Failed."),
return ge::GRAPH_FAILED);
if (IsDoSplitVSIMT(maxCoreNum)) {
return DoSplitVSIMTTiling(maxCoreNum);
}
if (numSplit_ != 0 && numSplit_ != 1 && isSameLenMode_) {
FuseInputShapeSameLen();
} else if (nonZeroSplitCnt_ == 1) {
FuseAllShape();
} else {
FuseInputShape();
}
if (numSplit_ != 0 && numSplit_ != 1 && isSameLenMode_) {
tilingKey_ = (numSplit_ == 1 || fusedShape_[1] >= (PURE_MOVE_BASE_LEN / xDtypeSize_)) ?
SPLIT_KEY_PURE_MOVE : SPLIT_KEY_UB_SPLIT;
CalcSameLenTilingInfo();
} else if (fusedShape_[0] == 0 || fusedShape_[1] == 0) {
CalEmptyInputTiling();
} else if (nonZeroSplitCnt_ == 1) {
SetPureMoveTilingMode();
} else {
CalBlockTilingParams();
SetTilingMode();
}
FillTilingData();
SetBlockDimAndTilingKey();
SetWorkspaceSize();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus SplitVTilingAscendC(gert::TilingContext* context, int32_t maxCoreNum, uint32_t ubSize,
int32_t isSameLen)
{
SplitVTiling tiling(context);
return isSameLen == 1 ? tiling.DoSplitTiling(maxCoreNum, ubSize) :
tiling.DoSplitVTiling(maxCoreNum, ubSize);
}
int64_t SplitVTiling::CeilDiv(int64_t value, int64_t factor) const {
int64_t valueNum = 0;
if (factor == 0) {
OP_LOGE(context_->GetNodeName(), "SplitTiling CeilDiv divideNum is 0!");
return value;
}
if (value % factor == 0) {
valueNum = value / factor;
} else {
valueNum = value / factor + 1;
}
return valueNum;
}
int64_t SplitVTiling::FloorAlign(int64_t value, int64_t factor) const {
int64_t valueNum = 0;
if (factor == 0) {
OP_LOGE(context_->GetNodeName(), "SplitTiling FloorAlign divideNum is 0!");
return value;
}
if (value % factor == 0 || value < factor) {
valueNum = value;
} else {
valueNum = value / factor * factor;
}
return valueNum;
}
ge::graphStatus SplitVTiling::GetInputParamsSameLen() {
auto xInput = context_->GetInputShape(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, xInput);
const gert::Shape& xInputShape = Ops::Base::EnsureNotScalar(xInput->GetStorageShape());
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
auto attr0 = attrs->GetInt(0);
OP_CHECK_NULL_WITH_CONTEXT(context_, attr0);
inputShape_ = xInputShape;
numSplit_ = *attr0;
int64_t inputXDimNum = static_cast<int64_t>(inputShape_.GetDimNum());
if (numSplit_ == 1) {
splitDim_ = 0;
} else {
OP_CHECK_IF(!Ops::Base::GetConstInt(context_, SPLIT_DIM_INDEX_SAME_LEN, splitDim_),
OP_LOGE(context_->GetNodeName(), "SplitTiling get split_dim failed"),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((splitDim_ < -inputXDimNum || splitDim_ >= inputXDimNum),
OP_LOGE(context_->GetNodeName(),
"SplitDim:%ld value out range, inputXDimNum:%ld", splitDim_, inputXDimNum), return ge::GRAPH_FAILED);
if (splitDim_ < 0) {
splitDim_ = splitDim_ + inputXDimNum;
}
OP_CHECK_IF(numSplit_ <= 0,
OP_LOGE(context_->GetNodeName(), "numSplit_:%ld must be greater than 0", numSplit_),
return ge::GRAPH_FAILED);
const int64_t dim = inputShape_.GetDim(splitDim_);
OP_CHECK_IF(dim % numSplit_ != 0,
OP_LOGE(context_->GetNodeName(),
"the split_dim dimension of x_shape must be divided by num_split!"
"split_dim is %ld, x_shape on split_dim is %ld",
splitDim_, dim),
return GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus SplitVTiling::InitParamsSameLen(int32_t maxCoreNum, uint32_t ubSize) {
coreNum_ = std::min(maxCoreNum, static_cast<int32_t>(MAX_CORE_COUNT));
ubSize_ = ubSize;
auto xInputDesc = context_->GetInputDesc(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, xInputDesc);
ge::DataType xDtype = xInputDesc->GetDataType();
xDtypeSize_ = ge::GetSizeByDataType(xDtype);
OP_CHECK_IF(xDtypeSize_ <= 0,
OP_LOGE(context_->GetNodeName(),
"xDtypeSize_:%ld must be greater than 0", xDtypeSize_),
return ge::GRAPH_FAILED);
if (GetInputParamsSameLen() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
void SplitVTiling::FuseInputShapeSameLen() {
int64_t mSize = 1;
int64_t nSize = inputShape_.GetDim(splitDim_) / numSplit_;
for (size_t i = 0; i < static_cast<size_t>(splitDim_); i++) {
mSize *= inputShape_.GetDim(i);
}
for (size_t i = splitDim_ + 1; i < inputShape_.GetDimNum(); i++) {
nSize *= inputShape_.GetDim(i);
}
fusedShape_[0] = mSize;
fusedShape_[1] = nSize;
mBlockFactorCount_ = fusedShape_[0];
nBlockFactorCount_ = fusedShape_[1];
OP_LOGI(context_->GetNodeName(), "After SameLen fused shape, fusedShape_:%ld %ld",
fusedShape_[0], fusedShape_[1]);
}
void SplitVTiling::CalcSameLenTilingInfo() {
OP_LOGD(context_->GetNodeName(), "Entering SplitTiling CalcSameLenTilingInfo.");
int64_t blockEleNum = BLOCK_SIZE / xDtypeSize_;
int64_t halfUbEleNum = static_cast<int64_t>(ubSize_ / BASE_2) / static_cast<int64_t>(xDtypeSize_);
int64_t quarterUbEleNum = ((static_cast<int64_t>(ubSize_)) / BASE_2 - BASE_256) / BASE_2 / xDtypeSize_;
int64_t ubOutSize = 0;
if (fusedShape_[1] > halfUbEleNum) {
mBlockFactor_ = 1;
mBlockCount_ = fusedShape_[0];
mBlockFactorTail_ = 1;
gUBFactor_ = 1;
gUBCount_ = numSplit_;
gUBFactorTail_ = 1;
nBlockFactor_ = FloorAlign(halfUbEleNum, blockEleNum);
nBlockCount_ = CeilDiv(fusedShape_[1], halfUbEleNum);
nBlockFactorTail_ = fusedShape_[1] % nBlockFactor_ == 0 ?
nBlockFactor_ : fusedShape_[1] % nBlockFactor_;
} else {
nBlockFactor_ = fusedShape_[1];
nBlockCount_ = 1;
nBlockFactorTail_ = fusedShape_[1];
if (tilingKey_ == SPLIT_KEY_PURE_MOVE) {
CalcSameLenCopyTilingInfo(halfUbEleNum, blockEleNum);
} else if (tilingKey_ == SPLIT_KEY_UB_SPLIT) {
CalcSameLenSplitTilingInfo(quarterUbEleNum, blockEleNum);
}
}
ubOutSize = mBlockCount_ * gUBCount_ * nBlockCount_;
realCoreNum_ = ubOutSize < coreNum_ ? ubOutSize : coreNum_;
blockDim_ = realCoreNum_;
blockFactor_ = ubOutSize / realCoreNum_;
blockFactorTail_ = ubOutSize % realCoreNum_;
}
void SplitVTiling::CalcSameLenCopyTilingInfo(int64_t halfUbEleNum, int64_t blockEleNum) {
int64_t spaceWithoutN = 0;
int64_t factorTmp = 0;
spaceWithoutN = halfUbEleNum / (CeilDiv(fusedShape_[1], blockEleNum) * blockEleNum);
factorTmp = std::floor(std::sqrt(spaceWithoutN));
mBlockFactor_ =
FloorAlign(factorTmp, blockEleNum) >= fusedShape_[0] ? fusedShape_[0] : FloorAlign(factorTmp, blockEleNum);
mBlockCount_ = CeilDiv(fusedShape_[0], mBlockFactor_);
mBlockFactorTail_ = fusedShape_[0] % mBlockFactor_ == 0 ? mBlockFactor_ : fusedShape_[0] % mBlockFactor_;
gUBFactor_ = factorTmp >= numSplit_ ? numSplit_ : factorTmp;
gUBCount_ = CeilDiv(numSplit_, gUBFactor_);
gUBFactorTail_ = numSplit_ % gUBFactor_ == 0 ? gUBFactor_ : numSplit_ % gUBFactor_;
}
void SplitVTiling::CalcSameLenSplitTilingInfo(int64_t halfUbEleNum, int64_t blockEleNum) {
int64_t factorTmp = 0;
int64_t nAlignNum = CeilDiv(fusedShape_[1], blockEleNum) * blockEleNum;
int64_t mnAlignNum = fusedShape_[0] * nAlignNum;
factorTmp = std::floor(std::sqrt(halfUbEleNum / fusedShape_[1]));
mBlockFactor_ =
FloorAlign(factorTmp, blockEleNum) >= fusedShape_[0] ? fusedShape_[0] : FloorAlign(factorTmp, blockEleNum);
mBlockCount_ = CeilDiv(fusedShape_[0], mBlockFactor_);
mBlockFactorTail_ = fusedShape_[0] % mBlockFactor_ == 0 ? mBlockFactor_ : fusedShape_[0] % mBlockFactor_;
gUBFactor_ = factorTmp - 1;
mnAlignNum = mBlockFactor_ * nAlignNum;
gUBFactor_ =
AdjustUbFactor(gUBFactor_, mnAlignNum) > numSplit_ ? numSplit_ : AdjustUbFactor(gUBFactor_, mnAlignNum);
gUBCount_ = CeilDiv(numSplit_, gUBFactor_);
gUBFactorTail_ = numSplit_ % gUBFactor_ == 0 ? gUBFactor_ : numSplit_ % gUBFactor_;
}
int64_t SplitVTiling::AdjustUbFactor(int64_t factor, int64_t alignFactor) const {
int64_t adjustFactor = factor;
if (xDtypeSize_ == BASE_2 && factor * alignFactor > INT16_MAX_VALUE) {
adjustFactor = INT16_MAX_VALUE / alignFactor;
} else if (xDtypeSize_ == 1 && factor * alignFactor > INT16_MAX_VALUE * BASE_2) {
adjustFactor = INT16_MAX_VALUE * BASE_2 / alignFactor;
}
return adjustFactor;
}
ge::graphStatus SplitVTiling::DoSplitTiling(int32_t maxCoreNum, uint32_t ubSize) {
OP_LOGD("SplitTilingForAscendC", "DoSplitTiling start");
OP_CHECK_IF((InitParamsSameLen(maxCoreNum, ubSize) != ge::GRAPH_SUCCESS),
OP_LOGE(context_->GetNodeName(), "SplitTiling InitParams Failed."),
return ge::GRAPH_FAILED);
FuseInputShapeSameLen();
int64_t nSwitchSize = PURE_MOVE_BASE_LEN / xDtypeSize_;
if (fusedShape_[0] == 0 || fusedShape_[1] == 0) {
tilingKey_ = SPLIT_KEY_PURE_MOVE;
realCoreNum_ = 0;
blockDim_ = 1;
} else {
tilingKey_ = (numSplit_ == 1 || fusedShape_[1] >= nSwitchSize) ?
SPLIT_KEY_PURE_MOVE : SPLIT_KEY_UB_SPLIT;
CalcSameLenTilingInfo();
}
FillTilingData();
SetBlockDimAndTilingKey();
SetWorkspaceSize();
OP_LOGD("SplitTilingForAscendC", "DoSplitTiling end");
return ge::GRAPH_SUCCESS;
}
graphStatus Tiling4SplitV(TilingContext* context) {
OP_LOGD(context->GetNodeName(), "begin to do Tiling4SplitV");
auto compile_info = static_cast<const SplitVCompileInfo*>(context->GetCompileInfo());
OP_LOGD(context->GetNodeName(), "AscendC splitV tiling start");
int32_t maxCoreNum = static_cast<int32_t>(compile_info->core_num);
uint32_t ubSizePlatform = compile_info->ubSizePlatform;
OP_CHECK_IF(ubSizePlatform <= 0, OP_LOGE(context->GetNodeName(),
"get ubSize <= 0 error"),
return GRAPH_FAILED);
int32_t isSameLen = 0;
OP_CHECK_IF(SplitVTilingAscendC(context, maxCoreNum, ubSizePlatform, isSameLen) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(),
"AscendC splitV tiling function call failed"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
graphStatus TilingPrepare4SplitV(TilingParseContext* context) {
auto compile_info = context->GetCompiledInfo<SplitVCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compile_info);
OP_LOGD(context->GetNodeName(), "AscendC splitV tiling prepare");
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compile_info->core_num = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF((compile_info->core_num <= 0),
OP_LOGE(context->GetNodeName(),
"The core num is invaild"),
return ge::GRAPH_FAILED);
uint64_t ubSize = 0;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
compile_info->ubSizePlatform = static_cast<uint32_t>(ubSize);
OP_CHECK_IF((compile_info->ubSizePlatform <= 0),
OP_LOGE(context->GetNodeName(), "The ubSize is invalid."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(SplitV).Tiling(Tiling4SplitV).TilingParse<SplitVCompileInfo>(TilingPrepare4SplitV);
}