"""
Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
"""
import json
import os
import re
import hashlib
from collections import defaultdict
from omnihelper.parser.function.function_builder import FunctionBuilder
from omnihelper.parser.type_matcher import TypeMatcher
from omnihelper.util.common_util import CommonUtil
class OpParser:
MAPPING_PATH = os.path.join(CommonUtil.get_execute_path(), "resources", "omni_opname_mapping_dictionary.json")
DICTIONARY_PATH = os.path.join(CommonUtil.get_execute_path(), "resources", "omni_op_dictionary.json")
FUNC_DICTIONARY_PATH = os.path.join(CommonUtil.get_execute_path(), "resources", "omni_function_dictionary.json")
UDF_DICTIONARY_PATH = os.path.join(CommonUtil.get_execute_path(), "resources", "udf_dictionary.json")
def __init__(self):
self.opname_mapping = {}
self.op_dictionary = {}
self.omni_ops = {}
self._load_op_mapping()
self._load_op_dictionary()
self.function_builder = None
self._load_func_list()
def _load_func_list(self):
try:
with open(self.FUNC_DICTIONARY_PATH, "r", encoding="utf-8") as f:
function_list = json.load(f)
except Exception as e:
raise Exception("Failed to load the functions list: " + str(e))
if os.path.exists(self.UDF_DICTIONARY_PATH):
try:
with open(self.UDF_DICTIONARY_PATH, "r", encoding="utf-8") as f:
udf_list = json.load(f)
except Exception as e:
raise Exception("Failed to load the user defined function: " + str(e))
omni_functions = [func.get("func_name").lower() for func in function_list]
user_defined_functions = [func.get("func_name").lower() for func in udf_list]
all_funcs = omni_functions + user_defined_functions
func_pattern = re.compile("({})\\((.*)".format("|".join(map(re.escape, all_funcs))), re.I)
self.function_builder = FunctionBuilder(func_pattern, all_funcs)
def _load_op_mapping(self):
try:
with open(self.MAPPING_PATH, "r", encoding="utf-8") as f:
self.opname_mapping = json.load(f)
self.omni_ops = self.opname_mapping["omni_op_list"]
except FileNotFoundError:
raise FileNotFoundError(f"Opname mapping file not found: {self.MAPPING_PATH}")
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON format in mapping file: {self.MAPPING_PATH}")
except Exception as e:
raise Exception(f"Unexpected error while loading mapping file: {self.MAPPING_PATH}, error: {e}")
def _load_op_dictionary(self):
try:
with open(self.DICTIONARY_PATH, "r", encoding="utf-8") as f:
self.op_dictionary = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Op dictionary file not found: {self.DICTIONARY_PATH}")
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON format in dictionary file: {self.DICTIONARY_PATH}")
except Exception as e:
raise Exception(f"Unexpected error while loading dictionary file: {self.DICTIONARY_PATH}, error: {e}")
def _process_node_metrics(self, node_metrics):
"""
处理node_metrics信息,提取节点和集群信息
:param node_metrics: 包含节点指标信息的字符串
:return: 包含处理后的nodes和node_name_mapping字典
"""
nodes = {}
node_name_mapping = {}
if not node_metrics:
return nodes, node_name_mapping
plan_part, subgraph_part = node_metrics.split("\n\n[SubGraph]")
plan_part = plan_part.split("[PlanMetric]\n")[1]
splited_plan_part = plan_part.split("\n\n")
op_block_pattern = re.compile(r'^id:(\d+)\s+name:([^\s]+).*')
for block in splited_plan_part:
block = block.strip()
block_match = op_block_pattern.match(block)
if not block_match:
continue
node_id = int(block_match.group(1).strip())
name_match = block_match.group(2).lower()
if self.opname_mapping.get(name_match):
name_match = self.opname_mapping.get(name_match)
node_name_mapping.setdefault(name_match, []).append(node_id)
nodes[node_id] = {
'id': node_id,
'name': name_match,
'number_of_output_rows': 0,
'duration': None,
'duration_seconds': 0,
'size': None,
'size_mb': 0,
'cluster': [],
}
lines = block.strip().split('\n')
for line in lines:
codegen_match = re.search(r'WholeStageCodegen\s+\(\d+\)', line)
if codegen_match:
codegen_name_match = codegen_match.group(0)
nodes[node_id]['name'] = codegen_name_match
metric_match = re.search(r'SQLPlanMetric\s*([^)]+)', line)
if metric_match:
metric_content = metric_match.group(1)
if 'number of output rows' in metric_content:
num_match = re.search(r'number of output rows\s*,(.*?),\s*sum', metric_content)
if num_match:
nodes[node_id]['number_of_output_rows'] = int(num_match.group(1).replace(",", ""))
elif 'duration' in metric_content:
dur_match = re.search(r'\(\s*duration\s*,(.*?),\s*timing', metric_content)
if dur_match:
time_str = dur_match.group(1)
seconds = CommonUtil.parse_time_to_seconds(time_str)
nodes[node_id]['duration'] = time_str
nodes[node_id]['duration_seconds'] = seconds
elif 'size of files read' in metric_content:
size_match = re.search(r'\(\s*size of files read\s*,(.*?),\s*size', metric_content)
if size_match:
size_str = size_match.group(1)
mb = CommonUtil.parse_size_to_mb(size_str)
nodes[node_id]['size'] = size_str
nodes[node_id]['size_mb'] = mb
cluster_lines = subgraph_part.strip().split('\n')
for line in cluster_lines:
if 'cluster' in line:
cluster_match = re.search(r'cluster\s+(\d+)\s*:\s*(.+)', line)
if cluster_match:
cluster_id = int(cluster_match.group(1))
node_ids = [int(x.strip()) for x in cluster_match.group(2).split()]
for node_id in node_ids:
nodes[node_id]['cluster'].append(cluster_id)
return nodes, node_name_mapping
def parse_event(self, event, column_type):
"""
单事件表达式、函数解析核心逻辑
:return:
"""
nodes = {}
node_name_mapping = {}
analysis_result = []
param_type_mapping = {}
alias_map = {}
param_type_mapping.update(column_type)
physical_plan = event.get("physical plan")
if not physical_plan:
print("no physical plan")
return False, []
if event.get("node metrics"):
TypeMatcher.extract_param_type(event.get("node metrics"), param_type_mapping)
nodes, node_name_mapping = self._process_node_metrics(event.get("node metrics"))
update_physical_plan = self.preprocess_physical_plan(physical_plan)
sql_hash = hashlib.sha256(event.get("original query").encode("utf-8")).hexdigest()[-6:]
split_phy_plan = physical_plan.split("\n")
for line in split_phy_plan:
CommonUtil.extract_alias_map(line, alias_map)
for index, block in enumerate(update_physical_plan):
if "ReadSchema" in block:
TypeMatcher.extract_param_type(block, param_type_mapping)
opname = block.split("\n")[0].split()[1].lower()
if opname in self.omni_ops:
return True, []
input_pattern = re.compile(r'Input\s*\[\d+\]:\s*\[([^\]]+)\]')
input_match = input_pattern.search(block)
input_list = TypeMatcher.parse_param_list(input_match, param_type_mapping,
alias_map, event, self.function_builder)
is_supported_op = self.evaluate_support_status(opname, input_list)
if is_supported_op:
continue
opname = self.opname_mapping.get(opname)
output_pattern = re.compile(r'Output\s*\[\d+\]:\s*\[([^\]]+)\]')
output_match = output_pattern.search(block)
output_list = TypeMatcher.parse_param_list(output_match, param_type_mapping,
alias_map, event, self.function_builder)
time_str_parts = []
total_seconds = 0
output_rows = 0
output_sizes = 0
node_ids = node_name_mapping.get(opname, [])
for node_id in node_ids:
node_info = nodes.get(node_id)
output_rows += node_info['number_of_output_rows']
output_sizes += node_info['size_mb']
if len(node_info['cluster']) > 0:
for cluster_id in node_info['cluster']:
cluster_time_str = f"{nodes.get(cluster_id)['name']}:{nodes.get(cluster_id)['duration']}"
if cluster_time_str not in time_str_parts:
time_str_parts.append(cluster_time_str)
else:
total_seconds += node_info['duration_seconds']
total_seconds = round(total_seconds, 6)
time_str_parts.append(f"{total_seconds} s")
analysis_result.append(
{
"op_name": opname,
"sql_hash": sql_hash,
"input_list": input_list,
"output_list": output_list,
"output_rows": output_rows,
"output_sizes": round(output_sizes, 9),
"running_time": "\n".join(time_str_parts),
}
)
return False, self.count_op_times(analysis_result)
def preprocess_physical_plan(self, physical_plan):
split_phy_plan = physical_plan.split("\n\n")
op_block_pattern = re.compile(r'^\(\d+\).*')
preprocess_phy_plan = [line.strip() for line in split_phy_plan if op_block_pattern.match(line.strip())]
return preprocess_phy_plan
def evaluate_support_status(self, opname, input_list):
if not opname in self.opname_mapping:
return True
real_op_name = self.opname_mapping.get(opname)
op_supported_list = self.op_dictionary.get(real_op_name, {})
if len(input_list) == 0:
if len(op_supported_list.get("supported_list", [])) == 0:
return False
return True
if all(item in op_supported_list.get("supported_list", []) for item in input_list):
return True
return False
def count_op_times(self, event_result):
counter = defaultdict(int)
for item in event_result:
key = (item["op_name"],
item["sql_hash"],
tuple(item["input_list"]),
tuple(item["output_list"]),
item["running_time"],
item["output_rows"],
item["output_sizes"])
counter[key] += 1
update_event_result = []
for (op_name, sql_hash, input_list, output_list, running_time, output_rows, output_sizes), times \
in counter.items():
update_event_result.append({
"op_name": op_name,
"sql_hash": sql_hash,
"input_list": input_list,
"output_list": output_list,
"running_time": running_time,
"output_rows": output_rows,
"output_sizes": output_sizes,
"times": times
})
return sorted(update_event_result, key=lambda x: x["op_name"])