* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file reduce_tiling.cpp
* \brief atvoss reduce template tiling
*/
#include "op_common/atvoss/reduce/reduce_tiling.h"
#include "op_common/op_host/util/math_util.h"
#include "reduce_tiling_batch_invariant.h"
namespace Ops {
namespace Base {
using namespace ReduceOpTmpl;
constexpr static int32_t MAX_INNER_A = 512;
constexpr static int32_t A_STEP_LEN = 4;
* Ensure that the returned shape is non-scalar.
* When the dim num of shape is 0, this shape is considered to express a scalar.
* This function returns the original shape when it receives a non-scalar shape,
* and returns the vector shape that returns a {1} when it receives a scalar shape
* @param in_shape input shape
* @return non-scalar shape
*/
inline const gert::Shape& EnsureNotScalar(const gert::Shape& inShape)
{
if (inShape.IsScalar()) {
static const gert::Shape scalarShape = {1};
return scalarShape;
}
return inShape;
}
namespace ReduceOpTmpl {
static bool CheckAllReduce(gert::TilingContext* context, int32_t outIdx)
{
auto out = context->GetOutputShape(outIdx);
OP_CHECK_IF(out == nullptr, OP_LOGE(context, "out is nullptr"), return false);
gert::Shape outShape = EnsureNotScalar(out->GetStorageShape());
return (outShape.GetShapeSize() == 1L);
}
static bool CheckIsContiguous(ReduceOpInputParam& opInput)
{
bool isContiguous{false};
for (size_t i = 0; i < opInput.shape.size(); i++) {
if (i != opInput.shape.size() - 1) {
isContiguous = (opInput.dimStrides[i] == opInput.dimStrides[i + 1] * opInput.shape[i + 1]);
} else {
isContiguous = (opInput.dimStrides[i] == 1UL);
}
}
return isContiguous;
}
ge::graphStatus GetInputShape(gert::TilingContext* context, int32_t idx, std::vector<int64_t>& shape)
{
auto xInput = context->GetInputShape(idx);
OP_CHECK_NULL_WITH_CONTEXT(context, xInput);
gert::Shape xInputShape;
if (context->InputIsView(idx)) {
xInputShape = EnsureNotScalar(xInput->GetShape());
} else {
xInputShape = EnsureNotScalar(xInput->GetStorageShape());
}
size_t shapeSize = xInputShape.GetDimNum();
OP_CHECK_IF((shapeSize >= static_cast<size_t>(MAX_DIM)),
OP_LOGE_FOR_INVALID_SHAPEDIM_WITH_REASON(context->GetNodeName(), "x", std::to_string(shapeSize).c_str(),
"shape dim size is over max dim, cannot support"),
return ge::GRAPH_FAILED);
shape.resize(shapeSize);
for (size_t i = 0; i < shapeSize; i++) {
shape[i] = xInputShape.GetDim(i);
if (shape[i] < 0) {
OP_LOGE_FOR_INVALID_SHAPEDIM_WITH_REASON(context->GetNodeName(), "x", std::to_string(shape[i]).c_str(),
"shape dim cannot support dynamic");
return ge::GRAPH_FAILED;
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GetInputStride(const gert::TilingContext* context, int32_t idx, std::vector<int64_t>& dimStrides)
{
auto xStride = context->GetInputStride(idx);
size_t shapeSize = 0UL;
if (xStride != nullptr) {
shapeSize = xStride->GetDimNum();
OP_CHECK_IF(
(shapeSize >= static_cast<size_t>(MAX_DIM)),
OP_LOGE_FOR_INVALID_SHAPEDIM_WITH_REASON(context->GetNodeName(), "x", std::to_string(shapeSize).c_str(),
"slice dim size is over max slice dim, cannot support"),
return ge::GRAPH_FAILED);
}
if (shapeSize == 0UL || !context->InputIsView(idx)) {
auto xInput = context->GetInputShape(idx);
OP_CHECK_NULL_WITH_CONTEXT(context, xInput);
gert::Shape xInputShape = EnsureNotScalar(xInput->GetStorageShape());
dimStrides.resize(xInputShape.GetDimNum());
int64_t stride = 1L;
for (int32_t i = xInputShape.GetDimNum() - 1; i >= 0; i--) {
dimStrides[i] = stride;
if (dimStrides[i] < 0) {
OP_LOGE_FOR_INVALID_SHAPEDIM_WITH_REASON(context->GetNodeName(), "x",
std::to_string(dimStrides[i]).c_str(),
"dimStrides dim size cannot support dynamic");
return ge::GRAPH_FAILED;
}
stride = stride * xInputShape.GetDim(i);
}
} else {
dimStrides.resize(shapeSize);
for (size_t i = 0; i < shapeSize; i++) {
dimStrides[i] = xStride->GetStride(i);
if (dimStrides[i] < 0) {
OP_LOGE_FOR_INVALID_SHAPEDIM_WITH_REASON(context->GetNodeName(), "x",
std::to_string(dimStrides[i]).c_str(),
"dimStrides dim size cannot support dynamic");
return ge::GRAPH_FAILED;
}
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GetInputDtype(gert::TilingContext* context, int32_t idx, ge::DataType& dtype)
{
auto inputDesc = context->GetInputDesc(idx);
OP_CHECK_NULL_WITH_CONTEXT(context, inputDesc);
dtype = inputDesc->GetDataType();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GetInputParam(gert::TilingContext* context, ReduceOpInputParam& opInput, int32_t inputIdx)
{
OP_CHECK_IF((GetInputDtype(context, inputIdx, opInput.inputDtype) == ge::GRAPH_FAILED),
OP_LOGE(context, "ReduceOpTmpl get input dtype failed"), return ge::GRAPH_FAILED);
OP_CHECK_IF((GetInputShape(context, inputIdx, opInput.shape) == ge::GRAPH_FAILED),
OP_LOGE(context, "ReduceOpTmpl get input shape failed"), return ge::GRAPH_FAILED);
OP_CHECK_IF(
(GetInputStride(context, inputIdx, opInput.dimStrides) == ge::GRAPH_FAILED),
OP_LOGE(context, "ReduceOpTmpl get input stride failed"), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GetInputParam(gert::TilingContext* context, ReduceOpInputParam& opInput, int32_t inputIdx,
int32_t axesIdx, int32_t outIdx)
{
ge::DataType axesDtype;
ge::graphStatus status = GetInputDtype(context, axesIdx, axesDtype);
OP_CHECK_IF((status == ge::GRAPH_FAILED), OP_LOGE(context, "ReduceOpTmpl get axes dtype failed"),
return ge::GRAPH_FAILED);
OP_CHECK_IF((GetInputDtype(context, inputIdx, opInput.inputDtype) == ge::GRAPH_FAILED),
OP_LOGE(context, "ReduceOpTmpl get input dtype failed"), return ge::GRAPH_FAILED);
OP_CHECK_IF((GetInputShape(context, inputIdx, opInput.shape) == ge::GRAPH_FAILED),
OP_LOGE(context, "ReduceOpTmpl get input shape failed"), return ge::GRAPH_FAILED);
OP_CHECK_IF((GetInputStride(context, inputIdx, opInput.dimStrides) == ge::GRAPH_FAILED),
OP_LOGE(context, "ReduceOpTmpl get input stride failed"), return ge::GRAPH_FAILED);
if (CheckAllReduce(context, outIdx) && CheckIsContiguous(opInput)) {
opInput.axes.resize(opInput.shape.size());
for (size_t i = 0; i < opInput.shape.size(); i++) {
opInput.axes[i] = i;
}
} else {
if (axesDtype == ge::DT_INT32) {
status = GetConstInputData<int32_t>(context, axesIdx, opInput.axes);
} else if (axesDtype == ge::DT_INT64) {
status = GetConstInputData<int64_t>(context, axesIdx, opInput.axes);
} else {
OP_LOGE_FOR_INVALID_DTYPE_WITH_REASON(context->GetNodeName(), "axes",
Ops::Base::ToString(axesDtype).c_str(),
"only support const input dtype in [int32, int64]");
status = ge::GRAPH_FAILED;
}
OP_CHECK_IF((status == ge::GRAPH_FAILED), OP_LOGE(context, "ReduceOpTmpl get axes const input failed"),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
}
static void MakeWrapDim(const std::vector<int64_t>& shape, std::vector<int64_t>& axes)
{
size_t shapeSize = shape.size();
for (size_t i = 0; i < axes.size(); i++) {
if (axes[i] < 0) {
axes[i] += shapeSize;
}
}
std::sort(axes.begin(), axes.end());
}
template <class Pattern>
bool ReduceOpTiling::IsAxisA(int32_t idx)
{
if (Pattern::FirstA) {
return idx % CONST2 == 0;
} else {
return idx % CONST2 == 1;
}
}
void ReduceOpTiling::PrintTilingData()
{
OP_LOGI(context_,
"TilingData: factorACntPerCore:%lu, factorATotalCnt:%lu, ubFactorA:%lu, factorRCntPerCore:%lu, "
"factorRTotalCnt:%lu, ubFactorR:%lu, groupR:%lu, outSize:%lu, basicBlock:%lu, resultBlock:%lu, "
"coreNum:%u, useNddma:%u, meanVar:%lf",
tilingData_->factorACntPerCore, tilingData_->factorATotalCnt, tilingData_->ubFactorA,
tilingData_->factorRCntPerCore, tilingData_->factorRTotalCnt, tilingData_->ubFactorR, tilingData_->groupR,
tilingData_->outSize, tilingData_->basicBlock, tilingData_->resultBlock, context_->GetBlockDim(),
tilingData_->useNddma, tilingData_->meanVar);
OP_LOGI(
context_,
"TilingData: shape is:%s, stride is:%s, dstStride is:%s, sliceNum is:%s, sliceShape is:%s, sliceStride is:%s",
ReduceOpTmpl::VectorToString(tilingData_->shape, MAX_DIM).c_str(),
ReduceOpTmpl::VectorToString(tilingData_->stride, MAX_DIM).c_str(),
ReduceOpTmpl::VectorToString(tilingData_->dstStride, MAX_DIM).c_str(),
ReduceOpTmpl::VectorToString(tilingData_->sliceNum, MAX_DIM).c_str(),
ReduceOpTmpl::VectorToString(tilingData_->sliceShape, MAX_DIM).c_str(),
ReduceOpTmpl::VectorToString(tilingData_->sliceStride, MAX_DIM).c_str());
}
uint64_t ReduceOpTiling::Ratio()
{
if (opDag_.memLevel == MemLevel::LEVEL_2) {
return CeilDiv(opDag_.maxInputBytes, static_cast<uint64_t>(ge::GetSizeByDataType(opInput_.inputDtype)));
}
return 1UL;
}
ge::graphStatus ReduceOpTiling::AxesCheck(const std::vector<int64_t>& shape, const std::vector<int64_t>& axes)
{
int64_t shapeSize = static_cast<int64_t>(shape.size());
int64_t axesSize = static_cast<int64_t>(axes.size());
OP_CHECK_IF(
(axesSize > shapeSize),
OP_LOGE_FOR_INVALID_SHAPESIZE_WITH_REASON(context_->GetNodeName(), "axes", std::to_string(axesSize).c_str(),
"axes size over x shape size is illegal"),
return ge::GRAPH_FAILED);
for (int64_t i = 0; i < axesSize; i++) {
if (axes[i] >= shapeSize || axes[i] < 0) {
std::string reasonMsg = "illegal axis: " + std::to_string(i) + " dim: " + std::to_string(axes[i]) +
" out of shape range:[0, " + std::to_string(shapeSize) + ")";
OP_LOGE_FOR_INVALID_SHAPEDIM_WITH_REASON(context_->GetNodeName(), "axes", std::to_string(axes[i]).c_str(),
reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ReduceOpTiling::ParamCheck(ReduceOpInputParam& opInput)
{
int32_t dtypeSize = ge::GetSizeByDataType(opInput.inputDtype);
OP_CHECK_IF(dtypeSize <= 0,
OP_LOGE_FOR_INVALID_DTYPE_WITH_REASON(context_->GetNodeName(), "x", std::to_string(dtypeSize).c_str(),
"dtype of input x is illegal"),
return ge::GRAPH_FAILED);
OP_LOGD(context_, "view shape is:%s, strides:%s, axes:%s", ReduceOpTmpl::VectorToString(opInput.shape).c_str(),
ReduceOpTmpl::VectorToString(opInput.dimStrides).c_str(),
ReduceOpTmpl::VectorToString(opInput.axes).c_str());
MakeWrapDim(opInput.shape, opInput.axes);
OP_CHECK_IF((AxesCheck(opInput.shape, opInput.axes) == ge::GRAPH_FAILED), OP_LOGE(context_, "illegal axes"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ReduceOpTiling::PreProcessOptionalParam()
{
if (tilingData_ == nullptr) {
tilingData_ = context_->GetTilingData<ReduceOpTilingData>();
OP_CHECK_NULL_WITH_CONTEXT(context_, tilingData_);
}
OP_CHECK_IF((memset_s(tilingData_, sizeof(ReduceOpTilingData), 0, sizeof(ReduceOpTilingData)) != EOK),
OP_LOGE(context_, "memset tilingdata failed"), return ge::GRAPH_FAILED);
if (opInput_.dimStrides.empty()) {
size_t shapeSize = opInput_.shape.size();
opInput_.dimStrides.resize(shapeSize);
int64_t stride = 1L;
for (int32_t i = shapeSize - 1; i >= 0; i--) {
opInput_.dimStrides[i] = stride;
stride = stride * opInput_.shape[i];
}
}
if (compileInfo_ == nullptr) {
compileInfoPtr_ = std::make_unique<ReduceOpCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context_, compileInfoPtr_);
compileInfo_ = compileInfoPtr_.get();
}
compileInfo_->vectorCoreNum = GetAivCoreNum(context_);
OP_CHECK_IF(compileInfo_->vectorCoreNum == 0,
OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "vectorCoreNum",
std::to_string(compileInfo_->vectorCoreNum).c_str(),
"The value of vector core num must be greater than 0"),
return ge::GRAPH_FAILED);
compileInfo_->ubSize = GetUbSize(context_);
OP_CHECK_IF(compileInfo_->ubSize == 0,
OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "ubSize",
std::to_string(compileInfo_->ubSize).c_str(),
"The value of ub size must be greater than 0"),
return ge::GRAPH_FAILED);
compileInfo_->cacheLineSize = GetCacheLineSize(context_);
OP_CHECK_IF(compileInfo_->cacheLineSize == 0,
OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "cacheLineSize",
std::to_string(compileInfo_->cacheLineSize).c_str(),
"The value of cacheline size must be greater than 0"),
return ge::GRAPH_FAILED);
compileInfo_->ubBlockSize = GetUbBlockSize(context_);
OP_CHECK_IF(compileInfo_->ubBlockSize == 0,
OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "ubBlockSize",
std::to_string(compileInfo_->ubBlockSize).c_str(),
"The value of ub block size must be greater than 0"),
return ge::GRAPH_FAILED);
compileInfo_->vRegSize = GetVRegSize(context_);
OP_CHECK_IF(compileInfo_->vRegSize == 0,
OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "vRegSize",
std::to_string(compileInfo_->vRegSize).c_str(),
"The value of vReg size must be greater than 0"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
void ReduceOpTiling::EliminateOne(uint64_t* viewShape, int64_t* viewStride, int32_t& shapeSize)
{
int32_t dstIdx = 1;
for (size_t i = 0; i < opInput_.axes.size(); i++) {
opInput_.axes[i] = opInput_.axes[i] + 1;
}
int32_t eraseNum = 0;
for (size_t i = 0; i < opInput_.shape.size(); i++) {
auto iter = std::find(opInput_.axes.begin(), opInput_.axes.end(), i + 1);
bool isContiguous{false};
if (i != opInput_.shape.size() - 1) {
isContiguous = (opInput_.dimStrides[i] == opInput_.dimStrides[i + 1] * opInput_.shape[i + 1]);
} else {
isContiguous = (opInput_.dimStrides[i] == 1UL);
}
if (opInput_.shape[i] != 1L || !isContiguous) {
viewShape[dstIdx] = opInput_.shape[i];
viewStride[dstIdx] = opInput_.dimStrides[i];
if (iter != opInput_.axes.end()) {
*iter = *iter - eraseNum;
}
dstIdx++;
} else {
eraseNum++;
if (iter != opInput_.axes.end()) {
opInput_.axes.erase(iter);
}
}
}
shapeSize = dstIdx;
viewStride[0] = opInput_.dimStrides[0] * opInput_.shape[0];
OP_LOGI(context_, "after EliminateOne, viewShape is:%s, viewStride is:%s, axes:%s",
ReduceOpTmpl::VectorToString(viewShape, shapeSize).c_str(),
ReduceOpTmpl::VectorToString(viewStride, shapeSize).c_str(),
ReduceOpTmpl::VectorToString(opInput_.axes).c_str());
}
void ReduceOpTiling::PadDimForNonContiguous(uint64_t* viewShape, int64_t* viewStride, int32_t& shapeSize)
{
int32_t nonContiguousNum = 0;
for (int32_t i = 0; i < shapeSize - 1; i++) {
bool isContiguous = (viewStride[i] == viewStride[i + 1] * static_cast<int64_t>(viewShape[i + 1]));
if (!isContiguous) {
nonContiguousNum++;
}
}
int32_t j = nonContiguousNum + shapeSize - 1;
for (int32_t i = shapeSize - 1; i >= 0; i--) {
bool isContiguous{false};
if (i != shapeSize - 1) {
isContiguous = (viewStride[i] == viewStride[i + 1] * static_cast<int64_t>(viewShape[i + 1]));
} else {
isContiguous = (viewStride[i] == 1UL);
}
auto iter = std::find(opInput_.axes.begin(), opInput_.axes.end(), i);
bool isReduce = (iter != opInput_.axes.end());
if (!isContiguous && i != shapeSize - 1) {
viewShape[j] = 1UL;
viewStride[j] = viewStride[j + 1] * static_cast<int64_t>(viewShape[j + 1]);
if (!isReduce) {
opInput_.axes.emplace_back(j);
}
j--;
}
viewShape[j] = viewShape[i];
viewStride[j] = viewStride[i];
if (isReduce) {
*iter = j;
}
j--;
}
std::sort(opInput_.axes.begin(), opInput_.axes.end());
shapeSize = nonContiguousNum + shapeSize;
OP_LOGI(context_, "after PadDimForNonContiguous, viewShape is:%s, viewStride is:%s, axes:%s",
ReduceOpTmpl::VectorToString(viewShape, shapeSize).c_str(),
ReduceOpTmpl::VectorToString(viewStride, shapeSize).c_str(),
ReduceOpTmpl::VectorToString(opInput_.axes).c_str());
}
void ReduceOpTiling::MergeAxis(uint64_t* viewShape, int64_t* viewStride, int32_t& shapeSize)
{
int32_t tmpSize = 0;
for (int32_t i = 0; i < shapeSize;) {
sliceNum_[tmpSize] = 1UL;
sliceShape_[tmpSize] = viewShape[i];
sliceStride_[tmpSize] = viewStride[i];
viewStride_[tmpSize] = viewStride[i];
auto iter0 = std::find(opInput_.axes.begin(), opInput_.axes.end(), i);
bool isRAxis0 = iter0 != opInput_.axes.end();
uint64_t s = viewShape[i];
int32_t j = i + 1;
for (; j < shapeSize; j++) {
int64_t curContiguousStride = (j == shapeSize - 1 ? 1L : viewStride[j + 1] * viewShape[j + 1]);
bool isContiguous1 = (viewStride[j] == curContiguousStride);
auto iter1 = std::find(opInput_.axes.begin(), opInput_.axes.end(), j);
bool isRAxis1 = iter1 != opInput_.axes.end();
if (isRAxis0 != isRAxis1) {
break;
}
if (!isContiguous1) {
sliceNum_[tmpSize] = s;
}
s *= viewShape[j];
sliceShape_[tmpSize] = isContiguous1 ? sliceShape_[tmpSize] * viewShape[j] :
viewShape[j];
sliceStride_[tmpSize] = isContiguous1 ? curContiguousStride : sliceStride_[tmpSize];
viewStride_[tmpSize] = viewStride[j];
if (isRAxis1) {
opInput_.axes.erase(iter1);
}
}
i = j;
viewShape[tmpSize] = s;
tmpSize++;
if (isRAxis0) {
*iter0 = tmpSize - 1;
}
}
for (int32_t i = tmpSize; i < shapeSize; i++) {
viewShape[i] = 0;
}
shapeSize = tmpSize;
OP_LOGI(context_,
"after merge axis, viewShape:%s, viewStride:%s, sliceNum:%s, sliceShape:%s, sliceStride:%s, axes:%s",
ReduceOpTmpl::VectorToString(viewShape, shapeSize).c_str(),
ReduceOpTmpl::VectorToString(viewStride_, shapeSize).c_str(),
ReduceOpTmpl::VectorToString(sliceNum_, shapeSize).c_str(),
ReduceOpTmpl::VectorToString(sliceShape_, shapeSize).c_str(),
ReduceOpTmpl::VectorToString(sliceStride_, shapeSize).c_str(),
ReduceOpTmpl::VectorToString(opInput_.axes).c_str());
}
template <class Pattern>
bool ReduceOpTiling::IsEmtpyTensor(const uint64_t* shape)
{
for (int32_t i = 0; i < Pattern::Dim; i++) {
if (shape[i] == 0) {
return true;
}
}
return false;
}
template <class Pattern>
uint64_t ReduceOpTiling::CaculateReduceSize(const uint64_t* shape)
{
uint64_t dSize = ge::GetSizeByDataType(opInput_.inputDtype);
uint64_t ubBlockSize = compileInfo_->ubBlockSize / dSize;
int32_t dim = Pattern::TailA ? Pattern::Dim - AXES_STEP : Pattern::Dim - CONST1;
uint64_t r = 1;
for (int32_t i = dim; i > -1; i = i - AXES_STEP) {
if (i == Pattern::Dim - 1) {
r = r * CeilAlign(shape[i], ubBlockSize);
} else {
r = r * shape[i];
}
}
return r;
}
template <class Pattern>
void ReduceOpTiling::PadDimOne(uint64_t* shape)
{
int32_t padNum = Pattern::TailA ? CONST9 - Pattern::Dim : CONST8 - Pattern::Dim;
int32_t maxDim = Pattern::TailA ? CONST9 : CONST8;
if (padNum == 0) {
return;
}
for (int32_t i = 0; i < Pattern::Dim; ++i) {
shape[maxDim - 1 - i] = shape[Pattern::Dim - 1 - i];
viewStride_[maxDim - 1 - i] = viewStride_[Pattern::Dim - 1 - i];
sliceNum_[maxDim - 1 - i] = sliceNum_[Pattern::Dim - 1 - i];
sliceShape_[maxDim - 1 - i] = sliceShape_[Pattern::Dim - 1 - i];
sliceStride_[maxDim - 1 - i] = sliceStride_[Pattern::Dim - 1 - i];
}
for (int32_t i = 0; i < padNum; ++i) {
shape[i] = 1UL;
sliceNum_[i] = 1UL;
sliceShape_[i] = 1UL;
viewStride_[i] = viewStride_[padNum] * shape[padNum];
sliceStride_[i] = sliceStride_[padNum] * sliceNum_[padNum];
}
}
void ReduceOpTiling::TransformShape(uint64_t* shape, int64_t* viewStride, int32_t& shapeSize)
{
shape[0] = 1UL;
EliminateOne(shape, viewStride, shapeSize);
PadDimForNonContiguous(shape, viewStride, shapeSize);
MergeAxis(shape, viewStride, shapeSize);
}
void ReduceOpTiling::DoReduceTiling(ReduceTilingKey& key)
{
uint64_t newShape[MAX_DIM] = {0};
int64_t newStride[MAX_DIM] = {0};
int32_t newShapeSize = 0;
tilingKey_.batchInvariant = key.batchInvariant;
TransformShape(newShape, newStride, newShapeSize);
DoTilingMatchPattern(newShape, newShapeSize);
CalcUserWorkSpace();
GetTilingKey(key);
PrintTilingData();
}
* cacheLine切分找到硬件cacheLine大小的位置
* 例如: float32, shape:(2, 35, 7), cacheLine:256B
* cacheLine切分找到axis:1, step:10, outer:4
*/
template <class Pattern>
void ReduceOpTiling::ComputeCacheLineBlock(const uint64_t* shape)
{
uint64_t cacheSize = compileInfo_->cacheLineSize / ge::GetSizeByDataType(opInput_.inputDtype);
uint64_t dSize = (Pattern::TailA ? opDag_.minInputBytes : ge::GetSizeByDataType(opInput_.inputDtype));
uint64_t ubBlockSize = compileInfo_->ubBlockSize / dSize;
uint64_t cacheLineShape = 1UL;
uint64_t cacheLineStep = 1UL;
uint64_t cacheLineOuter = 1UL;
uint64_t aInCacheLine = 1UL;
uint64_t rInCacheLine = 1UL;
for (int32_t i = Pattern::Dim - 1; i > -1; --i) {
cacheLineShape *= shape[i];
if (cacheLineShape > cacheSize) {
cacheLineShape /= shape[i];
cacheLineStep = CeilDiv(cacheSize, cacheLineShape);
cacheLineStep = cacheLineStep > sliceShape_[i] ? FloorAlign(cacheLineStep, sliceShape_[i]) : cacheLineStep;
if (cacheLineStep >= sliceShape_[i]) {
cacheLineOuter = CeilDiv(shape[i], cacheLineStep);
} else {
cacheLineOuter = sliceNum_[i] * CeilDiv(sliceShape_[i], cacheLineStep);
}
cBlock_.axis = i;
break;
} else {
cacheLineStep = shape[i];
cBlock_.axis = i;
}
}
for (int32_t i = Pattern::Dim - 1; i > cBlock_.axis; --i) {
if (i == Pattern::Dim - 1) {
if (IsAxisA<Pattern>(i)) {
aInCacheLine = aInCacheLine * CeilAlign(shape[i], ubBlockSize);
} else {
rInCacheLine = rInCacheLine * CeilAlign(shape[i], ubBlockSize);
}
} else {
if (IsAxisA<Pattern>(i)) {
aInCacheLine = aInCacheLine * shape[i];
} else {
rInCacheLine = rInCacheLine * shape[i];
}
}
}
cBlock_.cacheLineStep = cacheLineStep;
cBlock_.cacheLineOuter = cacheLineOuter;
cBlock_.aSize = aInCacheLine;
cBlock_.rSize = rInCacheLine;
OP_LOGD(context_, "cacheLine Block axis:%d, cacheLineStep:%lu, cacheLineOuter:%lu, aSize:%lu, rSize:%lu",
cBlock_.axis, cBlock_.cacheLineStep, cBlock_.cacheLineOuter, cBlock_.aSize, cBlock_.rSize);
}
template <class Pattern>
void ReduceOpTiling::InitUnit(const uint64_t* shape)
{
int32_t axisInCacheLine = cBlock_.axis;
for (int32_t i = Pattern::FirstA ? 0 : 1; i < axisInCacheLine; i += AXES_STEP) {
unitA_.outer *= shape[i];
}
for (int32_t i = Pattern::FirstA ? 1 : 0; i < axisInCacheLine; i += AXES_STEP) {
unitR_.outer *= shape[i];
}
bool basicSplitA = IsAxisA<Pattern>(axisInCacheLine);
if (basicSplitA) {
unitA_.outer *= cBlock_.cacheLineOuter;
} else {
unitR_.outer *= cBlock_.cacheLineOuter;
}
}
template <class Pattern>
void ReduceOpTiling::ComputeCacheLineBlockAndUnit(const uint64_t* shape)
{
ComputeCacheLineBlock<Pattern>(shape);
InitUnit<Pattern>(shape);
}
* 计算UB内A轴的切分大小,最大512B或者按A轴分核低于85%
* 例如: float32, shape:(12800, 8), pattern:AR, cacheLine:256B, coreNum:64
* cacheLine切分后(1600, cacheLine(8, 8))
* 对1600做A轴切分,找到uintA inner:16(cacheline中A轴为8, 与16相乘后达到512B上限), outer:100
*/
template <class Pattern>
void ReduceOpTiling::ComputeUnitA(const uint64_t* shape)
{
int32_t axisInCacheLine = cBlock_.axis;
bool basicSplitA = IsAxisA<Pattern>(axisInCacheLine);
uint64_t outerA = (basicSplitA ? unitA_.outer / cBlock_.cacheLineOuter * shape[axisInCacheLine] : unitA_.outer);
uint64_t innerA = unitA_.inner;
uint64_t maxCacheA = MAX_INNER_A / opDag_.maxInputBytes;
uint64_t maxInnerA = Pattern::ID == PATTERN_A ? basicBlock_ * Ratio() / opDag_.maxInputBytes : maxCacheA;
uint64_t stepLen = Pattern::ID == PATTERN_A ? A_STEP_LEN : 1;
uint64_t bBlockNum = basicBlock_ * Ratio() / opDag_.maxInputBytes;
uint64_t ubBlockSize = compileInfo_->ubBlockSize / opDag_.minInputBytes;
uint64_t step = 1;
int32_t iA;
for (iA = basicSplitA ? axisInCacheLine : axisInCacheLine - 1; iA > -1; iA -= AXES_STEP) {
uint64_t axisLen = shape[iA];
step = (iA == axisInCacheLine ? cBlock_.cacheLineStep : 1UL);
uint64_t maxStep = (iA == axisInCacheLine ? cBlock_.cacheLineStep : 1UL);
bool splitHere = false;
double maxRate = 0.0f;
for (; step <= axisLen; step = step + stepLen) {
uint64_t s = step > sliceShape_[iA] ? FloorAlign(step, sliceShape_[iA]) : step;
if (iA == Pattern::Dim - 1 && s <= sliceShape_[iA]) {
s = FloorAlign(s, ubBlockSize);
}
uint64_t tmpInnerA = innerA * s;
uint64_t tmpOuterA = (s >= sliceShape_[iA] ?
outerA / axisLen * CeilDiv(axisLen, s) :
outerA / axisLen * sliceNum_[iA] * CeilDiv(sliceShape_[iA], s));
uint64_t aSize = tmpInnerA * cBlock_.aSize;
if (aSize <= maxInnerA && aSize * cBlock_.rSize <= bBlockNum) {
uint64_t tempCoreNum = (tmpOuterA * unitR_.outer) /
CeilDiv(tmpOuterA * unitR_.outer, compileInfo_->vectorCoreNum);
tempCoreNum = tempCoreNum > tmpOuterA ? FloorAlign(tempCoreNum, tmpOuterA) : tempCoreNum;
double rate = static_cast<double>(tempCoreNum) / static_cast<double>(compileInfo_->vectorCoreNum);
if (Pattern::TailA) {
maxStep = (rate > THRES_HOLD) ? s : maxStep;
} else {
maxStep = (rate > THRES_HOLD && rate > maxRate) ? s : maxStep;
}
maxRate = rate > maxRate ? rate : maxRate;
} else {
splitHere = true;
break;
}
}
if (splitHere || maxStep != axisLen || iA - AXES_STEP < 0) {
step = maxStep == 0 ? 1UL : maxStep;
innerA = innerA * step;
outerA = (step >= sliceShape_[iA] ? outerA / axisLen * CeilDiv(axisLen, step) :
outerA / axisLen * sliceNum_[iA] * CeilDiv(sliceShape_[iA], step));
break;
}
innerA *= (iA == Pattern::Dim - 1 ? CeilAlign(axisLen, ubBlockSize) : axisLen);
outerA /= axisLen;
}
AssembleUnit(unitA_, iA, innerA, outerA, step);
}
* 根据BasicBlock大小和UB内A轴的切分大小,计算UB内R轴的切分大小
* 用BasicBlock / UB内A的切分大小,找到满BasicBlock的R轴切分位置
*/
template <class Pattern>
void ReduceOpTiling::ComputeUnitR(const uint64_t* shape)
{
int32_t axisInCacheLine = cBlock_.axis;
bool basicSplitA = IsAxisA<Pattern>(axisInCacheLine);
uint64_t outerR = (basicSplitA ? unitR_.outer : unitR_.outer / cBlock_.cacheLineOuter * shape[axisInCacheLine]);
uint64_t innerR = unitR_.inner;
uint64_t outerA = unitA_.outer;
uint64_t innerA = unitA_.inner;
uint64_t step = 1UL;
uint64_t bBlockNum = basicBlock_ * Ratio() / opDag_.maxInputBytes;
uint64_t dSize = ge::GetSizeByDataType(opInput_.inputDtype);
uint64_t ubBlockSize = compileInfo_->ubBlockSize / dSize;
int32_t iR;
for (iR = basicSplitA ? axisInCacheLine - 1 : axisInCacheLine; iR > -1; iR -= AXES_STEP) {
uint64_t axisLen = shape[iR];
uint64_t effectiveLen = (iR == Pattern::Dim - 1) ? CeilAlign(axisLen, ubBlockSize) : axisLen;
if (innerA * innerR * cBlock_.aSize * cBlock_.rSize * effectiveLen <= bBlockNum) {
outerR = outerR / axisLen;
innerR *= effectiveLen;
continue;
}
step = std::min(bBlockNum / (innerA * innerR * cBlock_.aSize * cBlock_.rSize), axisLen);
if (step > sliceShape_[iR]) {
step = FloorAlign(step, sliceShape_[iR]);
}
if (iR == Pattern::Dim - 1) {
step = FloorAlign(step, ubBlockSize);
}
uint64_t minStep = (iR == axisInCacheLine ? cBlock_.cacheLineStep : 1UL);
for (uint64_t s = step; s > minStep; s--) {
uint64_t tmpS = (s > sliceShape_[iR] ? FloorAlign(s, sliceShape_[iR]) : s);
if (iR == Pattern::Dim - 1 && tmpS != sliceShape_[iR]) {
tmpS = FloorAlign(tmpS, ubBlockSize);
}
auto tmpOuterR = (tmpS >= sliceShape_[iR] ?
outerR / axisLen * CeilDiv(axisLen, tmpS) :
outerR / axisLen * sliceNum_[iR] * CeilDiv(sliceShape_[iR], tmpS));
uint64_t tempCoreNum = (outerA * tmpOuterR) / CeilDiv(outerA * tmpOuterR, compileInfo_->vectorCoreNum);
tempCoreNum = tempCoreNum > outerA ? FloorAlign(tempCoreNum, outerA) : tempCoreNum;
double rate = static_cast<double>(tempCoreNum) / static_cast<double>(compileInfo_->vectorCoreNum);
if (rate > THRES_HOLD) {
step = tmpS;
break;
}
}
innerR = innerR * step;
outerR = (step >= sliceShape_[iR] ? outerR / axisLen * CeilDiv(axisLen, step) :
outerR / axisLen * sliceNum_[iR] * CeilDiv(sliceShape_[iR], step));
break;
}
AssembleUnit(unitR_, iR, innerR, outerR, step);
}
* 针对R轴过小,UB内R轴全载,BasicBlock不能满载时,调整UB内A轴切分,上限Reduce计算后输出buffer大小
*/
template <class Pattern>
void ReduceOpTiling::ComputeProgressUnitA(const uint64_t* shape)
{
if (unitR_.idx != -1) {
return;
}
uint64_t innerA = unitA_.inner / unitA_.step;
uint64_t outerA = unitA_.outer / CeilDiv(shape[unitA_.idx], unitA_.step) * shape[unitA_.idx];
if (unitA_.step < sliceShape_[unitA_.idx]) {
outerA = unitA_.outer / sliceNum_[unitA_.idx] / CeilDiv(sliceShape_[unitA_.idx], unitA_.step) *
shape[unitA_.idx];
}
uint64_t bBlockNum = basicBlock_ * Ratio() / opDag_.maxInputBytes;
uint64_t maxInnerA = resultBlock_ / opDag_.maxInputBytes;
uint64_t ubBlockSize = compileInfo_->ubBlockSize / opDag_.minInputBytes;
uint64_t cacheSize = compileInfo_->cacheLineSize / opDag_.minInputBytes;
uint64_t innerR = unitR_.inner;
uint64_t step = 1;
int32_t iA;
for (iA = unitA_.idx; iA > -1; iA -= AXES_STEP) {
uint64_t axisLen = shape[iA];
bool splitHere = false;
uint64_t s = (iA == unitA_.idx ? unitA_.step + 1UL : 1UL);
uint64_t maxStep = (iA == unitA_.idx ? unitA_.step : 1UL);
for (; s <= axisLen; s++) {
uint64_t tmpS = s > sliceShape_[iA] ? FloorAlign(s, sliceShape_[iA]) : s;
if (iA == Pattern::Dim - 1) {
tmpS = CeilAlign(tmpS, ubBlockSize);
}
uint64_t tmpInnerA = innerA * tmpS;
uint64_t tmpOuterA = (tmpS >= sliceShape_[iA] ?
outerA / axisLen * CeilDiv(axisLen, tmpS) :
outerA / axisLen * sliceNum_[iA] * CeilDiv(sliceShape_[iA], tmpS));
double rate = static_cast<double>(tmpOuterA) /
static_cast<double>(CeilAlign(tmpOuterA, compileInfo_->vectorCoreNum));
bool isContinue = (tmpInnerA * innerR * cBlock_.aSize * cBlock_.rSize <= bBlockNum &&
tmpInnerA * cBlock_.aSize <= maxInnerA);
if (isContinue) {
if (tmpInnerA * cBlock_.aSize <= cacheSize) {
maxStep = tmpS;
} else {
maxStep = rate >= THRES_HOLD ? tmpS : maxStep;
}
continue;
} else {
splitHere = true;
break;
}
}
if (splitHere || iA - AXES_STEP < 0) {
step = maxStep;
innerA *= step;
outerA = (step >= sliceShape_[iA] ? outerA / axisLen * CeilDiv(axisLen, step) :
outerA / axisLen * sliceNum_[iA] * CeilDiv(sliceShape_[iA], step));
break;
}
innerA *= (iA == Pattern::Dim - 1 ? CeilAlign(axisLen, ubBlockSize) : axisLen);
outerA /= axisLen;
}
AssembleUnit(unitA_, iA, innerA, outerA, step);
}
template <class Pattern>
uint64_t ReduceOpTiling::TryGetReduceBlock(const uint64_t* shape, uint64_t preInputBufferNum, uint64_t preRestBufferNum,
uint64_t postBufferNum)
{
uint64_t r = CaculateReduceSize<Pattern>(shape);
uint64_t ubAvilSize = compileInfo_->ubSize - (CACHE_BUF_SIZE + opInput_.reservedSize);
uint64_t ratio = Ratio();
double basicBlock = static_cast<double>(ubAvilSize) /
(static_cast<double>(preInputBufferNum + preRestBufferNum * ratio) +
static_cast<double>(postBufferNum * ratio) / static_cast<double>(r));
return static_cast<uint64_t>(static_cast<double>(basicBlock * ratio) / static_cast<double>(r));
}
template <class Pattern>
ge::graphStatus ReduceOpTiling::CalcBasicBlock(const uint64_t* shape)
{
if (compileInfo_->ubSize <= CACHE_BUF_SIZE + opInput_.reservedSize) {
std::string reasonMsg = "ubSize: " + std::to_string(compileInfo_->ubSize) +
" is smaller than size: " + std::to_string(CACHE_BUF_SIZE + opInput_.reservedSize) +
", not support";
OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "ubSize",
std::to_string(compileInfo_->ubSize).c_str(), reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
uint64_t preInputBufferNum = opDag_.preInput * CONST2;
uint64_t preRestBufferNum = opDag_.preTempCalc + opDag_.preOutput * CONST2 + opInput_.reservedNode;
if (opDag_.reduceOpPos == 1) {
preInputBufferNum = preInputBufferNum * CONST2;
}
uint64_t postBufferNum = opDag_.postInput + opDag_.postTempCalc + opDag_.postOutput + CONST1;
uint64_t ratio = Ratio();
uint64_t ubAvilSize = compileInfo_->ubSize - (CACHE_BUF_SIZE + opInput_.reservedSize);
if (Pattern::ID == PATTERN_A) {
basicBlock_ = ubAvilSize / (preInputBufferNum + (preRestBufferNum + postBufferNum) * ratio);
basicBlock_ = FloorAlign(basicBlock_, compileInfo_->vRegSize);
resultBlock_ = basicBlock_ * ratio;
} else {
uint64_t resBlock = TryGetReduceBlock<Pattern>(shape, preInputBufferNum, preRestBufferNum, postBufferNum);
resultBlock_ = FloorAlign(resBlock, compileInfo_->cacheLineSize);
if (resultBlock_ < static_cast<uint64_t>(MAX_INNER_A)) {
resultBlock_ = MAX_INNER_A;
} else if (resultBlock_ > CACHE_BUF_SIZE) {
resultBlock_ = CACHE_BUF_SIZE;
}
if (ubAvilSize <= resultBlock_ * postBufferNum) {
std::string reasonMsg = "ubSize: " + std::to_string(compileInfo_->ubSize) + " is smaller than size: " +
std::to_string(CACHE_BUF_SIZE + opInput_.reservedSize +
resultBlock_ * postBufferNum) +
" not support";
OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "ubSize",
std::to_string(compileInfo_->ubSize).c_str(), reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
uint64_t preBufSize = ubAvilSize - resultBlock_ * postBufferNum;
basicBlock_ = preBufSize / (preInputBufferNum + preRestBufferNum * ratio);
basicBlock_ = FloorAlign(basicBlock_, compileInfo_->vRegSize);
}
if (basicBlock_ < compileInfo_->vRegSize) {
std::string reasonMsg = "basic block: " + std::to_string(basicBlock_) + " is too small, not support";
OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "basicBlock",
std::to_string(basicBlock_).c_str(), reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
OP_LOGI(context_, "preInputBufferNum:%lu, preRestBufferNum:%lu, postBufferNum:%lu, basicBlock:%ld, resBlock:%ld",
preInputBufferNum, preRestBufferNum, postBufferNum, basicBlock_, resultBlock_);
CalcUserBasicBlock(Pattern::ID == PATTERN_A);
return ge::GRAPH_SUCCESS;
}
template <class Pattern>
ge::graphStatus ReduceOpTiling::ComputeEmptyTiling(uint64_t* shape)
{
uint64_t outSize = 1;
for (int32_t dim = Pattern::Dim - 1; dim > -1; dim--) {
if (IsAxisA<Pattern>(dim)) {
outSize *= shape[dim];
}
}
tilingData_->outSize = outSize;
context_->SetBlockDim(compileInfo_->vectorCoreNum);
if (outSize == 0) {
return ge::GRAPH_SUCCESS;
}
uint64_t ubAvilSize = compileInfo_->ubSize - CACHE_BUF_SIZE;
basicBlock_ = FloorAlign(ubAvilSize / CONST2, compileInfo_->vRegSize);
uint64_t newshape[MAX_DIM] = {outSize};
ComputeCacheLineBlock<__reducePattern::A>(newshape);
unitA_.outer *= cBlock_.cacheLineOuter;
sliceShape_[0] = outSize;
sliceNum_[0] = 1;
ComputeUnitA<__reducePattern::A>(newshape);
SetTilingData<__reducePattern::A>(newshape);
return ge::GRAPH_SUCCESS;
}
template <class Pattern>
ge::graphStatus ReduceOpTiling::ComputeTiling(uint64_t* shape)
{
dimNum_ = Pattern::Dim;
OP_CHECK_IF((CalcBasicBlock<Pattern>(shape) == ge::GRAPH_FAILED),
OP_LOGE(context_, "calc basic block failed, maybe unsupport ubsize"), return ge::GRAPH_FAILED);
if (IsEmtpyTensor<Pattern>(shape)) {
return ComputeEmptyTiling<Pattern>(shape);
}
if (tilingKey_.batchInvariant && CheckIsContiguous(opInput_)) {
ComputeRFirst<Pattern>(shape);
if (unitR_.idx == -1) {
ComputeAWithRFullLoad<Pattern>(shape);
} else {
ComputeAWithRCannotFullLoad<Pattern>(shape);
}
SetTilingDataBatchInvariant<Pattern>(shape);
} else {
ComputeCacheLineBlockAndUnit<Pattern>(shape);
ComputeUnitA<Pattern>(shape);
ComputeUnitR<Pattern>(shape);
ComputeProgressUnitA<Pattern>(shape);
SetTilingData<Pattern>(shape);
}
OP_LOGI(context_,
"tiling step outerA:%lu, innerA:%lu, stepA:%lu, idxA:%d, outerR:%lu, innerR:%lu, stepR:%lu, idxR:%d",
unitA_.outer, unitA_.inner, unitA_.step, unitA_.idx, unitR_.outer, unitR_.inner, unitR_.step, unitR_.idx);
SetTilingKey<Pattern>();
return ge::GRAPH_SUCCESS;
}
template <class Pattern>
void ReduceOpTiling::SetTilingData(const uint64_t* shape)
{
uint64_t perCoreNum = CeilDiv(unitA_.outer * unitR_.outer, compileInfo_->vectorCoreNum);
uint64_t numBlocks = CeilDiv(unitA_.outer * unitR_.outer, perCoreNum);
if (unitA_.outer < numBlocks) {
auto tmpBlockDim = CeilAlign(numBlocks, unitA_.outer);
if (tmpBlockDim <= compileInfo_->vectorCoreNum) {
numBlocks = tmpBlockDim;
} else {
numBlocks = FloorAlign(numBlocks, unitA_.outer);
}
}
tilingData_->ubFactorA = unitA_.step;
uint64_t factorACntPerCore = CeilDiv(unitA_.outer, numBlocks);
tilingData_->factorACntPerCore = factorACntPerCore;
tilingData_->factorATotalCnt = unitA_.outer;
tilingData_->ubFactorR = unitR_.step;
uint64_t factorRCntPerCore = CeilDiv(unitR_.outer, CeilDiv(numBlocks, unitA_.outer));
tilingData_->factorRCntPerCore = factorRCntPerCore;
tilingData_->factorRTotalCnt = unitR_.outer;
tilingData_->groupR = CeilDiv(unitR_.outer, factorRCntPerCore);
if (tilingData_->groupR > 1) {
OP_CHECK_IF(context_->SetScheduleMode(1) != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "Failed to set ScheduleMode!"), return);
}
for (int32_t i = 0; i < MAX_DIM; i++) {
tilingData_->shape[i] = shape[i];
tilingData_->stride[i] = viewStride_[i];
tilingData_->sliceNum[i] = sliceNum_[i];
tilingData_->sliceShape[i] = sliceShape_[i];
tilingData_->sliceStride[i] = sliceStride_[i];
}
tilingData_->basicBlock = basicBlock_;
tilingData_->resultBlock = resultBlock_;
tilingData_->coreNum = static_cast<int32_t>(compileInfo_->vectorCoreNum);
tilingData_->useNddma = IsUseNddma<Pattern>(shape);
ComputeStride<Pattern>(shape);
uint32_t realCore = CeilDiv(unitA_.outer, factorACntPerCore) * CeilDiv(unitR_.outer, factorRCntPerCore);
tilingData_->realCoreNum = realCore;
context_->SetBlockDim(realCore);
}
template <class Pattern>
int32_t ReduceOpTiling::IsUseNddma(const uint64_t* shape)
{
int32_t axis = cBlock_.axis;
uint64_t dSize = ge::GetSizeByDataType(opInput_.inputDtype);
uint64_t ubBlockSize = compileInfo_->ubBlockSize / dSize;
if (viewStride_[Pattern::Dim - 1] != 1UL) {
return 1;
}
if (shape[Pattern::Dim - 1] >= ubBlockSize) {
return 0;
}
if ((Pattern::Dim - 1 - axis > CONST2) || (Pattern::Dim - 1 - axis == CONST2 && cBlock_.cacheLineStep != 1UL)) {
return 1;
}
if (Pattern::TailA) {
uint64_t factorA = tilingData_->ubFactorA;
for (auto iA = unitA_.idx + AXES_STEP; iA < Pattern::Dim; iA += AXES_STEP) {
factorA = factorA * shape[iA];
}
if (factorA > ubBlockSize) {
return 0;
}
} else {
uint64_t factorR = tilingData_->ubFactorR;
for (auto iR = unitR_.idx + AXES_STEP; iR < Pattern::Dim; iR += AXES_STEP) {
factorR = factorR * shape[iR];
}
if (factorR > ubBlockSize) {
return 0;
}
}
return 1;
}
template <class Pattern>
void ReduceOpTiling::ComputeStride(const uint64_t* shape)
{
uint64_t s = 1UL;
uint64_t ds = 1UL;
for (int32_t dim = Pattern::Dim - 1; dim > -1; dim--) {
tilingData_->dstStride[dim] = ds;
s *= shape[dim];
if (IsAxisA<Pattern>(dim)) {
ds *= shape[dim];
}
}
double meanVar = static_cast<double>(1) / static_cast<double>(s / ds);
tilingData_->outSize = ds;
tilingData_->meanVar = static_cast<float>(meanVar);
}
template <class Pattern>
void ReduceOpTiling::SetTilingKey()
{
uint64_t groupR = tilingData_->groupR;
int32_t aCount = 0;
int32_t rCount = 0;
int32_t innerACount = 0;
int32_t innerRCount = 0;
if (groupR == 1UL) {
aCount = (unitA_.idx - (Pattern::FirstA ? 0 : 1)) / AXES_STEP + 1;
innerRCount = (unitR_.idx - (Pattern::FirstA ? 1 : 0)) / AXES_STEP + 1;
} else {
aCount = (unitA_.idx - (Pattern::FirstA ? 0 : 1)) / AXES_STEP + 1;
rCount = (unitR_.idx - (Pattern::FirstA ? 1 : 0)) / AXES_STEP + 1;
rCount = rCount + aCount;
}
int32_t innerID = Pattern::TailA ? 0 : 1;
tilingKey_.patternID = Pattern::ID * CONST10 + innerID;
tilingKey_.loopARCount = static_cast<uint32_t>(aCount * CONST10 + rCount);
tilingKey_.loopInnerARCount = static_cast<uint32_t>(innerACount * CONST10 + innerRCount);
tilingKey_.isContiguous = CheckIsContiguous(opInput_);
OP_LOGI(
context_, "patternID:%u, loopARCount:%u, loopInnerARCount:%u, isContiguous:%d, isBatchInvariant:%d", tilingKey_.patternID,
tilingKey_.loopARCount, tilingKey_.loopInnerARCount, tilingKey_.isContiguous ? 1 : 0, tilingKey_.batchInvariant ? 1 : 0);
}
void ReduceOpTiling::GetTilingKey(ReduceTilingKey& key) { key = tilingKey_; }
void ReduceOpTiling::CalcUserWorkSpace()
{
size_t* workspaces = context_->GetWorkspaceSizes(1);
uint64_t groupR = tilingData_->groupR;
uint64_t outSize = tilingData_->outSize;
uint64_t size = opDag_.maxInputBytes;
if (groupR > 1UL) {
workSpaceSize_ = compileInfo_->vectorCoreNum * CeilAlign(outSize * size, compileInfo_->cacheLineSize);
}
workspaces[0] = WORKSPACE_SIZE + workSpaceSize_;
}
ge::graphStatus ReduceOpTiling::DoTilingMatchPattern(uint64_t* shape, int32_t shapeSize)
{
switch (shapeSize) {
case CONST1:
return ComputeTiling<__reducePattern::A>(shape);
case CONST2:
return ComputeTiling<__reducePattern::AR>(shape);
case CONST3:
return ComputeTiling<__reducePattern::ARA>(shape);
case CONST4:
return ComputeTiling<__reducePattern::ARAR>(shape);
case CONST5:
PadDimOne<__reducePattern::ARARA>(shape);
return ComputeTiling<__reducePattern::ARARARARA>(shape);
case CONST6:
PadDimOne<__reducePattern::ARARAR>(shape);
return ComputeTiling<__reducePattern::ARARARAR>(shape);
case CONST7:
PadDimOne<__reducePattern::ARARARA>(shape);
return ComputeTiling<__reducePattern::ARARARARA>(shape);
case CONST8:
return ComputeTiling<__reducePattern::ARARARAR>(shape);
case CONST9:
return ComputeTiling<__reducePattern::ARARARARA>(shape);
default:
OP_LOGE(context_, "unsupport pattern");
return ge::GRAPH_FAILED;
}
}
}
}