* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License")
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file select_simt_tiling.cpp
* \brief select_simt_tiling source file
*/
#include "select_simt_tiling.h"
#include <graph/utils/type_utils.h>
#include "select_tiling.h"
#include "atvoss/broadcast/broadcast_tiling.h"
#include "math/select/op_kernel/arch35/select_struct.h"
#include "register/op_def_registry.h"
#include "util/math_util.h"
#include "op_host/math_tiling_templates_registry.h"
using namespace AscendC;
using namespace ge;
namespace optiling {
static constexpr uint64_t SELECT_COMMON_TILING_PRIORITY = 0;
static constexpr int64_t NUM_TWO = 2;
static constexpr int32_t CONDITION_IDX = 0;
#ifdef DAVID_FPGA
const static int64_t SMALL_CASE_THREAD_NUM = 64;
#else
const static int64_t SMALL_CASE_THREAD_NUM = 128;
#endif
static constexpr int32_t X1_IDX = 1;
static constexpr int32_t X2_IDX = 2;
static constexpr int32_t MIN_SIZE = 512;
static constexpr int32_t DCACHE_SIZE = 128 * 1024;
static constexpr uint64_t WORKSPACE_SIZE = 32;
static constexpr uint64_t BLOCK_NUM = 1;
static constexpr int64_t INPUT_DTYPE_B64 = 8;
static constexpr int64_t INPUT_DTYPE_B32 = 4;
static constexpr int64_t INPUT_DTYPE_B16 = 2;
static constexpr int64_t TILING_KEY = 999;
static constexpr int64_t USER_DEF = 1;
ge::graphStatus SelectSimtTiling::GetShapeAttrsInfo()
{
OP_CHECK_IF(SelectCheckInputDtype(context_) != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "SelectCheckInputDtype error!"),
return ge::GRAPH_FAILED);
vector<gert::Shape> inputShapes;
OP_CHECK_IF(InferSelectShape(context_, inputShapes) != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "InferSelectShape error!"),
return ge::GRAPH_FAILED);
conditionShape_ = inputShapes[CONDITION_IDX];
x1Shape_ = inputShapes[X1_IDX];
x2Shape_ = inputShapes[X2_IDX];
return ge::GRAPH_SUCCESS;
}
bool SelectSimtTiling::IsMatchAB()
{
int64_t pos = -1;
int64_t conditionDimNum = conditionShape_.GetDimNum();
for (int64_t i = 0; i < conditionDimNum; i++) {
if (conditionShape_.GetDim(i) == 1){
pos = i;
break;
}
aSize_ *= conditionShape_.GetDim(i);
}
if (pos < 0 || pos >= conditionDimNum){
return false;
}
for (int64_t i = pos; i < conditionDimNum; i++) {
if (conditionShape_.GetDim(i) != 1){
return false;
}
bSize_ *= x1Shape_.GetDim(i);
}
return true;
}
bool SelectSimtTiling::XDtypeImprove()
{
auto x1Dtype = context_->GetInputDesc(X1_IDX)->GetDataType();
int64_t xDtypeSize = ge::GetSizeByDataType(x1Dtype);
if (bSize_ * xDtypeSize < MIN_SIZE) {
return false;
}
if (xDtypeSize == INPUT_DTYPE_B64) {
return true;
}
if ((xDtypeSize < INPUT_DTYPE_B64) && (bSize_ * xDtypeSize % INPUT_DTYPE_B64) == 0) {
OP_LOGD(context_->GetNodeName(), "XDtypeImprove lastAxisBytes %ld, improve to INPUT_DTYPE_B64", bSize_);
bSize_ /= (INPUT_DTYPE_B64 / xDtypeSize);
return true;
}
return false;
}
bool SelectSimtTiling::IsCapable()
{
if (!IsMatchAB()){
return false;
}
if (!XDtypeImprove()){
return false;
}
return true;
}
ge::graphStatus SelectSimtTiling::DoOpTiling()
{
ySize_ = aSize_ * bSize_;
SelectSimtTilingData* tilingData = context_->GetTilingData<SelectSimtTilingData>();
tilingData->aSize = aSize_;
tilingData->bSize = bSize_;
while ((threadNum_ >= NUM_TWO * SMALL_CASE_THREAD_NUM) && (Ops::Base::CeilDiv(ySize_, threadNum_) < (aivNum_ / NUM_TWO))) {
threadNum_ = threadNum_ / NUM_TWO;
}
tilingData->threadNum = threadNum_;
int64_t perCoreElements = Ops::Base::CeilDiv(ySize_, aivNum_);
if (ySize_ < threadNum_) {
tilingData->needCoreNum = 1;
tilingData->perCoreElements = ySize_;
tilingData->lastCoreElements = ySize_;
needCoreNum_ = 1;
} else {
perCoreElements = (perCoreElements + threadNum_ - 1) / threadNum_ * threadNum_;
needCoreNum_ = Ops::Base::CeilDiv(ySize_, perCoreElements);
int64_t lastCoreElements = ySize_ - perCoreElements * (needCoreNum_ - 1);
tilingData->needCoreNum = needCoreNum_;
tilingData->perCoreElements = perCoreElements;
tilingData->lastCoreElements = lastCoreElements;
}
context_->SetBlockDim(needCoreNum_);
tilingKey = GET_TPL_TILING_KEY(TILING_KEY, USER_DEF);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus SelectSimtTiling::DoLibApiTiling()
{
return ge::GRAPH_SUCCESS;
}
uint64_t SelectSimtTiling::GetTilingKey() const
{
return tilingKey;
}
ge::graphStatus SelectSimtTiling::GetWorkspaceSize()
{
size_t* workspaces = context_->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, workspaces);
workspaces[0] = WORKSPACE_SIZE;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus SelectSimtTiling::PostTiling()
{
context_->SetLocalMemorySize(static_cast<uint32_t>(ubSize_ - DCACHE_SIZE));
return ge::GRAPH_SUCCESS;
}
ge::graphStatus SelectSimtTiling::GetPlatformInfo()
{
auto platformInfo = context_->GetPlatformInfo();
if (platformInfo == nullptr) {
auto compileInfoPtr = reinterpret_cast<const SelectCompileInfo*>(context_->GetCompileInfo());
OP_CHECK_IF(compileInfoPtr == nullptr, OP_LOGE(context_, "compile info is null"),
return ge::GRAPH_FAILED);
aivNum_ = compileInfoPtr->coreNum;
ubSize_ = compileInfoPtr->ubSize;
OP_LOGD(context_->GetNodeName(), "Get ubSize form compileInfo is: %ld", ubSize_);
OP_LOGD(context_->GetNodeName(), "Get aivNum form compileInfo is: %ld", aivNum_);
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
aivNum_ = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSizePlatform;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatform);
ubSize_ = static_cast<int64_t>(ubSizePlatform);
OP_LOGD(context_->GetNodeName(), "Get ubSize form ascendcPlatform is: %ld", ubSize_);
OP_LOGD(context_->GetNodeName(), "Get aivNum form ascendcPlatform is: %ld", aivNum_);
}
aicoreParams_.numBlocks = aivNum_;
return ge::GRAPH_SUCCESS;
}
REGISTER_OPS_TILING_TEMPLATE(Select, SelectSimtTiling, SELECT_COMMON_TILING_PRIORITY);
}