* This file is part of the OpenBOAT project at Harbin Institute of Technology (HIT)
* and is contributed to the CANN Open Software.
*
* Copyright (c) 2025 AISS Group, Harbin Institute of Technology (HIT).
* All Rights Reserved.
*
* Authors (accounts):
* - Cao Xiaojuan <@c15503545287>
* - Su Tonghua <@sutonghua>
*
* This program is free software: you can redistribute it and/or modify it.
* Licensed under the CANN Open Software License Agreement Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* See the LICENSE file at the root of the repository for the full text of the License.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTIES OF ANY KIND, EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
*/
* \file slice_v3_tiling.cpp
* \brief
*/
#include "log/log.h"
#include "util/math_util.h"
#include "register/op_impl_registry.h"
#include <graph/utils/type_utils.h>
#include "tiling/platform/platform_ascendc.h"
#include "../op_kernel/slice_v3_tiling_data.h"
#include "../op_kernel/slice_v3_tiling_key.h"
namespace optiling {
constexpr uint32_t WS_SYS_SIZE = 512U;
constexpr uint32_t RESERVED_BYTES = 512U;
static void ComputeCumShape(const std::vector<int64_t>& shape, int64_t* cumShape) {
int n = shape.size();
if (n == 0) return;
cumShape[n - 1] = 1;
for (int i = n - 2; i >= 0; i--) {
cumShape[i] = cumShape[i + 1] * shape[i + 1];
}
}
struct SliceV3CompileInfo {};
const uint32_t BLOCK_SIZE = 32;
static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
{
auto ascendcPlatform = platform_ascendc:: PlatformAscendC(context->GetPlatformInfo());
uint32_t sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
currentWorkspace[0] = WS_SYS_SIZE + sysWorkspaceSize;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetShapeAttrsInfo(const gert::TilingContext* context, int64_t& totalIdx, ge::DataType& dataType)
{
auto inputX = context->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, inputX);
totalIdx = inputX->GetStorageShape().GetShapeSize();
const std::set<ge::DataType> supportedDtype = {ge::DT_FLOAT, ge::DT_INT32,
ge::DT_FLOAT16, ge::DT_INT16,
ge::DT_BF16,ge::DT_INT8, ge::DT_UINT8};
auto inputDesc = context->GetInputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context, inputDesc);
dataType = inputDesc->GetDataType();
if (supportedDtype.count(dataType) == 0) {
OP_LOGE(context, "invalid dtype");
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus SliceV3TilingFunc(gert::TilingContext* context)
{
SliceV3TilingData* tiling = context->GetTilingData<SliceV3TilingData>();
OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
OP_CHECK_IF(
memset_s(tiling, sizeof(SliceV3TilingData), 0, sizeof(SliceV3TilingData)) != EOK,
OP_LOGE(context, "set tiling data error"), return ge::GRAPH_FAILED);
uint64_t ubSize;
uint32_t coreNum;
auto plat = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
plat.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
coreNum = plat.GetCoreNum();
OP_CHECK_IF(
GetWorkspaceSize(context) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetWorkspaceSize error"),
return ge::GRAPH_FAILED);
int64_t totalIdx = 0;
ge::DataType dataType;
OP_CHECK_IF(
GetShapeAttrsInfo(context, totalIdx, dataType) != ge::GRAPH_SUCCESS,
OP_LOGE(context, "GetShapeAttrsInfo error"), return ge::GRAPH_FAILED);
uint32_t dim = context->GetInputShape(0)->GetStorageShape().GetDimNum();
tiling->dim = dim;
uint32_t typeLen = 0;
ge::TypeUtils::GetDataTypeLength(context->GetInputDesc(0)->GetDataType(), typeLen);
OP_CHECK_IF(typeLen == 0, OP_LOGE(context, "dtype len error"), return ge::GRAPH_FAILED);
auto beginTensor = context->GetInputTensor(1);
auto sizeTensor = context->GetInputTensor(2);
const int64_t* beginPtr = beginTensor->GetData<int64_t>();
const int64_t* sizePtr = sizeTensor->GetData<int64_t>();
std::vector<int64_t> begin(dim);
std::vector<int64_t> size(dim);
for (uint32_t i = 0; i < dim; ++i) {
begin[i] = beginPtr[i];
size[i] = sizePtr[i];
tiling->begin[i] = begin[i];
}
std::vector<int64_t> inputDims(dim), outputDims(dim);
for (uint32_t i = 0; i < dim; ++i) {
inputDims[i] = context->GetInputShape(0)->GetStorageShape().GetDim(i);
outputDims[i] = size[i];
}
ComputeCumShape(inputDims, tiling->inputCumShape);
ComputeCumShape(outputDims, tiling->outputCumShape);
int64_t totalBlocks = 1;
for (uint32_t i = 0; i < dim - 1; ++i) {
if (outputDims[i] <= 0) {
totalBlocks = 0;
break;
}
totalBlocks *= outputDims[i];
}
int64_t lastDim = (dim > 0 ? outputDims[dim - 1] : 0);
uint32_t useCoreNum = std::min((uint32_t)std::max<int64_t>(totalBlocks, 1), coreNum);
if (totalBlocks == 0) useCoreNum = 1;
if (useCoreNum == 0) {
OP_LOGE(context, "useCoreNum cannot be zero");
return ge::GRAPH_FAILED;
}
context->SetBlockDim(useCoreNum);
uint64_t availUb = (ubSize * 9 / 10) / 2;
uint32_t alignNum = BLOCK_SIZE / typeLen;
if (alignNum == 0) alignNum = 1;
uint32_t maxTile = static_cast<uint32_t>(availUb / typeLen);
uint32_t alignedMaxTile = (maxTile >= alignNum)
? (maxTile & (~(alignNum - 1)))
: maxTile;
if (alignedMaxTile == 0)
alignedMaxTile = alignNum;
int64_t tileDataNum = 1;
if (lastDim > 0) {
tileDataNum = (lastDim <= alignedMaxTile) ? lastDim : alignedMaxTile;
if (tileDataNum < alignNum)
tileDataNum = alignNum;
}
tiling->outerLoopNum = totalBlocks / useCoreNum;
tiling->tailCoreNum = totalBlocks % useCoreNum;
tiling->tileDataNum = tileDataNum;
tiling->lastDimSize = lastDim;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingParseForSliceV3([[maybe_unused]] gert::TilingParseContext* context)
{
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(SliceV3).Tiling(SliceV3TilingFunc).TilingParse<SliceV3CompileInfo>(TilingParseForSliceV3);
}