* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License")
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file broadcast_to_infershape.cpp
* \brief
*/
#include "log/log.h"
#include "register/op_impl_registry.h"
#include "op_util.h"
using namespace ge;
namespace ops {
constexpr size_t INPUT_INDEX_SHAPE = 1;
template <typename T>
static void GetConstValueToShape(const gert::Tensor* tensor, size_t size, gert::Shape* shape) {
const T* value = tensor->GetData<T>();
shape->SetDimNum(size);
for (size_t i = 0; i < size; i++) {
shape->SetDim(i, value[i]);
}
}
static ge::graphStatus BroadcastToInferShapeWithShapeValues(const gert::InferShapeContext* context,
const gert::Shape* x_shape,
const gert::ContinuousVector* shape_attr,
gert::Shape* out_shape) {
OP_LOGD(context->GetNodeName(), "Begin to do BroadcastToInfershape.");
const int64_t* shape_value = static_cast<const int64_t*>(shape_attr->GetData());
OP_CHECK_NULL_WITH_CONTEXT(context, shape_value);
const size_t dim_num = shape_attr->GetSize();
OP_LOGD(context->GetNodeName(), "input x = %s", Ops::Base::ToString(*x_shape).c_str());
out_shape->SetDimNum(dim_num);
for (size_t i = 0; i < dim_num; ++i) {
out_shape->SetDim(i, shape_value[i]);
}
OP_LOGD(context->GetNodeName(), "output y = %s", Ops::Base::ToString(*out_shape).c_str());
OP_LOGD(context->GetNodeName(), "End to do BroadcastToInfershape.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferShape4BroadcastTo(gert::InferShapeContext* context) {
auto x_shape = context->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, x_shape);
auto out_shape = context->GetOutputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, out_shape);
auto shape_tensor = context->GetInputTensor(INPUT_INDEX_SHAPE);
if (shape_tensor == nullptr) {
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const gert::ContinuousVector* shape_attr = attrs->GetAttrPointer<gert::ContinuousVector>(0);
OP_CHECK_NULL_WITH_CONTEXT(context, shape_attr);
return BroadcastToInferShapeWithShapeValues(context, x_shape, shape_attr, out_shape);
}
auto shape_size = static_cast<size_t>(shape_tensor->GetShapeSize());
OP_CHECK_IF(
shape_size < x_shape->GetDimNum(),
OP_LOGE(context->GetNodeName(), "%s",
ConcatString("shape size ", shape_size, " cannot be less than x size ",
x_shape->GetDimNum(), ", error!").c_str()),
return ge::GRAPH_FAILED);
out_shape->SetDimNum(shape_size);
OP_LOGD(context->GetNodeName(), "shape_size is %zu", shape_size);
DataType data_type = shape_tensor->GetDataType();
OP_CHECK_IF(
(data_type != DT_INT32) && (data_type != DT_INT64),
OP_LOGE(
context->GetNodeName(), "%s",
ConcatString("shape's dtype ", Ops::Base::ToString(data_type), " must be in (int32,int64)!").c_str()),
return ge::GRAPH_FAILED);
size_t diff = shape_size - x_shape->GetDimNum();
if (data_type == DT_INT32) {
GetConstValueToShape<int32_t>(shape_tensor, shape_size, out_shape);
} else {
GetConstValueToShape<int64_t>(shape_tensor, shape_size, out_shape);
}
for (size_t i = 0; i < shape_size; i++) {
if (out_shape->GetDim(i) == -1) {
if (i >= diff) {
out_shape->SetDim(i, x_shape->GetDim(i - diff));
} else {
out_shape->SetDim(i, 1);
}
}
if (i < diff) {
continue;
}
OP_CHECK_IF(
(out_shape->GetDim(i) != x_shape->GetDim(i - diff)) && (1 != x_shape->GetDim(i - diff)),
OP_LOGE(
context->GetNodeName(), "%s",
ConcatString(Ops::Base::ToString(*x_shape).c_str(), " can not broadcast to ", Ops::Base::ToString(*out_shape).c_str()).c_str()),
return ge::GRAPH_FAILED);
}
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(BroadcastTo)
.InferShape(InferShape4BroadcastTo)
.InputsDataDependency({INPUT_INDEX_SHAPE})
.PrivateAttr("_shape_values", std::vector<int64_t>{});
}