"""
Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
"""
import json
import os
import re
import hashlib
from collections import defaultdict
from omnihelper.parser.cte.parser import CTEParser
from omnihelper.parser.function.function_builder import FunctionBuilder
from omnihelper.parser.function.function_checker import FunctionChecker
from omnihelper.parser.type_matcher import TypeMatcher
from omnihelper.util.common_util import CommonUtil
from omnihelper.util.func_util import NOT_SUPPORTED_TYPE
class FunctionParser:
DICTIONARY_PATH = os.path.join(CommonUtil.get_execute_path(), "resources", "omni_function_dictionary.json")
UDF_DICTIONARY_PATH = os.path.join(CommonUtil.get_execute_path(), "resources", "udf_dictionary.json")
def __init__(self):
self.function_list = []
self.omni_functions = []
self.udf_list = []
self.user_defined_functions = []
self.partial_func_mapping = {}
self.all_funcs = []
self.func_pattern = ""
self.function_builder = None
self.load_func_list()
def load_func_list(self):
try:
with open(self.DICTIONARY_PATH, "r", encoding="utf-8") as f:
self.function_list = json.load(f)
except Exception as e:
raise Exception("Failed to load the functions list: " + str(e))
if os.path.exists(self.UDF_DICTIONARY_PATH):
try:
with open(self.UDF_DICTIONARY_PATH, "r", encoding="utf-8") as f:
self.udf_list = json.load(f)
except Exception as e:
raise Exception("Failed to load the user defined function: " + str(e))
self.omni_functions = [func.get("func_name").lower() for func in self.function_list]
self.user_defined_functions = [func.get("func_name").lower() for func in self.udf_list]
self.all_funcs = self.omni_functions + self.user_defined_functions
self.func_pattern = re.compile("({})\\((.*)".format("|".join(map(re.escape, self.all_funcs))), re.I)
for func in self.function_list:
if func.get("hash_agg_func"):
self.partial_func_mapping[func["func_name"]] = func["hash_agg_func"]
self.function_builder = FunctionBuilder(self.func_pattern, self.all_funcs)
def parse_event(self, event, column_type, table_schema):
"""
单事件表达式、函数解析核心逻辑
:return:
"""
if not self.function_list:
return []
analysis_result = []
param_type_mapping = {}
alias_map = {}
param_type_mapping.update(column_type)
physical_plan = event.get("physical plan")
if not physical_plan:
return []
if event.get("node metrics"):
TypeMatcher.extract_param_type(event.get("node metrics"), param_type_mapping)
ori_sql = event.get("original query")
cte_parser = CTEParser()
TypeMatcher.cte_subquery_table_mapping = cte_parser.parse(ori_sql)
TypeMatcher.table_schema = {table.rsplit(".", 1)[-1]: schema for table, schema in table_schema.items()}
update_physical_plan = self.preprocess_physical_plan(physical_plan)
operator_blocks = self.split_operators(update_physical_plan)
if not operator_blocks:
analysis_result = self.parse_physical_plan(update_physical_plan, event, param_type_mapping, alias_map)
for block in operator_blocks:
analysis_result.extend(self.parse_physical_plan(block, event, param_type_mapping, alias_map))
return self.count_func_times(analysis_result)
def preprocess_physical_plan(self, physical_plan):
preprocess_phy_plan = []
split_phy_plan = physical_plan.split("\n")
for line in split_phy_plan:
if not line or line.strip().startswith("+- =="):
continue
if line.startswith("Condition"):
condition_match = re.match(r"Condition\s*:\s*\((.*)\)\s*", line)
if condition_match:
line = condition_match.group(1)
processed_line = self.extract_bracket_content(line)
preprocess_phy_plan.append(processed_line)
return preprocess_phy_plan
def extract_bracket_content(self, line):
"""
去除算子参数的名称字段,例如Input、Keys、Output等字段,只取值
例如:Input [12]: [sum#167],仅保留sum#167
"""
line = line.strip()
m = re.search(r"\[[^\[\]]*\]\s*:\s*\[(.*)\]$", line)
if m:
return m.group(1).strip()
return line
def split_operators(self, physical_plan):
result = []
current_block = []
for line in physical_plan:
if re.match(r"^\(\d+\)", line.strip()):
if current_block:
result.append(current_block)
current_block = []
current_block.append(line)
if current_block:
result.append(current_block)
return result
def parse_physical_plan(self, physical_plan, event, param_type_mapping, alias_map):
analysis_result = []
for line in physical_plan:
if "ReadSchema" in line:
TypeMatcher.extract_param_type(line, param_type_mapping)
CommonUtil.extract_alias_map(line, alias_map)
for line in physical_plan:
func_pairs = self.function_builder.search_func_expr_pairs(line)
if not func_pairs:
continue
for pair in func_pairs:
func_name = pair.get("func")
params = pair.get("params")
input_type = TypeMatcher.get_input_type(params, param_type_mapping, event, pair,
self.function_builder, alias_map, 0)
function_checker = FunctionChecker(self.function_list, self.udf_list)
is_not_supported_func = function_checker.check_support_status(func_name, params, input_type,
event.get("original query"))
if not is_not_supported_func:
continue
not_supported_func = self.build_not_supported_func(func_name, event, input_type, params, line)
analysis_result.append(not_supported_func)
return analysis_result
def build_not_supported_func(self, func_name, event, input_type, params, line):
func_name = self.partial_func_mapping[func_name] if func_name in self.partial_func_mapping else func_name
sql_hash = hashlib.sha256(event.get("original query").encode("utf-8")).hexdigest()[-6:]
not_supported_params = [
param
for idx, param in enumerate(params)
if input_type[idx] in NOT_SUPPORTED_TYPE
]
return {
"func_name": func_name,
"sql_hash": sql_hash,
"input": input_type,
"not_supported_line": line if any(param_type in NOT_SUPPORTED_TYPE for param_type in input_type) else "",
"not_supported_params": not_supported_params
}
def count_func_times(self, event_result):
counter = defaultdict(int)
not_supported_line = defaultdict(list)
not_supported_params = defaultdict(list)
for item in event_result:
key = (item["func_name"], item["sql_hash"], tuple(item["input"]))
counter[key] += 1
if item.get("not_supported_line"):
not_supported_line[key].append(item["not_supported_line"])
if item.get("not_supported_params"):
not_supported_params[key].append(item["not_supported_params"])
update_event_result = []
for (func_name, sql_hash, input_type), times in counter.items():
if not not_supported_line[(func_name, sql_hash, input_type)]:
update_event_result.append({
"func_name": func_name,
"sql_hash": sql_hash,
"input": input_type,
"times": times,
"is_udf": True if func_name.lower() in self.user_defined_functions else False,
"not_supported_line": '',
"not_supported_params": not_supported_params[(func_name, sql_hash, input_type)]
})
for i, x in enumerate(not_supported_line[(func_name, sql_hash, input_type)]):
param = ", ".join(not_supported_params[(func_name, sql_hash, input_type)][i])
nested_line = f"{i + 1}. {x} 参数名:{param}"
update_event_result.append({
"func_name": func_name,
"sql_hash": sql_hash,
"input": input_type,
"times": times,
"is_udf": True if func_name.lower() in self.user_defined_functions else False,
"not_supported_line": nested_line,
"not_supported_params": not_supported_params[(func_name, sql_hash, input_type)]
})
return update_event_result