* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph_builder/bg_infer_shape.h"
#include "graph_builder/bg_memory.h"
#include "engine/aicpu/graph_builder/bg_launch.h"
#include "framework/common/ge_types.h"
#include "common/hyper_status.h"
#include "graph/debug/ge_attr_define.h"
#include "engine/aicpu/graph_builder/bg_aicpu_arg.h"
#include "graph/utils/node_utils.h"
#include "graph_builder/converter_checker.h"
#include "graph_builder/bg_tensor.h"
#include "sequence_ops.h"
#include "graph_builder/bg_rt_session.h"
namespace {
constexpr size_t kOutputNum = 1U;
}
namespace gert {
bg::ValueHolderPtr SequenceConstructCompute(bg::ValueHolderPtr session_id, bg::ValueHolderPtr container_id,
bg::ValueHolderPtr input_num_holder, const LowerInput &lower_input,
bg::ValueHolderPtr output_tensor) {
std::vector<bg::ValueHolderPtr> inputs;
inputs.emplace_back(session_id);
inputs.emplace_back(container_id);
inputs.emplace_back(input_num_holder);
for (size_t i = 0; i < lower_input.input_addrs.size(); i++) {
inputs.emplace_back(lower_input.input_addrs[i]);
inputs.emplace_back(lower_input.input_shapes[i]);
}
inputs.emplace_back(output_tensor);
auto compute_holder = bg::ValueHolder::CreateSingleDataOutput("SequenceConstructCompute", inputs);
return compute_holder;
}
LowerResult LoweringSequenceConstruct(const ge::NodePtr &node, const LowerInput &lower_input) {
if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr)) {
GELOGE(ge::PARAM_INVALID, "[Check][Op]Can not find op.");
REPORT_INNER_ERR_MSG("E39999", "Can not find op.");
return {HyperStatus::ErrorStatus(static_cast<const char *>("Can not find op")), {}, {}, {}};
}
auto ret = CheckLowerInput(lower_input);
if (!ret.IsSuccess()) {
GELOGE(ge::PARAM_INVALID, "[Check][LowerInput]Op %s type %s lower_input is invalid.", node->GetName().c_str(),
ge::NodeUtils::GetNodeType(node).c_str());
REPORT_INNER_ERR_MSG("E39999", "Op %s type %s lower_input is invalid.", node->GetName().c_str(),
ge::NodeUtils::GetNodeType(node).c_str());
return {ret, {}, {}, {}};
}
auto input_num = lower_input.input_addrs.size();
if (input_num < 1) {
GELOGE(ge::PARAM_INVALID, "[Check][Op]Input num err, it is at least 1.");
REPORT_INNER_ERR_MSG("E39999", "Input num err, it is at least 1.");
return {HyperStatus::ErrorStatus(static_cast<const char *>("Input num err, it is at least 1")), {}, {}, {}};
}
auto input_num_holder = bg::ValueHolder::CreateConst(&input_num, sizeof(input_num));
auto output_shape = bg::ValueHolder::CreateSingleDataOutput("GetSequenceHandleShape", {});
std::vector<bg::ValueHolderPtr> output_shapes;
output_shapes.push_back(output_shape);
auto output_sizes = bg::CalcTensorSize(node, output_shapes);
auto output_addrs = bg::AllocOutputMemory(kOnHost, node, output_sizes, *(lower_input.global_data));
CONVERTER_CHECK_HOLDERS_ALL_OK(output_addrs, kOutputNum);
auto output_tensor = bg::BuildTensor(node, 0, kOnHost, output_shapes[0], output_addrs[0]);
auto session_id = bg::GetSessionId(*lower_input.global_data);
bg::ValueHolderPtr container_id_holder = bg::GetContainerIdHolder(lower_input);
auto compute_holder =
SequenceConstructCompute(session_id, container_id_holder, input_num_holder, lower_input, output_tensor);
return {HyperStatus::Success(), {compute_holder}, output_shapes, output_addrs};
}
REGISTER_NODE_CONVERTER("SequenceConstruct", LoweringSequenceConstruct);
}