* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "flattenv2_kernel.h"
#include <memory>
#include "host_kernels/kernel_utils.h"
#include "host_kernels/kernel_factory.h"
#include "framework/common/framework_types_internal.h"
namespace ge {
namespace {
const size_t kFirstDataIndex = 0UL;
const size_t kValidSize = 1UL;
void GetAndConvertAxis(const OpDescPtr &op_desc_ptr, int64_t &axis, int64_t &end_axis) {
if (!AttrUtils::GetInt(op_desc_ptr, "axis", axis)) {
axis = 1;
}
if (!AttrUtils::GetInt(op_desc_ptr, "end_axis", end_axis)) {
end_axis = -1;
}
GeTensorDesc x_desc = op_desc_ptr->GetInputDesc("x");
int64_t dim_count = static_cast<int64_t>(x_desc.GetShape().GetDimNum());
if (axis < 0) {
axis += dim_count;
}
if (end_axis < 0) {
end_axis += dim_count;
}
}
bool IsFlattenV2ParamsValid(const OpDescPtr &op_desc_ptr) {
size_t input_size = op_desc_ptr->GetInputsSize();
size_t output_size = op_desc_ptr->GetOutputsSize();
if (input_size != kValidSize || output_size != kValidSize) {
GELOGE(PARAM_INVALID, "input_size or output_size is invalid");
return false;
}
GeTensorDesc x_desc = op_desc_ptr->GetInputDesc("x");
if (KernelUtils::IsUnknownShape(x_desc.GetShape())) {
GELOGE(FAILED, "shape is unknown.");
return false;
}
int64_t axis = 0;
int64_t end_axis = 0;
GetAndConvertAxis(op_desc_ptr, axis, end_axis);
const int64_t dim_num = static_cast<int64_t>(x_desc.GetShape().GetDimNum());
if (axis < 0 || axis >= dim_num) {
GELOGE(PARAM_INVALID, "axis out of range! axis is %ld", axis);
return false;
}
if (end_axis < 0 || end_axis >= dim_num) {
GELOGE(PARAM_INVALID, "end_axis out of range! end_axis is %ld", end_axis);
return false;
}
if (axis > end_axis) {
GELOGE(PARAM_INVALID, "axis after end_axis! axis is %ld, end_axis is %ld", axis, end_axis);
return false;
}
return true;
}
void ComputeShape(const GeTensorDesc &x_desc, const OpDescPtr &op_desc_ptr, GeTensorDesc &y_desc) {
const auto x_shape_dim = x_desc.GetShape().GetDims();
int64_t axis = 0;
int64_t end_axis = 0;
GetAndConvertAxis(op_desc_ptr, axis, end_axis);
std::vector<int64_t> y_shape_dim;
for (int64_t i = 0; i < axis; i++) {
y_shape_dim.emplace_back(x_shape_dim[i]);
}
int64_t dim_val = 1;
for (int64_t i = axis; i < (end_axis + 1); i++) {
dim_val = dim_val * x_shape_dim[i];
}
y_shape_dim.emplace_back(dim_val);
for (int64_t i = (end_axis + 1); i <static_cast<int64_t>(x_shape_dim.size()); i++) {
y_shape_dim.emplace_back(x_shape_dim[i]);
}
GeShape y_shape(y_shape_dim);
y_desc.SetShape(y_shape);
}
}
Status FlattenV2Kernel::Compute(const NodePtr &node_ptr) const {
GELOGD("FlattenV2 dimension kernel in");
GE_CHECK_NOTNULL(node_ptr);
ge::OpDescPtr op_desc_ptr = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(op_desc_ptr);
if (!IsFlattenV2ParamsValid(op_desc_ptr)) {
GELOGW("Params are invalid");
return NOT_CHANGED;
}
GELOGD("FlattenV2 dimension kernel success.");
return SUCCESS;
}
Status FlattenV2Kernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
std::vector<ge::GeTensorPtr> &v_output) {
GELOGD("FlattenV2 folding kernel in.");
GE_CHECK_NOTNULL(op_desc_ptr);
if (!IsFlattenV2ParamsValid(op_desc_ptr)) {
GELOGW("Params are invalid");
return NOT_CHANGED;
}
auto output_tensor_desc = op_desc_ptr->GetOutputDesc(kFirstDataIndex);
auto input_desc = op_desc_ptr->GetInputDesc(kFirstDataIndex);
ComputeShape(input_desc, op_desc_ptr, output_tensor_desc);
GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc);
if (output_ptr == nullptr) {
GELOGW("Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str());
return NOT_CHANGED;
}
GELOGI("FlattenV2 op %s output tensor data size is %zu", op_desc_ptr->GetName().c_str(),
output_ptr->GetData().size());
size_t data_dim_size = output_ptr->GetTensorDesc().GetShape().GetDims().size();
GELOGI("FlattenV2 op %s output tensor dim size is %zu", op_desc_ptr->GetName().c_str(), data_dim_size);
if (output_ptr->SetData(input.at(kFirstDataIndex)->GetData()) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Compute: SetData failed");
return FAILED;
}
v_output.emplace_back(output_ptr);
GELOGD("FlattenV2 folding kernel success.");
return SUCCESS;
}
REGISTER_COMPUTE_NODE_KERNEL(FLATTENV2, FlattenV2Kernel);
}