* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "split_to_sequence.h"
#include "common/hyper_status.h"
#include "engine/aicpu/graph_builder/bg_launch.h"
#include "framework/common/ge_types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/node_utils.h"
#include "graph_builder/bg_infer_shape.h"
#include "graph_builder/bg_memory.h"
#include "graph_builder/bg_rt_session.h"
#include "graph_builder/bg_tensor.h"
#include "graph_builder/converter_checker.h"
namespace {
constexpr size_t kOutputNum = 1U;
constexpr size_t kInputNum = 1U;
}
namespace gert {
bg::ValueHolderPtr SplitToSequenceCompute(bg::ValueHolderPtr &axis_holder,
bg::ValueHolderPtr &split_shapes,
bg::ValueHolderPtr &inner_split_addrs,
const LowerInput &lower_input) {
return bg::ValueHolder::CreateSingleDataOutput("SplitToSequenceDoCompute",
{axis_holder, lower_input.input_shapes[0], lower_input.input_addrs[0], split_shapes, inner_split_addrs});
}
bg::ValueHolderPtr InferSplitOutputShape(uint32_t input_num, bg::ValueHolderPtr &axis_holder,
bg::ValueHolderPtr &input_num_holder, const LowerInput &lower_input) {
if (input_num == 1) {
return bg::ValueHolder::CreateSingleDataOutput("GetSplitVecWithInput",
{axis_holder, input_num_holder, lower_input.input_shapes[0]});
}
return bg::ValueHolder::CreateSingleDataOutput("GetSplitVecWithInput",
{axis_holder, input_num_holder, lower_input.input_shapes[0],
lower_input.input_shapes[1], lower_input.input_addrs[1]});
}
bg::ValueHolderPtr CalcSpecifiedShapeSize(const ge::NodePtr &node, bg::ValueHolderPtr &split_shapes) {
auto td = node->GetOpDescBarePtr()->GetInputDescPtr(0);
if (td == nullptr) {
return nullptr;
}
ge::DataType dt = td->GetDataType();
auto data_type = bg::ValueHolder::CreateConst(&dt, sizeof(dt));
return bg::ValueHolder::CreateSingleDataOutput("CalcTensorSizeFromSpecifiedShape", {data_type, split_shapes});
}
bg::ValueHolderPtr AllocSpecifedShapeMemory(TensorPlacement placement,
const bg::ValueHolderPtr &specified_sizes,
LoweringGlobalData &global_data) {
auto allocator_holder = global_data.GetOrCreateAllocator({placement, AllocatorUsage::kAllocNodeOutput});
auto mem = bg::ValueHolder::CreateSingleDataOutput("AllocSpecifiedMem", {allocator_holder, specified_sizes});
bg::ValueHolder::CreateVoidGuarder("FreeSpecifiedMem", mem, {});
return mem;
}
bg::ValueHolderPtr StoreSplitTensorToSequence(bg::StoreSequenceArg storeArg) {
std::vector<bg::ValueHolderPtr> inputs;
inputs.emplace_back(storeArg.session_id);
inputs.emplace_back(storeArg.contanier_id);
inputs.emplace_back(storeArg.inner_tensor_addrs);
inputs.emplace_back(storeArg.inner_tensor_shapes);
inputs.emplace_back(storeArg.input_num);
inputs.emplace_back(storeArg.keep_dims_holder);
inputs.emplace_back(storeArg.axis_holder);
inputs.emplace_back(storeArg.input_x_shape);
inputs.emplace_back(storeArg.output_tensor);
return bg::ValueHolder::CreateSingleDataOutput("StoreSplitTensorToSequence", inputs);
}
bg::ValueHolderPtr GetHandleShape() {
return bg::ValueHolder::CreateSingleDataOutput("GetSequenceHandleShape", {});
}
LowerResult LoweringSplitToSeuqnece(const ge::NodePtr &node, const LowerInput &lower_input) {
if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr)) {
GELOGE(ge::PARAM_INVALID, "[Check][Op]Can not find op.");
REPORT_INNER_ERR_MSG("E19999", "Can not find op.");
return {HyperStatus::ErrorStatus(static_cast<const char*>("Can not find op")), {}, {}, {}};
}
auto ret = CheckLowerInput(lower_input);
if (!ret.IsSuccess()) {
GELOGE(ge::PARAM_INVALID, "[Check][LowerInput]Op %s type %s lower_input is invalid.", node->GetName().c_str(),
ge::NodeUtils::GetNodeType(node).c_str());
REPORT_INNER_ERR_MSG("E19999", "Op %s type %s lower_input is invalid.", node->GetName().c_str(),
ge::NodeUtils::GetNodeType(node).c_str());
return {ret, {}, {}, {}};
}
auto input_num = lower_input.input_addrs.size();
if (input_num < 1) {
GELOGE(ge::PARAM_INVALID, "[Check][Op]Input num err, it is at least 1.");
REPORT_INNER_ERR_MSG("E19999", "Input num err, it is at least 1.");
return {HyperStatus::ErrorStatus(static_cast<const char*>("Input num err, it is at least 1")), {}, {}, {}};
}
auto input_num_holder = bg::ValueHolder::CreateConst(&input_num, sizeof(input_num));
auto output_shape = GetHandleShape();
std::vector<bg::ValueHolderPtr> output_shapes;
output_shapes.push_back(output_shape);
auto output_sizes = bg::CalcTensorSize(node, output_shapes);
auto output_addrs = bg::AllocOutputMemory(kOnHost, node, output_sizes, *(lower_input.global_data));
CONVERTER_CHECK_HOLDERS_ALL_OK(output_addrs, kOutputNum);
LOWER_REQUIRE(lower_input.input_shapes.size() >= kInputNum);
LOWER_REQUIRE(lower_input.input_addrs.size() >= kInputNum);
int32_t axis = 0;
(void) ge::AttrUtils::GetInt(node->GetOpDescBarePtr(), "axis", axis);
auto axis_holder = bg::ValueHolder::CreateConst(&axis, sizeof(axis));
auto split_shapes = InferSplitOutputShape(input_num, axis_holder, input_num_holder, lower_input);
auto specified_size = CalcSpecifiedShapeSize(node, split_shapes);
auto inner_split_addrs = AllocSpecifedShapeMemory(kOnDeviceHbm, specified_size, *(lower_input.global_data));
auto compute_holder = SplitToSequenceCompute(axis_holder, split_shapes, inner_split_addrs, lower_input);
int32_t keep_dims = 1;
(void) ge::AttrUtils::GetInt(node->GetOpDescBarePtr(), "keepdims", keep_dims);
auto keep_dims_holder = bg::ValueHolder::CreateConst(&keep_dims, sizeof(keep_dims));
auto output_tensor = bg::BuildTensor(node, 0, kOnHost, output_shapes[0], output_addrs[0]);
auto session_id = bg::GetSessionId(*lower_input.global_data);
auto contain_id_holder = bg::GetContainerIdHolder(lower_input);
auto store_holder = StoreSplitTensorToSequence({session_id, contain_id_holder, inner_split_addrs, split_shapes,
input_num_holder, keep_dims_holder, axis_holder, lower_input.input_shapes[0], output_tensor});
bg::ValueHolder::AddDependency(compute_holder, store_holder);
return {HyperStatus::Success(), {compute_holder, store_holder}, output_shapes, output_addrs};
}
REGISTER_NODE_CONVERTER_PLACEMENT("SplitToSequence", kOnDeviceHbm, LoweringSplitToSeuqnece);
}