* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file random_tiling_arch35.cpp
* \brief
*/
#include <algorithm>
#include "platform/platform_infos_def.h"
#include "op_host/math_tiling_templates_registry.h"
#include "log/log.h"
#include "util/math_util.h"
#include "op_common/op_host/util/platform_util.h"
#include "random_tiling_base.h"
#include "random_tiling_arch35.h"
#include <cstdlib>
namespace optiling {
static constexpr int64_t MIN_CORE_PRO = 256;
int64_t TensorSliceState::GetMaxOffsetBytes() const
{
int64_t maxOffsetBytes = elementSize;
for (int64_t dim = 0; dim < ndim; dim++) {
maxOffsetBytes += (shape[dim] - 1) * strides[dim] * elementSize;
}
return maxOffsetBytes;
}
bool TensorSliceState::Is32bitIndexable() const
{
int64_t maxValue = static_cast<int64_t>(INDEX_32BIT_LIMIT);
if (numel > maxValue)
return false;
if (GetMaxOffsetBytes() > maxValue)
return false;
return true;
}
int64_t TensorSliceState::GetDimToSplit() const
{
int64_t maxExtent = -1, dimToSplit = -1;
for (int64_t dim = ndim - 1; dim >= 0; dim--) {
if (shape[dim] == 0)
continue;
int64_t extent = (shape[dim] - 1) * std::abs(strides[dim]) * elementSize;
if (extent > maxExtent) {
maxExtent = extent;
dimToSplit = dim;
}
}
return dimToSplit;
}
void TensorSliceState::ReduceDimExtent(int64_t dim, int64_t start, int64_t size)
{
gmOffset += start * strides[dim];
shape[dim] = size;
numel = 1;
for (int64_t d = 0; d < ndim; d++)
numel *= shape[d];
}
void TensorSliceState::PartitionDim(int64_t dim, TensorSliceState& other)
{
int64_t copySize = shape[dim] / 2;
int64_t thisSize = shape[dim] - copySize;
for (int64_t d = 0; d < ndim; d++) {
other.shape[d] = shape[d];
other.strides[d] = strides[d];
}
other.ndim = ndim;
other.elementSize = elementSize;
other.gmOffset = gmOffset;
other.ReduceDimExtent(dim, 0, copySize);
ReduceDimExtent(dim, copySize, thisSize);
}
ge::graphStatus InitTensorSliceState(
TensorSliceState& state, const gert::Shape& outputTensor, int64_t outputSize, ge::DataType outputDtype)
{
state.ndim = static_cast<int64_t>(outputTensor.GetDimNum());
state.numel = outputSize;
state.gmOffset = 0;
state.elementSize = ge::GetSizeByDataType(outputDtype);
for (int64_t dim = 0; dim < state.ndim && dim < MAX_TENSOR_DIMS; dim++) {
state.shape[dim] = outputTensor.GetDim(static_cast<size_t>(dim));
}
if (state.ndim > 0) {
state.strides[state.ndim - 1] = 1;
for (int64_t dim = state.ndim - 2; dim >= 0; dim--) {
state.strides[dim] = state.shape[dim + 1] * state.strides[dim + 1];
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CalcSplitBlocks(TensorSliceState& state, RandomUnifiedSimtTilingDataStruct& simtTilingData)
{
state.gmOffset = 0;
simtTilingData.splitBlockCount = 0;
if (state.Is32bitIndexable()) {
simtTilingData.splitBlocks[0].numel = state.numel;
simtTilingData.splitBlocks[0].gmOffset = state.gmOffset;
simtTilingData.splitBlockCount = 1;
} else {
TensorSliceState stack[MAX_SPLIT_BLOCKS];
int32_t top = 0;
stack[top++] = state;
while (top > 0 && simtTilingData.splitBlockCount < MAX_SPLIT_BLOCKS) {
TensorSliceState cur = stack[--top];
if (cur.Is32bitIndexable()) {
int64_t blockIdx = simtTilingData.splitBlockCount;
simtTilingData.splitBlocks[blockIdx].numel = cur.numel;
simtTilingData.splitBlocks[blockIdx].gmOffset = cur.gmOffset;
simtTilingData.splitBlockCount++;
} else {
int64_t dim = cur.GetDimToSplit();
TensorSliceState other;
cur.PartitionDim(dim, other);
if (static_cast<int32_t>(MAX_SPLIT_BLOCKS) - top < 2) {
return ge::GRAPH_FAILED;
}
stack[top++] = cur;
stack[top++] = other;
}
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CalcExecutionPoliciesForBlocks(RandomUnifiedSimtTilingDataStruct& simtTilingData, uint32_t unrollFactor)
{
int64_t accumulatedOffset = simtTilingData.offset;
if (simtTilingData.splitBlockCount > 1) {
int64_t totalNumel = simtTilingData.outputSize;
int64_t grid = (totalNumel + SIMT_THREAD_GROUP_SIZE - 1) / SIMT_THREAD_GROUP_SIZE;
int64_t blocksPerAic = MAX_THREADS_PER_AIC / SIMT_THREAD_GROUP_SIZE;
grid = (AIC_CLUSTER_COUNT * blocksPerAic < grid) ? AIC_CLUSTER_COUNT * blocksPerAic : grid;
int64_t totalThreads = grid * SIMT_THREAD_GROUP_SIZE;
int64_t threadsPerRound = totalThreads * unrollFactor;
int64_t roundsNeeded = (totalNumel + threadsPerRound - 1) / threadsPerRound;
int64_t counterOffset = roundsNeeded * MAX_PRNG_COUNTER_INCR;
accumulatedOffset += counterOffset;
}
for (int64_t i = 0; i < simtTilingData.splitBlockCount; i++) {
int64_t numel = simtTilingData.splitBlocks[i].numel;
int64_t grid = (numel + SIMT_THREAD_GROUP_SIZE - 1) / SIMT_THREAD_GROUP_SIZE;
int64_t blocksPerAic = MAX_THREADS_PER_AIC / SIMT_THREAD_GROUP_SIZE;
grid = (AIC_CLUSTER_COUNT * blocksPerAic < grid) ? AIC_CLUSTER_COUNT * blocksPerAic : grid;
int64_t totalThreads = grid * SIMT_THREAD_GROUP_SIZE;
int64_t threadsPerRound = totalThreads * unrollFactor;
int64_t roundsNeeded = (numel + threadsPerRound - 1) / threadsPerRound;
int64_t counterOffset = roundsNeeded * MAX_PRNG_COUNTER_INCR;
simtTilingData.splitBlocks[i].kernelOffset = accumulatedOffset;
simtTilingData.splitBlocks[i].grid = grid;
simtTilingData.splitBlocks[i].totalThreads = totalThreads;
accumulatedOffset += counterOffset;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingParseArch35(gert::TilingParseContext* context, const std::string& operatorName)
{
OP_LOGD(context, "Entering RandomTilingArch35 operator name : %s", operatorName);
auto compileInfo = context->GetCompiledInfo<RandomOperatorCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
compileInfo->ubSize = static_cast<int64_t>(ubSizePlatForm);
OP_CHECK_IF(
(compileInfo->totalCoreNum <= 0 || compileInfo->ubSize <= 0),
OP_LOGE(
context, "GetHardwareInfo Failed, vectorCoreNum:%ld, ubSize:%ld.", compileInfo->totalCoreNum,
compileInfo->ubSize),
return ge::GRAPH_FAILED);
OP_LOGD(context, "Get totalCoreNum:%d, ubSize:%ld", compileInfo->totalCoreNum, compileInfo->ubSize);
return ge::GRAPH_SUCCESS;
}
template <typename T>
static inline ge::graphStatus GetIntValue(
const gert::TilingContext* context, const gert::Tensor* constTensor, gert::Shape& constShape)
{
const T* constValue = constTensor->GetData<T>();
OP_CHECK_NULL_WITH_CONTEXT(context, constValue);
const size_t constNum = constTensor->GetShapeSize();
constShape.SetDimNum(0);
for (size_t i = 0; i < constNum; ++i) {
constShape.AppendDim(constValue[i]);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ExtractTensorValue(const gert::TilingContext* context, const int64_t constIdx, gert::Shape& constShape)
{
auto constTensor = context->GetRequiredInputTensor(constIdx);
OP_CHECK_NULL_WITH_CONTEXT(context, constTensor);
auto inputDescPtr = context->GetRequiredInputDesc(constIdx);
OP_CHECK_NULL_WITH_CONTEXT(context, inputDescPtr);
auto constDtype = inputDescPtr->GetDataType();
auto ret = ge::GRAPH_FAILED;
switch (constDtype) {
case ge::DT_INT32:
ret = GetIntValue<int32_t>(context, constTensor, constShape);
break;
case ge::DT_INT64:
ret = GetIntValue<int64_t>(context, constTensor, constShape);
break;
default:
OP_LOGD(
context->GetNodeName(), "ExtractTensorValue only support [int32, int64]. but is %s",
Ops::Base::ToString(constDtype).c_str());
return ge::GRAPH_FAILED;
}
OP_CHECK_IF(
ret != ge::GRAPH_SUCCESS, OP_LOGE(context, "get const value failed, please check."), return ge::GRAPH_FAILED);
OP_LOGI(context->GetNodeName(), "current const value is %s", Ops::Base::ToString(constShape).c_str());
return ge::GRAPH_SUCCESS;
}
RandomTilingArch35::RandomTilingArch35(gert::TilingContext* context, const OpTilingConfig& config)
: context_(context), config_(config), tilingData_{}, simtTilingData_{}
{}
ge::graphStatus RandomTilingArch35::DoTiling()
{
opName_ = context_->GetNodeName();
OP_LOGD(opName_, "Start tiling for op: %s", opName_.c_str());
auto ret = CheckInputsOutputsAndAttrs();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Check inputs/outputs/attrs failed");
return ret;
}
ret = GetPlatformInfo();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Get platform info failed");
return ret;
}
ret = BeforeProcess();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Before process failed");
return ret;
}
ret = config_.kernelMode == RandomKernelMode::SIMD ? FillUnifiedTilingData() : FillUnifiedSimtTilingData();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Fill tiling data failed");
return ret;
}
ret = CalcTilingKeyAndWorkspace();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Calc tiling key/workspace failed");
return ret;
}
ret = UniqueProcess();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Unique process failed");
return ret;
}
ret = WriteBackToContext();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Write tiling data to context failed");
return ret;
}
OP_LOGI(
"RandomTiling", "%s",
(config_.kernelMode == RandomKernelMode::SIMD ? tilingData_.DumpTilingInfo() : simtTilingData_.DumpTilingInfo())
.c_str());
OP_LOGD(opName_, "Tiling success for op: %s", opName_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::CheckInputsOutputsAndAttrs()
{
OP_LOGI(opName_, "TilingContext: %s", RandomUtils::GetTilingContext(context_).c_str());
for (const auto& [idx, rule] : config_.inputCheckRules) {
auto tensorDesc = context_->GetRequiredInputDesc(idx);
OP_CHECK_NULL_WITH_CONTEXT(context_, tensorDesc);
auto inputShape = context_->GetRequiredInputShape(idx);
OP_CHECK_NULL_WITH_CONTEXT(context_, inputShape);
auto storageShape = inputShape->GetStorageShape();
auto ret = CheckTensor(tensorDesc, storageShape, rule, "input_" + std::to_string(idx));
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
}
for (const auto& [idx, rule] : config_.optionalInputCheckRules) {
auto tensorDesc = context_->GetOptionalInputDesc(idx);
auto inputShape = context_->GetOptionalInputShape(idx);
if (tensorDesc != nullptr && inputShape != nullptr) {
auto storageShape = inputShape->GetStorageShape();
auto ret = CheckTensor(tensorDesc, storageShape, rule, "optional_input_" + std::to_string(idx));
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
}
}
for (const auto& [idx, rule] : config_.outputCheckRules) {
auto tensorDesc = context_->GetOutputDesc(idx);
OP_CHECK_NULL_WITH_CONTEXT(context_, tensorDesc);
auto outputShape = context_->GetOutputShape(idx);
OP_CHECK_NULL_WITH_CONTEXT(context_, outputShape);
auto storageShape = outputShape->GetStorageShape();
auto ret = CheckTensor(tensorDesc, storageShape, rule, "output_" + std::to_string(idx));
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
}
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
for (const auto& rule : config_.attrCheckRules) {
auto attrIdx = rule.first;
auto checkFunc = rule.second;
if (!checkFunc(context_)) {
OP_LOGE(context_->GetNodeName(), "Attr %d check failed", attrIdx);
return ge::GRAPH_FAILED;
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::GetPlatformInfo()
{
auto platformInfo = context_->GetPlatformInfo();
if (platformInfo == nullptr) {
auto compileInfo = reinterpret_cast<const RandomOperatorCompileInfo*>(context_->GetCompileInfo());
OP_CHECK_NULL_WITH_CONTEXT(context_, compileInfo);
totalCoreNum_ = static_cast<int64_t>(compileInfo->totalCoreNum);
ubSize_ = compileInfo->ubSize;
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
auto aivNum = ascendcPlatform.GetCoreNumAiv();
OP_CHECK_IF(
(aivNum <= 0), OP_LOGE(opName_, "RandomTilingArch35 fail to get coreNum."), return ge::GRAPH_FAILED);
totalCoreNum_ = aivNum;
uint64_t ubSizePlatForm = 0;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
ubSize_ = ubSizePlatForm;
}
ubSize_ -= config_.DcacheSize;
OP_CHECK_IF(
(ubSize_ <= 0), OP_LOGE(opName_, "ub size less than Dcache Size. please check."), return ge::GRAPH_FAILED);
OP_LOGI(opName_, "RandomTilingArch35::GetPlatformInfo ubSize_=%d, totalCoreNum_=%d", ubSize_, totalCoreNum_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::DoSimtBlockTiling()
{
OP_CHECK_IF(
(totalCoreNum_ <= 0), OP_LOGE(opName_, "totalCoreNum is less than or equal to 0. please check."),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
(config_.coreAlignSize == 0), OP_LOGE(opName_, "coreAlignSize is equal to 0. please check."),
return ge::GRAPH_FAILED);
int64_t avgPerCore = Ops::Base::CeilDiv(simtTilingData_.outputSize, totalCoreNum_);
int64_t numOfPerCore = Ops::Base::CeilAlign(avgPerCore, config_.coreAlignSize);
int64_t usedCoreNum = Ops::Base::CeilDiv(simtTilingData_.outputSize, numOfPerCore);
simtTilingData_.usedCoreNum = std::min(totalCoreNum_, usedCoreNum);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::FillUnifiedSimtTilingData()
{
auto ret = config_.getOutputSize(context_, simtTilingData_.outputSize);
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
OP_CHECK_IF(
(simtTilingData_.outputSize < 0), OP_LOGE(opName_, "outputSize is less than 0. please check."),
return ge::GRAPH_FAILED);
ret = config_.getSeedAndOffset(context_, simtTilingData_.seed, simtTilingData_.offset);
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = DoSimtBlockTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
if (config_.enableSplitBlocks) {
auto outputShape = context_->GetOutputShape(config_.splitOutputIndex);
OP_CHECK_NULL_WITH_CONTEXT(context_, outputShape);
auto outputTensor = outputShape->GetStorageShape();
auto outputDesc = context_->GetOutputDesc(config_.splitOutputIndex);
OP_CHECK_NULL_WITH_CONTEXT(context_, outputDesc);
auto outputDtype = outputDesc->GetDataType();
TensorSliceState state;
InitTensorSliceState(state, outputTensor, simtTilingData_.outputSize, outputDtype);
ret = CalcSplitBlocks(state, simtTilingData_);
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CalcExecutionPoliciesForBlocks(simtTilingData_, config_.unrollFactor);
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::FillUnifiedTilingData()
{
auto ret = config_.getOutputSize(context_, tilingData_.outputSize);
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
OP_CHECK_IF(
(tilingData_.outputSize <= 0), OP_LOGE(opName_, "outputSize is less than or equal to 0. please check."),
return ge::GRAPH_FAILED);
ret = config_.getKeyAndCounter(context_, tilingData_.key, tilingData_.counter);
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = config_.getBufferNum(context_, bufNum_);
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = DoBlockTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = DoUbTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
tilingData_.sharedTmpBufSize = config_.sharedTmpBufSize;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::DoBlockTiling()
{
OP_CHECK_IF(
(totalCoreNum_ <= 0), OP_LOGE(opName_, "totalCoreNum is less than or equal to 0. please check."),
return ge::GRAPH_FAILED);
tilingData_.normalCoreProNum = Ops::Base::CeilDiv(tilingData_.outputSize, totalCoreNum_);
OP_CHECK_IF(
(config_.coreAlignSize == 0), OP_LOGE(opName_, "coreAlignSize is equal to 0. please check."),
return ge::GRAPH_FAILED);
tilingData_.normalCoreProNum =
(tilingData_.normalCoreProNum + config_.coreAlignSize - 1) / config_.coreAlignSize * config_.coreAlignSize;
tilingData_.normalCoreProNum = std::max(tilingData_.normalCoreProNum, MIN_CORE_PRO);
tilingData_.usedCoreNum = Ops::Base::CeilDiv(tilingData_.outputSize, tilingData_.normalCoreProNum);
tilingData_.tailCoreProNum = tilingData_.outputSize - tilingData_.normalCoreProNum * (tilingData_.usedCoreNum - 1);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::DoUbTiling()
{
ubSize_ -= config_.sharedTmpBufSize;
OP_CHECK_IF((bufNum_ == 0), OP_LOGE(opName_, "bufNum_ is equal to 0. please check."), return ge::GRAPH_FAILED);
tilingData_.singleBufferSize = ubSize_ / bufNum_;
if (config_.ubAlignSize == 0) {
auto ubBlockSize = Ops::Base::GetUbBlockSize(context_);
OP_CHECK_IF(
(ubBlockSize == 0), OP_LOGE(opName_, "ubBlockSize is equal to 0. please check."), return ge::GRAPH_FAILED);
tilingData_.singleBufferSize = tilingData_.singleBufferSize / ubBlockSize * ubBlockSize;
} else {
tilingData_.singleBufferSize = tilingData_.singleBufferSize / config_.ubAlignSize * config_.ubAlignSize;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::CalcTilingKeyAndWorkspace()
{
constexpr uint64_t DEFAULT_TILING_KEY = 100;
workspaceSize_ = 0;
tilingKey_ = DEFAULT_TILING_KEY;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::WriteBackToContext()
{
size_t* workspaces = context_->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, workspaces);
workspaces[0] = workspaceSize_;
context_->SetBlockDim(
config_.kernelMode == RandomKernelMode::SIMD ? tilingData_.usedCoreNum : simtTilingData_.usedCoreNum);
context_->SetScheduleMode(config_.isNeedSyncAll);
context_->SetTilingKey(tilingKey_);
if (config_.DcacheSize != 0) {
auto res = context_->SetLocalMemorySize(ubSize_);
OP_CHECK_IF(
(res != ge::GRAPH_SUCCESS),
OP_LOGE(opName_, "SetLocalMemorySize ubSize = %ld failed.", static_cast<int64_t>(ubSize_)),
return ge::GRAPH_FAILED);
}
if (config_.kernelMode == RandomKernelMode::SIMD) {
auto* tilingData = context_->GetTilingData<RandomUnifiedTilingDataStruct>();
OP_CHECK_IF(tilingData == nullptr, OP_LOGE(opName_, "tilingData ptr is null"), return ge::GRAPH_FAILED);
*tilingData = tilingData_;
} else {
auto* tilingData = context_->GetTilingData<RandomUnifiedSimtTilingDataStruct>();
OP_CHECK_IF(tilingData == nullptr, OP_LOGE(opName_, "simtTilingData ptr is null"), return ge::GRAPH_FAILED);
*tilingData = simtTilingData_;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RandomTilingArch35::CheckTensor(
const gert::CompileTimeTensorDesc* tensorDesc, const gert::Shape& tensorShape, const TensorCheckRule& rule,
const std::string& tensorName)
{
if (!rule.dtypeSet.empty() && rule.dtypeSet.count(tensorDesc->GetDataType()) == 0) {
OP_LOGE(
context_->GetNodeName(), "Tensor %s dtype %d not in allowed set", tensorName.c_str(),
tensorDesc->GetDataType());
return ge::GRAPH_FAILED;
}
auto shapeSize = tensorShape.GetShapeSize();
if (rule.shapeSize != -1 && shapeSize != rule.shapeSize) {
OP_LOGE(
context_->GetNodeName(), "Tensor %s shape size %ld not match required %ld", tensorName.c_str(), shapeSize,
rule.shapeSize);
return ge::GRAPH_FAILED;
}
auto dimNum = tensorShape.GetDimNum();
if (!rule.dimNumSet.empty() && rule.dimNumSet.count(dimNum) == 0) {
OP_LOGE(context_->GetNodeName(), "Tensor %s dim num %lu not in allowed set", tensorName.c_str(), dimNum);
return ge::GRAPH_FAILED;
}
if (rule.customCheck && !rule.customCheck(context_)) {
OP_LOGE(context_->GetNodeName(), "Tensor %s custom check failed", tensorName.c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
}