* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file rasterizer_tiling.cpp
* \brief
*/
#include "log/log.h"
#include "util/math_util.h"
#include "op_host/tiling_util.h"
#include "op_host/tiling_templates_registry.h"
#include "rasterizer_tiling.h"
namespace optiling {
static constexpr int64_t DIM_NUM2 = 2;
static constexpr int64_t DIM_NUM3 = 3;
static constexpr int64_t DIM_VAL3 = 3;
static constexpr int64_t DIM_VAL4 = 4;
static constexpr size_t WORK_SPACE_SIZE = 32 * 1024 * 1024;
static constexpr size_t MAX_PROC_ELENUM = 1920;
static constexpr size_t RSV = 64;
static constexpr size_t BUFFER_NUM = 2;
static constexpr uint32_t IDX_0 = 0;
static constexpr uint32_t IDX_1 = 1;
static constexpr uint32_t IDX_2 = 2;
static constexpr uint32_t IDX_3 = 3;
static constexpr uint32_t MAX_SHAPE_VALUE = 4096;
static constexpr uint32_t MIN_SHAPE_VALUE = 0;
static ge::graphStatus CheckParam(gert::TilingContext* context, const gert::StorageShape* vShape,
const gert::StorageShape* fShape)
{
const gert::StorageShape* findicesShape = context->GetOutputShape(IDX_0);
const gert::StorageShape* baryShape = context->GetOutputShape(IDX_1);
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const uint32_t* width = attrs->GetAttrPointer<uint32_t>(IDX_0);
const uint32_t* height = attrs->GetAttrPointer<uint32_t>(IDX_1);
const uint32_t* useDepthPrior = attrs->GetAttrPointer<uint32_t>(IDX_3);
OP_CHECK_NULL_WITH_CONTEXT(context, vShape);
OP_CHECK_NULL_WITH_CONTEXT(context, fShape);
OP_CHECK_NULL_WITH_CONTEXT(context, findicesShape);
OP_CHECK_NULL_WITH_CONTEXT(context, baryShape);
auto vDimNum = vShape->GetStorageShape().GetDimNum();
auto fDimNum = fShape->GetStorageShape().GetDimNum();
auto findicesDimNum = findicesShape->GetStorageShape().GetDimNum();
auto barycentricDimNum = baryShape->GetStorageShape().GetDimNum();
OP_CHECK_IF(vDimNum != DIM_NUM2 || fDimNum != DIM_NUM2 || findicesDimNum != DIM_NUM2,
OP_LOGE(context, "v/f/findices dim num is not 2, please check"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(barycentricDimNum != DIM_NUM3 ,
OP_LOGE(context, "barycentric dim num is not 3, please check"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(fShape->GetStorageShape().GetDim(IDX_1) != DIM_VAL3
|| baryShape->GetStorageShape().GetDim(IDX_2) != DIM_VAL3,
OP_LOGE(context, "dim1 of f and dim2 of barycentric should be 3, please check"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(vShape->GetStorageShape().GetDim(IDX_1) != DIM_VAL4,
OP_LOGE(context, "dim1 of v should be 4, please check"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(*height > MAX_SHAPE_VALUE || *width > MAX_SHAPE_VALUE
|| *height == MIN_SHAPE_VALUE || *width == MIN_SHAPE_VALUE,
OP_LOGE(context, "height/width should be no greater than 4096 and greater than 0, please check"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
findicesShape->GetStorageShape().GetDim(IDX_0) != baryShape->GetStorageShape().GetDim(IDX_0)
|| findicesShape->GetStorageShape().GetDim(IDX_1) != baryShape->GetStorageShape().GetDim(IDX_1)
|| findicesShape->GetStorageShape().GetDim(IDX_0) != *height
|| findicesShape->GetStorageShape().GetDim(IDX_1) != *width,
OP_LOGE(context, "dim0 and dim1 of findices/barycentric should be equal to height and width, please check"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(*useDepthPrior != 0,
OP_LOGE(context, "useDepthPrior should be 0, please check"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
void FillTilingData(gert::TilingContext* context, const gert::StorageShape* vShape,
const gert::StorageShape* fShape)
{
uint32_t numFaces = fShape->GetStorageShape().GetDim(IDX_0);
uint32_t numVertices = vShape->GetStorageShape().GetDim(IDX_0);
auto attrs = context->GetAttrs();
const uint32_t* width = attrs->GetAttrPointer<uint32_t>(IDX_0);
const uint32_t* height = attrs->GetAttrPointer<uint32_t>(IDX_1);
const float* occlusionTruncation = attrs->GetAttrPointer<float>(IDX_2);
const uint32_t* useDepthPrior = attrs->GetAttrPointer<uint32_t>(IDX_3);
RasterizerTilingData tiling;
tiling.set_numFaces(numFaces);
tiling.set_numVertices(numVertices);
tiling.set_height(*height);
tiling.set_width(*width);
tiling.set_occlusionTruncation(*occlusionTruncation);
tiling.set_useDepthPrior(*useDepthPrior);
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
auto platformInfoPtr = context->GetPlatformInfo();
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
auto aivCoreNum = ascendcPlatform.GetCoreNumAiv();
size_t *workSpaceSize = context->GetWorkspaceSizes(1);
workSpaceSize[0] = static_cast<size_t>(*height) * static_cast<size_t>(*width)
* (sizeof(int32_t) + sizeof(float)) * aivCoreNum
+ DIM_VAL3 * MAX_PROC_ELENUM * sizeof(uint32_t)
+ BUFFER_NUM * RSV * sizeof(uint32_t) + WORK_SPACE_SIZE;
context->SetTilingKey(1);
context->SetBlockDim(aivCoreNum);
}
static ge::graphStatus RasterizerTilingFunc(gert::TilingContext* context)
{
OP_LOGI(context, "Enter in RasterizerTilingFunc");
const gert::StorageShape* vShape = context->GetInputShape(IDX_0);
const gert::StorageShape* fShape = context->GetInputShape(IDX_1);
OP_CHECK_IF(CheckParam(context, vShape, fShape) != ge::GRAPH_SUCCESS, OP_LOGE(context, "CheckInputShapes is failed"),
return ge::GRAPH_FAILED);
FillTilingData(context, vShape, fShape);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingParseForRasterizer([[maybe_unused]] gert::TilingParseContext* context)
{
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(Rasterizer).Tiling(RasterizerTilingFunc).TilingParse<RasterizerCompileInfo>(TilingParseForRasterizer);
}