* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_NODE_PASSES_SPLIT_VARIABLE_INTO_SUBGRAPH_PASS_H
#define AIR_NODE_PASSES_SPLIT_VARIABLE_INTO_SUBGRAPH_PASS_H
#include "graph/passes/base_pass.h"
namespace ge {
* (1)
* When variable connect to a node with subgraph(if/case/paritioncall)
* we need split variable into subgraph to make sure variable function ok in subgraph
* Before:
* +-------------------+
* | data |
* variable | | |
* / \ | opB |
* opA PartitionedCall +-------------------+
*
*
* After:
* +-------------------+
* | variable |
* variable | | |
* / \ | opB |
* opA PartitionedCall +-------------------+
*(2)
* When variable connect to a while node
* This pass will copy variable into subgraph, consider while node as normal op, after while_node executed,
* opB can access variable.
*
* variable
* / \
* opA while
* |
* opB
*
*
* variable
* / | \c
* opA | while
* \ /c
* opB
*
* Attention:
* This pass should be executed after DataPass which mark parent_node_input_index on inner data of subgraph.
*/
class SplitVariableIntoSubgraphPass : public BaseNodePass {
public:
Status Run(NodePtr &node) override;
private:
Status RefreshMultiDataRefDataNodeShape(const NodePtr &data_node, NodePtr &ref_data_node) const;
Status CopyVarToSubgraph(const NodePtr &var, int32_t parent_input_index, const ComputeGraphPtr &subgraph) const;
bool IsMultiBatchSubgraph(const ComputeGraphPtr &subgraph) const;
};
}
#endif