* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file random_infershape_base.cpp
* \brief
*/
#include "random_infershape_base.h"
namespace ops {
namespace randomCommon {
template <typename T>
ge::graphStatus HandleShapeTensor(gert::Shape& outputShape, size_t xShapeSize, const T* xShapeData)
{
std::cerr << "[DEBUG] HandleShapeTensor with type: " << typeid(T).name() << ", dims " << xShapeSize << std::endl;
outputShape.SetDimNum(xShapeSize);
for (size_t i = 0U; i < xShapeSize; i++) {
outputShape.SetDim(i, xShapeData[i]);
}
return ge::GRAPH_SUCCESS;
}
bool InferShapeForUnknow(
gert::InferShapeContext* context, const gert::Shape& inShape, gert::Shape& outShape, int64_t& maskIndex,
int64_t& offsetIndex)
{
if (Ops::Base::IsUnknownRank(inShape)) {
Ops::Base::SetUnknownRank(outShape);
if (maskIndex >= 0) {
gert::Shape* maskOutputShape = context->GetOutputShape(maskIndex);
Ops::Base::SetUnknownRank(*maskOutputShape);
}
if (offsetIndex >= 0) {
gert::Shape* offsetOutputShape = context->GetOutputShape(offsetIndex);
Ops::Base::SetUnknownRank(*offsetOutputShape);
}
return true;
}
if (Ops::Base::IsUnknownShape(inShape)) {
Ops::Base::SetUnknownShape(inShape.GetDimNum(), outShape);
if (maskIndex >= 0) {
gert::Shape* maskOutputShape = context->GetOutputShape(maskIndex);
Ops::Base::SetUnknownShape(1, *maskOutputShape);
}
if (offsetIndex >= 0) {
gert::Shape* offsetOutputShape = context->GetOutputShape(offsetIndex);
Ops::Base::SetUnknownShape(1, *offsetOutputShape);
;
}
return true;
}
return false;
}
bool DependencyMode(const gert::Tensor* inTensor, gert::Shape& outShape, size_t xShapeSize)
{
ge::DataType shapeDtype = inTensor->GetDataType();
if (shapeDtype == ge::DT_INT32) {
auto xShapeData = inTensor->GetData<int32_t>();
if (xShapeData == nullptr) {
std::cerr << "[WARN] Empty DT_INT32 shape tensor, set 0-dim output" << std::endl;
Ops::Base::SetUnknownShape(xShapeSize, outShape);
return true;
}
if (HandleShapeTensor<int32_t>(outShape, xShapeSize, xShapeData) == ge::GRAPH_SUCCESS) {
return true;
}
} else if (shapeDtype == ge::DT_INT64) {
auto xShapeData = inTensor->GetData<int64_t>();
if (xShapeData == nullptr) {
std::cerr << "[WARN] Empty DT_INT64 shape tensor, set 0-dim output" << std::endl;
Ops::Base::SetUnknownShape(xShapeSize, outShape);
return true;
}
if (HandleShapeTensor<int64_t>(outShape, xShapeSize, xShapeData) == ge::GRAPH_SUCCESS) {
return true;
}
}
std::cerr << "[ERROR] Unsupported dtype: " << static_cast<int>(shapeDtype) << std::endl;
return false;
}
bool InputAndOutputCheck(
gert::InferShapeContext* context, const std::unordered_map<std::string, size_t>& requiredInputMap,
const std::unordered_map<std::string, size_t>& outputMap, int64_t& maskIndex, int64_t& offsetIndex,
const std::unordered_map<std::string, size_t>& optionalInputMap)
{
OP_LOGD(context->GetNodeName(), "InputAndOutputCheck start");
for (const auto& item : requiredInputMap) {
size_t inputIndex = item.second;
auto inputShape = context->GetRequiredInputShape(inputIndex);
OP_CHECK_NULL_WITH_CONTEXT(context, inputShape);
}
for (const auto& item : optionalInputMap) {
size_t inputIndex = item.second;
auto inputShape = context->GetOptionalInputShape(inputIndex);
if (inputShape != nullptr) {
OP_LOGD(context->GetNodeName(), "Optional input %zu is provided", inputIndex);
}
}
for (const auto& item : outputMap) {
const std::string& outputName = item.first;
size_t outputIndex = item.second;
auto output = context->GetOutputShape(outputIndex);
OP_CHECK_NULL_WITH_CONTEXT(context, output);
if (outputName == "mask") {
maskIndex = outputIndex;
}
if (outputName == "offset") {
offsetIndex = outputIndex;
}
}
OP_LOGD(
context->GetNodeName(), "InputAndOutputCheck end, maskIndex = %ld, offsetIndex = %ld", maskIndex, offsetIndex);
return true;
}
ge::graphStatus CommonInferShape(
gert::InferShapeContext* context, const std::unordered_map<std::string, size_t>& requiredInputMap,
const std::unordered_map<std::string, size_t>& outputMap, int32_t mode,
const std::unordered_map<std::string, size_t>& optionalInputMap)
{
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
int64_t maskIndex = -1;
int64_t offsetIndex = -1;
if (!InputAndOutputCheck(context, requiredInputMap, outputMap, maskIndex, offsetIndex, optionalInputMap)) {
return ge::GRAPH_FAILED;
}
const gert::Shape* inShape = context->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, inShape);
auto* outShape = context->GetOutputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, outShape);
int64_t xShapeSize = inShape->GetShapeSize();
if (InferShapeForUnknow(context, *inShape, *outShape, maskIndex, offsetIndex)) {
OP_LOGI(context->GetNodeName(), "Success unknown shape");
return ge::GRAPH_SUCCESS;
}
if (maskIndex >= 0) {
static constexpr int64_t MASK_ALIGN_SIZE = 128;
static constexpr int64_t MASK_BIT_TO_UINT8 = 8;
gert::Shape* maskOutputShape = context->GetOutputShape(maskIndex);
int64_t maskSize = (xShapeSize + MASK_ALIGN_SIZE - 1) / MASK_ALIGN_SIZE * MASK_ALIGN_SIZE / MASK_BIT_TO_UINT8;
maskOutputShape->SetDimNum(1);
maskOutputShape->SetDim(0, maskSize);
}
if (offsetIndex >= 0) {
gert::Shape* offsetOutputShape = context->GetOutputShape(offsetIndex);
offsetOutputShape->SetDimNum(1);
offsetOutputShape->SetDim(0, 1);
}
if (mode == MODE_NO_DEPENDENCY) {
*outShape = *inShape;
OP_LOGI(context->GetNodeName(), "Success no dependency Mode");
return ge::GRAPH_SUCCESS;
}
if (mode == MODE_DEPENDENCY) {
const gert::Tensor* inTensor = context->GetInputTensor(0);
OP_CHECK_NULL_WITH_CONTEXT(context, inTensor);
if (DependencyMode(inTensor, *outShape, static_cast<size_t>(xShapeSize))) {
return ge::GRAPH_SUCCESS;
}
}
OP_LOGE(context->GetNodeName(), "Failed to infer shape! mode = %d, xShapeSize=%ld", mode, xShapeSize);
return ge::GRAPH_FAILED;
}
}
}