"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import copy
from abc import abstractmethod
from itertools import chain
from queue import Queue
from typing import List, Any, Optional, Type, Generator, Dict, Union
from ascend_utils.common.security import check_element_type
from ascend_utils.common.security import check_type
from ascend_utils.common.utils import FullPermutation
from ascend_utils.core.dag.dag_node import DagNode
class DirectedAcyclicGraph:
def __init__(self, network: Any):
self._network = network
self._dag_node_list: List[DagNode] = []
@property
def network(self):
return self._network
@property
def dag_node_list(self) -> List[DagNode]:
return self._dag_node_list
@staticmethod
def _clone_sub_graph(node_list_in_calc_order: List[DagNode]) -> List[DagNode]:
if node_list_in_calc_order is None or not isinstance(node_list_in_calc_order, list):
return []
dag_node_list = [DagNode(op_type=node.op_types, name=node.name) for node in node_list_in_calc_order]
dag_node_map = {node.name: node for node in dag_node_list}
for node in node_list_in_calc_order:
node_clone = dag_node_map[node.name]
node_clone.add_next_node(dag_node_map[node_output.name] for node_output in node.output_nodes)
return dag_node_list
@staticmethod
def _get_node_input_node(sub_graph):
input_node = None
output_node = None
for node in sub_graph:
if len(node.inputs) == 0:
if input_node is not None:
raise ValueError("There can only be one input node.")
else:
input_node = node
if len(node.outputs) == 0:
if output_node is not None:
raise ValueError("There can only be one output node.")
else:
output_node = node
for node in sub_graph:
for linked_node in chain(node.input_nodes, node.output_nodes):
if linked_node not in sub_graph:
raise ValueError("There can only be one input/output node.")
node_names = set()
for node in sub_graph:
if node.name in node_names:
raise ValueError("The node name must be different.")
node_names.add(node.name)
if input_node is None or output_node is None:
raise ValueError("You must have an input and an output.")
return input_node
@staticmethod
def _get_node_list_in_calc_order(sub_graph, input_node):
node_list_in_calc_order = []
calculating_nodes: Queue[DagNode] = Queue()
calculating_nodes.put(input_node)
while len(node_list_in_calc_order) < len(sub_graph) and not calculating_nodes.empty():
calculating_node = calculating_nodes.get()
node_list_in_calc_order.append(calculating_node)
for next_will_calc_node in calculating_node.output_nodes:
if all((needed_input in node_list_in_calc_order for needed_input in next_will_calc_node.input_nodes)):
calculating_nodes.put(next_will_calc_node)
return node_list_in_calc_order
@staticmethod
def _get_proper_combination_of_outputs(ori_graph_outputs: List[DagNode],
sub_graph_outputs: List[DagNode]):
if len(ori_graph_outputs) != len(sub_graph_outputs):
return
for seq in FullPermutation().get_all_permutations(len(sub_graph_outputs)):
for index_sub_node, index_ori_node in enumerate(seq):
if ori_graph_outputs[index_ori_node].op_type not in sub_graph_outputs[index_sub_node].op_types:
break
else:
yield {sub_graph_outputs[index_sub_node].name: ori_graph_outputs[index_ori_node] for
index_sub_node, index_ori_node in enumerate(seq)}
def search_nodes_by_class(self, cls: Type) -> Generator[DagNode, None, None]:
check_type(cls, type, param_name="cls")
for dag_node in self._dag_node_list:
if isinstance(dag_node.node, cls):
yield dag_node
def search_nodes_by_op_type(self, op_types: Union[str, List[str]]) -> Generator[DagNode, None, None]:
if isinstance(op_types, str):
op_type_list = [op_types]
else:
check_element_type(op_types, element_type=str, value_type=list, param_name="op_types")
op_type_list = op_types
for dag_node in self._dag_node_list:
if dag_node.op_type in op_type_list:
yield dag_node
def get_node_by_name(self, name: str) -> Optional[DagNode]:
check_type(name, str, param_name="name")
dag_node_get = None
for dag_node in self._dag_node_list:
if dag_node.name == name:
dag_node_get = dag_node
break
return dag_node_get
def get_nodes_by_name_prefix(self, name_prefix: str) -> Generator[DagNode, None, None]:
check_type(name_prefix, str, param_name="name_prefix")
for dag_node in self._dag_node_list:
if dag_node.name.startswith(name_prefix):
yield dag_node
def search_sub_graph(self, sub_graph: List[DagNode]) -> Generator[Dict[str, DagNode], None, None]:
"""
search sub graph
Args:
sub_graph: nodes in sub graph
Only one node input and one node output graph are supported.
Each node name must be unique.
Returns:
generator. The element is a dictionary, and the key of the dictionary is the node name in sub graph.
Examples:
>>> conv = DagNode(op_type="Conv2d", name="c")
>>> bn = DagNode(op_type="BatchNorm2d", name="b")
>>> conv >> bn
>>> dag.search_sub_graph([conv, bn])
"""
if sub_graph is None or not isinstance(sub_graph, list):
raise TypeError("sub graph must be list of DagNode.")
if len(sub_graph) == 0:
return
for input_node, node_list_in_calc_order in self._parse_sub_graph(sub_graph):
for node in self.dag_node_list:
if node.op_type not in input_node.op_types:
continue
for search_out in self._search_by_calc_order({input_node.name: node}, node_list_in_calc_order, 0):
yield search_out
@abstractmethod
def get_params(self) -> int:
pass
def _parse_sub_graph(self, sub_graph: List[DagNode]):
for node in sub_graph:
if None in node.op_types and (len(list(node.input_nodes)) > 1 or len(list(node.output_nodes)) > 1):
raise ValueError("The node whose op type is None must have only one input and one output now.")
input_node = self._get_node_input_node(sub_graph)
node_list_in_calc_order = self._get_node_list_in_calc_order(sub_graph, input_node)
for one_sub_graph in self._get_possible_sub_graph(node_list_in_calc_order):
if len(one_sub_graph) == 0:
continue
input_node = self._get_node_input_node(one_sub_graph)
node_list_in_calc_order = self._get_node_list_in_calc_order(one_sub_graph, input_node)
if len(node_list_in_calc_order) == 0:
raise ValueError("graph must has node.")
yield node_list_in_calc_order[0], node_list_in_calc_order
def _get_possible_sub_graph(self, node_list_in_calc_order: List[DagNode]) -> Generator[List[DagNode], None, None]:
none_names = [node.name for node in node_list_in_calc_order if None in node.op_types]
if len(none_names) == 0:
yield node_list_in_calc_order
return
for idxes in FullPermutation.get_all_combinations([2] * len(none_names)):
now_is_none_names = [name for index, name in enumerate(none_names) if idxes[index] == 0]
yield self._create_new_sub_graph(node_list_in_calc_order, now_is_none_names)
def _create_new_sub_graph(self, node_list_in_calc_order: List[DagNode], none_names: List[str]) -> List[DagNode]:
dag_node_list = self._clone_sub_graph(node_list_in_calc_order)
for node in reversed(dag_node_list):
if node.name not in none_names:
continue
dag_node_list.remove(node)
if len(node.inputs) == 0:
continue
if len(node.inputs) > 1:
raise ValueError("The node whose op type is None must have only one input.")
node_input = node.inputs[0]
node_input.dag_nodes_to.remove(node)
node_input.dag_nodes_to.extend(list(node.output_nodes))
for output_of_node in node.outputs:
for node_output in output_of_node.dag_nodes_to:
node_output.inputs.remove(output_of_node)
node_output.inputs.append(node_input)
return dag_node_list
def _search_by_calc_order(self, matched_nodes: Dict[str, DagNode], node_list_in_calc_order: List[DagNode],
index: int):
if index >= len(node_list_in_calc_order):
return
sub_graph_node = node_list_in_calc_order[index]
node_name = sub_graph_node.name
ori_graph_node = matched_nodes.get(node_name, None)
if ori_graph_node is None or ori_graph_node.op_type not in sub_graph_node.op_types:
return
if index != 0:
sub_graph_input_list = [matched_nodes.get(in_sub.name, None) for in_sub in sub_graph_node.input_nodes]
ori_graph_input_list = list(ori_graph_node.input_nodes)
sub_graph_input_list.sort(key=lambda x: x.name)
ori_graph_input_list.sort(key=lambda x: x.name)
if sub_graph_input_list != ori_graph_input_list:
return
if index == len(node_list_in_calc_order) - 1:
yield copy.copy(matched_nodes)
return
ori_graph_outputs = [node for node in ori_graph_node.output_nodes]
sub_graph_outputs = [node for node in sub_graph_node.output_nodes]
for proper_combination in self._get_proper_combination_of_outputs(ori_graph_outputs, sub_graph_outputs):
matched_nodes.update(proper_combination)
for search_out in self._search_by_calc_order(matched_nodes, node_list_in_calc_order, index + 1):
yield search_out
def _remove_one_node(self, node: DagNode):
if node in self.dag_node_list:
self.dag_node_list.remove(node)
for input_info in node.inputs:
for output_info in node.outputs:
input_info.add_node_to(output_info.dag_nodes_to)
if node in input_info.dag_nodes_to:
input_info.dag_nodes_to.remove(node)
for output_node in node.output_nodes:
for output_info in node.outputs:
if output_info in output_node.inputs:
output_node.inputs.remove(output_info)
output_node.inputs.extend(node.inputs)