* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "register/op_impl_registry.h"
#include "common/checker.h"
#include "graph/utils/type_utils.h"
using namespace gert;
namespace ops {
template <typename T1, typename T2>
bool IsDimValid(const T1 shape_size, const T2 dim_value) {
int64_t minimum_num = static_cast<int64_t>(shape_size) * (-1);
int64_t maximum_num = static_cast<int64_t>(shape_size) - 1;
return static_cast<int64_t>(dim_value) >= minimum_num && static_cast<int64_t>(dim_value) <= maximum_num;
}
template <typename T>
ge::graphStatus ReduceDimsWithKeepDims(const gert::Shape* x_shape, const T* axes_dims, int32_t axes_size,
gert::Shape* output_shape) {
T dim_num = x_shape->GetDimNum();
*output_shape = *x_shape;
for (int32_t i = 0; i < axes_size; i++) {
if (IsDimValid(dim_num, axes_dims[i])) {
GELOGE(ge::PARAM_INVALID, "axes_dims[%d] is invalid", i);
return ge::PARAM_INVALID;
}
T dim = axes_dims[i] < 0 ? axes_dims[i] + dim_num : axes_dims[i];
output_shape->SetDim(dim, 1);
}
GELOGD("ReduceDimsWithKeepDims is SUCCESS");
return ge::GRAPH_SUCCESS;
}
template <typename T>
ge::graphStatus ReduceDimsWithoutKeepDims(const gert::Shape* x_shape, const T* axes_dims, int32_t axes_size,
gert::Shape* output_shape) {
T dim_num = x_shape->GetDimNum();
output_shape->SetDimNum(0);
for (T j = 0; j < dim_num; j++) {
bool reduce_flag = false;
for (int32_t i = 0; i < axes_size; i++) {
if (IsDimValid(dim_num, axes_dims[i])) {
GELOGE(ge::PARAM_INVALID, "axes_dims[%d] is invalid", i);
return ge::PARAM_INVALID;
}
T dim = axes_dims[i] < 0 ? axes_dims[i] + dim_num : axes_dims[i];
if (dim == j) {
reduce_flag = true;
break;
}
}
if (!reduce_flag) {
output_shape->AppendDim(x_shape->GetDim(j));
}
}
GELOGD("ReduceDimsWithoutKeepDims is SUCCESS");
return ge::GRAPH_SUCCESS;
}
template <typename T>
ge::graphStatus ReduceDims(const gert::Shape* x_shape, const gert::Tensor* axes_tensor, int32_t axes_size,
const bool keep_dims, gert::Shape* output_shape) {
const T* axes_dims = axes_tensor->GetData<T>();
if (keep_dims) {
return ReduceDimsWithKeepDims<T>(x_shape, axes_dims, axes_size, output_shape);
}
return ReduceDimsWithoutKeepDims<T>(x_shape, axes_dims, axes_size, output_shape);
}
ge::graphStatus InferShape4ReduceCommon(InferShapeContext* context) {
auto in_shape = context->GetInputShape(0);
GE_ASSERT_NOTNULL(in_shape);
auto axes_tensor = context->GetInputTensor(1);
GE_ASSERT_NOTNULL(axes_tensor);
auto out_shape = context->GetOutputShape(0);
GE_ASSERT_NOTNULL(out_shape);
auto attrs = context->GetAttrs();
GE_ASSERT_NOTNULL(attrs);
const bool* keep_dims = attrs->GetAttrPointer<bool>(0);
GE_ASSERT_NOTNULL(keep_dims);
auto axes_size = static_cast<int32_t>(axes_tensor->GetShapeSize());
auto dtype = axes_tensor->GetDataType();
GE_ASSERT(dtype == ge::DT_INT32 || dtype == ge::DT_INT64,
"axes datatype ", ge::TypeUtils::DataTypeToSerialString(dtype), " must in (int32, int64)");
if (dtype == ge::DT_INT32) {
return ReduceDims<int32_t>(in_shape, axes_tensor, axes_size, *keep_dims, out_shape);
}
return ReduceDims<int64_t>(in_shape, axes_tensor, axes_size, *keep_dims, out_shape);
}
ge::graphStatus TilingForReduceSum(TilingContext *context) {
(void) context;
return ge::GRAPH_SUCCESS;
}
struct StubReduceSumTilingData {
uint64_t tiling_data[8];
};
ge::graphStatus TilingParseForReduceSum(KernelContext *context) {
return ge::GRAPH_SUCCESS;
}
IMPL_OP(ReduceSum)
.InferShape(InferShape4ReduceCommon)
.InputsDataDependency({1})
.Tiling(TilingForReduceSum)
.TilingParse<StubReduceSumTilingData>(TilingParseForReduceSum);;
}