* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file grid_sample_tiling_arch35.cpp
* \brief grid_sample_tiling_arch35 impl info
*/
#include "grid_sample_tiling_arch35.h"
#include "grid_sample_tiling.h"
namespace optiling {
static constexpr uint32_t VF_MAX_THREAD_NUM_2D = 1024;
static constexpr uint32_t VF_MAX_THREAD_NUM_3D = 512;
static constexpr uint32_t SIMT_COMMON_2D_TILING_KEY = 1000;
static constexpr uint32_t SIMT_COMMON_3D_TILING_KEY = 2000;
static constexpr uint32_t DCACHE_SIZE = 32 * 1024;
static constexpr uint32_t REGBASE_CCEC_CACHE_SIZE = 8 * 1024;
static constexpr uint32_t MAX_OUTPUT_SIZE = 2147483647;
static constexpr uint32_t USE_UINT64_TYPE = 100;
static const std::set<ge::DataType> supportDtype = {ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16};
bool GridSampleArch35Tiling::IsCapable()
{
return true;
}
ge::graphStatus GridSampleArch35Tiling::DoOpTiling()
{
tilingData.set_coreNumVar(coreNumVar);
tilingData.set_outD(outD);
tilingData.set_outH(outH);
tilingData.set_outW(outW);
tilingData.set_inN(inN);
tilingData.set_inC(inC);
tilingData.set_inD(inD);
tilingData.set_inH(inH);
tilingData.set_inW(inW);
tilingData.set_interpolationMode(interpolationMode);
tilingData.set_channelLast(channelLast);
tilingData.set_paddingMode(paddingMode);
tilingData.set_alignCorners(alignCorners);
int64_t outputD = outD == 0 ? 1 : outD;
outputSize = inN * inC * outputD * outH * outW;
inputSize = inN * inC * inH * inW;
gridSize = inN * outH * outW;
if (dimension == 0) {
gridSize = gridSize * DIM_2;
} else {
gridSize = gridSize * outD * DIM_3;
inputSize = inputSize * inD;
}
int32_t threadNum = VF_MAX_THREAD_NUM_2D;
if (outputSize > MAX_OUTPUT_SIZE || inputSize > MAX_OUTPUT_SIZE || gridSize > MAX_OUTPUT_SIZE || dimension != 0) {
threadNum = VF_MAX_THREAD_NUM_3D;
}
needCoreNum = (outputSize + threadNum - 1) / threadNum;
needCoreNum = std::min(needCoreNum, coreNumVar);
tilingData.set_needCoreNum(needCoreNum);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GridSampleArch35Tiling::GetWorkspaceSize()
{
workspaceSize_ = SIZE_16 * LENGTH_1024 * LENGTH_1024;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GridSampleArch35Tiling::PostTiling()
{
context_->SetBlockDim(tilingData.get_needCoreNum());
size_t* workspaces = context_->GetWorkspaceSizes(1);
workspaces[0] = workspaceSize_;
gert::TilingData* rawTilingData = context_->GetRawTilingData();
OP_CHECK_IF(
rawTilingData == nullptr, OP_LOGE(context_->GetNodeType(), "GetRawTilingData failed."),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
tilingData.GetDataSize() > rawTilingData->GetCapacity(),
OP_LOGE(
context_, "actual tiling data size %zu > context tiling data size %zu", tilingData.GetDataSize(),
rawTilingData->GetCapacity()),
return ge::GRAPH_FAILED);
tilingData.SaveToBuffer(rawTilingData->GetData(), rawTilingData->GetCapacity());
rawTilingData->SetDataSize(tilingData.GetDataSize());
return ge::GRAPH_SUCCESS;
}
uint64_t GridSampleArch35Tiling::GetTilingKey() const
{
uint32_t tilingKey = SIMT_COMMON_2D_TILING_KEY;
if (dimension != 0) {
tilingKey = SIMT_COMMON_3D_TILING_KEY;
}
if (outputSize > MAX_OUTPUT_SIZE || inputSize > MAX_OUTPUT_SIZE || gridSize > MAX_OUTPUT_SIZE) {
tilingKey += USE_UINT64_TYPE;
}
return tilingKey;
}
REGISTER_OPS_TILING_TEMPLATE(GridSample, GridSampleArch35Tiling, 2000);
}