import os
from msprof_analyze.prof_common.logger import get_logger
import itertools
from collections import deque
from dataclasses import dataclass
from typing import List, Tuple, Dict
from msprof_analyze.prof_common.file_manager import FileManager
from msprof_analyze.advisor.utils.file import FileOpen
logger = get_logger()
@dataclass
class Tensor:
def __init__(self):
super().__init__()
self.shape = []
self.origin_shape = []
self.shape_range = []
self.origin_shape_range = []
self.dtype = ""
self.origin_data_type = ""
self.format = ""
self.origin_format = []
@dataclass
class Attr:
def __init__(self):
super().__init__()
self.key = str()
self.value = []
class HostGraphNode:
def __init__(self):
super().__init__()
self.graph_name = str()
self.op_name = str()
self.op_type = str()
self.inputs = []
self.input = []
self.outputs = []
self.output = []
self.strides = []
self.pads = []
self.groups = ""
self.dilations = []
self.kernelname = ""
self._attrs = []
def __repr__(self):
return f"<node {self.op_name}>"
@dataclass
class HostGraph:
def __init__(self):
super().__init__()
self.name = ""
self.nodes = {}
self.inputs = []
self.edges = []
self.model_name = None
self.file_path = None
def build(self):
"""build a graph"""
for name, node in self.nodes.items():
for input_node in node.inputs:
if input_node not in self.nodes:
continue
self.nodes[input_node].outputs.append(name)
class HostGraphParser:
"""
Parse graph metadata from text file
"""
def __init__(self, file_path):
self.buffer = deque(maxlen=100)
self.line_no = 0
self._file_path = file_path
self.edges: List[Tuple[HostGraphNode, HostGraphNode]] = []
self.nodes: Dict[str, HostGraphNode] = {}
self.graphs = self._parse(self._file_path)
self._get_node_dict()
self._get_edges_list()
del self.graphs[0]
@staticmethod
def _get_key_value(line):
res = line.split(':', 1)
return res[0].strip(), res[1].strip().strip('"')
@staticmethod
def _parse_attr(key, value, obj):
if not isinstance(obj, list) and not obj:
return
if key == "dim" and hasattr(obj, "shape"):
obj.shape.append(value)
elif key == "name" and hasattr(obj, "op_name"):
obj.op_name = value
elif key == "name" and hasattr(obj, "name"):
obj.name = value
elif key == "dtype" and hasattr(obj, "dtype"):
obj.dtype = value
elif key == "layout" and hasattr(obj, "format"):
obj.format = value
elif key == "type" and hasattr(obj, "op_type"):
obj.op_type = value
elif key == "input" and hasattr(obj, "input"):
obj.inputs.append(value.strip('"').split(':')[0])
elif key == "key" and hasattr(obj, "key"):
obj.key = value
elif hasattr(obj, key):
setattr(obj, key, value)
elif isinstance(obj, list) and key != "val_type":
obj.append(value)
def _parse_struct(self, in_file, key, in_obj):
def parse_shape(file, obj):
obj = self._parse_line(file, obj)
def parse_input_desc(file, obj):
tensor = self._parse_line(file, Tensor())
if obj and hasattr(obj, "input"):
obj.input.append(tensor)
def parse_out_desc(file, obj):
tensor = self._parse_line(file, Tensor())
if obj and hasattr(obj, "output"):
obj.output.append(tensor)
def parse_op(file, obj: HostGraph):
node = self._parse_line(file, HostGraphNode())
if hasattr(obj, "name"):
node.graph_name = obj.name
if obj and hasattr(obj, "nodes") and node.op_name:
obj.nodes[node.op_name] = node
def parse_graph(file, obj):
graph = self._parse_line(file, HostGraph())
obj.append(graph)
def parse_attr(file, obj):
attr = self._parse_line(file, Attr())
if hasattr(obj, attr.key):
if attr.key not in ['format']:
setattr(obj, attr.key, attr.value)
elif attr.key.endswith("_kernelname"):
setattr(obj, "kernelname", attr.value)
if obj and hasattr(obj, "get_attrs"):
obj.get_attrs().append(attr)
def parse_list(file, obj):
value = []
self._parse_line(file, value)
if isinstance(obj, list):
obj.append(value)
else:
obj = value
def parse_value(file, obj):
if hasattr(obj, "value"):
obj.value = self._parse_line(file, obj.value)
def parse_default(file, _obj=None):
"""function with unused argument"""
self._parse_line(file, None)
parse_methods = {
"shape": parse_shape,
"input_desc": parse_input_desc,
"output_desc": parse_out_desc,
"op": parse_op,
"graph": parse_graph,
"attr": parse_attr,
"list_list_int": parse_list,
"list_list_i": parse_list,
"list": parse_list,
"value": parse_value,
}
parse_methods.get(key, parse_default)(in_file, in_obj)
def _read_line(self, file):
self.line_no += 1
line = file.readline()
if line.strip().endswith('}'):
end_line = ""
while self.buffer and not end_line.strip().endswith("{"):
end_line = self.buffer.pop()
else:
self.buffer.append(line)
return line.strip()
def _parse_line(self, file, obj=None):
line = self._read_line(file)
try:
while line and not line.endswith("}"):
if line.endswith('{'):
key = line.rstrip('{').strip()
self._parse_struct(file, key, obj)
else:
key, value = self._get_key_value(line)
self._parse_attr(key, value, obj)
line = self._read_line(file)
except Exception as exception:
if self.buffer:
logger.debug("***********************graph content**************************")
while self.buffer:
line = self.buffer.popleft()
logger.debug(line)
logger.debug("***********************graph content**************************")
raise exception
return obj
def _parse(self, graph_file):
graph_list = []
with FileOpen(graph_file, "r") as file:
try:
graph_list = self._parse_line(file.file_reader, graph_list)
except Exception:
logger.error(
"Parse line %s of file %s failed, make sure the format is correct.", self.line_no, graph_file
)
graphs = []
for graph in graph_list:
if isinstance(graph, HostGraph):
graphs.append(graph)
for graph in graphs:
graph.model_name = graphs[0].name
graph.file_path = self._file_path
graph.build()
return graphs
def _get_edges_list(self) -> None:
if len(self.graphs) <= 0:
return
def is_repeat_edge(edge, edge_collector):
for _edge in edge_collector:
if edge[0].op_name == _edge[0].op_name and edge[1].op_name == _edge[1].op_name:
return True
return False
for node in self.nodes.values():
for input_node_name in node.inputs:
if input_node_name not in self.nodes:
continue
input_node = self.nodes[input_node_name]
if not is_repeat_edge((input_node, node), self.edges):
self.edges.append((input_node, node))
for output_node_name in node.outputs:
if output_node_name not in self.nodes:
continue
output_node = self.nodes[output_node_name]
if not is_repeat_edge((node, output_node), self.edges):
self.edges.append((node, output_node))
def _get_node_dict(self) -> None:
if not self.graphs:
self.nodes = {}
return
self.nodes = {
node.op_name: node
for graph in self.graphs
for node in graph.nodes.values()
}
class QueryGraphNode:
"""
Graph Node
"""
_ID = 0
def __init__(self, op_type: str, op_pass: str):
self._op_type = op_type
self._id = QueryGraphNode._ID
self._op_pass = op_pass
QueryGraphNode._ID += 1
def __eq__(self, other):
return self._op_type == other._op_type and \
self._id == other._id
def __hash__(self):
return hash(self._op_type + str(self._id))
@property
def op_type(self):
return self._op_type
@property
def op_name(self):
return self._op_type + "_id_" + str(self._id)
@property
def op_pass(self):
return self._op_pass
@op_type.setter
def op_type(self, op_type):
self._op_type = op_type
@staticmethod
def trim_string(string: str, length: int = -1):
"""
Trim string to target length
:param string: Original string
:param length: Target length of string, -1 indicates original string.
:return: Trimmed string
"""
if string is None or not isinstance(string, str):
raise TypeError(f"Param string must be a string type but got {type(string)}.")
if length <= -1 or len(string) <= length:
return string
return string[:length]
def get_property(self, name):
"""
get property
"""
return getattr(self, name, lambda: None)
class QueryGraphParser:
def __init__(self, rule_database_path: str):
self._fusion_rules: Dict[str, List[Tuple]] = dict()
self.load_database(rule_database_path)
self.num_rules = sum([len(v) for v in self._fusion_rules.values()])
@property
def fusion_rules(self):
return self._fusion_rules
@staticmethod
def build_query_graph_v0(graph_name: str, graph_struct: List[str]) -> List[Tuple]:
nodes = dict()
graphs = []
edges = []
pre_node, next_node = None, None
for node in graph_struct:
pre_node = next_node
next_node = QueryGraphNode(node, graph_name)
nodes[next_node.op_name] = next_node
if pre_node is None or next_node is None:
continue
edges.append((pre_node, next_node,))
graphs.append((nodes, edges, graph_name,))
return graphs
@staticmethod
def build_query_graph_v1(graph_name: str,
nodes_list: List[Dict],
edges_list: List[List[str]]) -> List[Tuple]:
graphs = []
node_index = dict()
multi_node_list = []
for index, node in enumerate(nodes_list):
(node_name, op_type), = node.items()
if isinstance(op_type, str):
op_type = [op_type]
multi_node_list.append([QueryGraphNode(op, graph_name) for op in op_type])
node_index[node_name] = index
multi_node = list(itertools.product(*multi_node_list))
for index, sub_nodes in enumerate(multi_node):
sub_graph_name = graph_name if index == 0 else f"{graph_name}#{index}"
sub_edge = []
sub_node = dict()
for node in sub_nodes:
sub_node[node.op_name] = node
for edge in edges_list:
pre_node, next_node = edge
pre_node_index, next_node_index = node_index.get(pre_node), node_index.get(next_node)
sub_edge.append((sub_nodes[pre_node_index], sub_nodes[next_node_index]))
sub_graph = (sub_node, sub_edge, sub_graph_name,)
graphs.append(sub_graph)
return graphs
def load_database(self, rule_database):
if not os.path.isabs(rule_database):
rule_database = os.path.join(os.path.dirname(__file__),
"../", "../",
rule_database)
if not os.path.exists(rule_database):
raise FileNotFoundError(f"Path {rule_database} does not exist.")
database = FileManager.read_yaml_file(rule_database)
self.parse_yaml(database)
def parse_yaml(self, yaml_database):
fusion_strategy_list = yaml_database.get("GraphFusion", [])
if yaml_database.get("UBFusion", []):
fusion_strategy_list.extend(yaml_database.get("UBFusion", []))
for fusion_strategy in fusion_strategy_list:
if not isinstance(fusion_strategy, dict):
continue
(fusion_name, strategy), = fusion_strategy.items()
version = strategy.get("version", 0)
if version == 0 or version == "0":
self._fusion_rules[fusion_name] = self.build_query_graph_v0(fusion_name,
strategy.get('struct', []))
elif version == 1 or version == "1":
self._fusion_rules[fusion_name] = self.build_query_graph_v1(fusion_name,
strategy.get('nodes', []),
strategy.get('edges', []))