from msprof_analyze.prof_common.logger import get_logger
from typing import Dict, List, Tuple, Callable, Any, Optional, Union
import networkx as nx
from msprof_analyze.advisor.common.graph.graph_parser import HostGraphNode, QueryGraphNode
logger = get_logger()
class Graph:
"""
Graph Struct
"""
def __init__(self,
nodes: Dict[str, Optional[Union[HostGraphNode, QueryGraphNode]]] = None,
edges: List[Tuple[Optional[Union[HostGraphNode, QueryGraphNode]],
Optional[Union[HostGraphNode, QueryGraphNode]]]] = None,
name: str = None):
self.name = name
self.graph = nx.DiGraph(name=name)
self.nodes = nodes if nodes is not None else {}
self.edges = edges if edges is not None else list()
def build(self):
for _, node in self.nodes.items():
self.add_node(node,
op_type=node.op_type
)
for edge in self.edges:
self.add_edge(*edge)
return self.graph
def get_size(self) -> Dict[str, int]:
if not hasattr(self.graph, "nodes"):
return {"edges": 0, "nodes": 0}
return {"edges": len(self.graph.edges),
"nodes": len(self.graph.nodes)}
def add_node(self, node: HostGraphNode, **kwargs):
if node is None:
return
self.graph.add_node(node, **kwargs)
def add_edge(self, pre_node: HostGraphNode, next_node: HostGraphNode):
if pre_node is None or next_node is None:
return
if pre_node not in self.graph or \
next_node not in self.graph:
logger.error("Nodes between edge should be both exists.")
return
self.graph.add_edge(pre_node, next_node)
def add_node_with_edge(self, node, adj_nodes: List[HostGraphNode]):
self.add_node(node)
for adj in adj_nodes:
self.add_edge(node, adj)
def remove_node(self, node: HostGraphNode = None) -> None:
if node is None:
return
self.graph.remove_node(node)
def remove_edge(self, pre_node: HostGraphNode = None, next_node: HostGraphNode = None) -> None:
if pre_node is None or next_node is None:
raise ValueError(f"Invalid edge from {pre_node} to {pre_node}.")
self.graph.remove_edge(pre_node, next_node)
def get_subgraph(self, nodes: List[HostGraphNode]) -> nx.DiGraph:
nodes = list(set(nodes))
for node in nodes:
if not self.is_node_exists(node):
raise ValueError(f"Failed to subtract subgraph because {node.op_name} is not in the graph.")
return self.graph.subgraph(nodes)
def highlight_subgraph(self, subgraph: nx.DiGraph = None) -> None:
pass
def get_node(self, node: HostGraphNode):
return self.graph[node] if node in self.graph else None
def get_node_by_name(self, node_name: str):
return self.nodes.get(node_name, None)
def is_node_exists(self, node: HostGraphNode):
return node in self.graph