"""
Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
"""
import re
class TableExtractor:
"""从SQL中提取表引用的工具类"""
BACKTICK_PATTERN = re.compile(r'`([^`]+)`')
FROM_JOIN_PATTERN = re.compile(
r'(?:FROM|JOIN|INTO)\s+(?:([a-zA-Z0-9_]+)\.)?([a-zA-Z0-9_]+)',
re.IGNORECASE
)
def __init__(self):
self.cte_names = set()
def set_cte_names(self, cte_names):
self.cte_names = set(cte_names)
def extract_tables(self, sql):
"""
提取SQL中所有表引用
:param sql: SQL片段
:return: list: 表名列表
"""
if not sql:
return []
tables = set()
backtick_tables = self.BACKTICK_PATTERN.findall(sql)
tables.update(backtick_tables)
from_join_matches = self.FROM_JOIN_PATTERN.findall(sql)
for match in from_join_matches:
table_name = match[1] if match[1] else match[0]
if table_name:
tables.add(table_name)
subquery_tables = self._extract_subquery_tables(sql)
tables.update(subquery_tables)
tables = self._filter_sql_keywords(tables)
return list(tables)
def _extract_subquery_tables(self, sql):
"""
解析子查询中的表引用
"""
tables = set()
stack = []
start = -1
for i, char in enumerate(sql):
if char == "(":
if start == -1:
start = i
stack.append(i)
elif char == ")":
if stack:
stack.pop()
if not stack:
subquery = sql[start+1:i]
sub_tables = self.extract_tables(subquery)
tables.update(sub_tables)
start = -1
return tables
def _filter_sql_keywords(self, tables):
"""
过滤掉SQL关键字和保留字
"""
sql_keywords = {
'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER',
'FULL', 'CROSS', 'ON', 'AND', 'OR', 'NOT', 'IN', 'IS', 'NULL', 'AS',
'GROUP', 'BY', 'ORDER', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
'DISTINCT', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'OVER', 'PARTITION',
'ROWS', 'RANGE', 'PRECEDING', 'FOLLOWING', 'CURRENT', 'ROW', 'UNBOUNDED',
'INSERT', 'OVERWRITE', 'TABLE', 'VALUES', 'SET', 'WITH',
'RECURSIVE', 'LATERAL', 'VIEW', 'TEMPORARY', 'TEMP', 'FUNCTION',
'CAST', 'COALESCE', 'IF', 'NULLIF', 'GREATEST', 'LEAST',
'RANK', 'DENSE_RANK', 'ROW_NUMBER', 'LAG', 'LEAD', 'FIRST_VALUE', 'LAST_VALUE',
'NTH_VALUE', 'NTILE', 'CUME_DIST', 'PERCENT_RANK',
'IFNULL', 'NVL', 'NVL2', 'DECODE', 'IIF',
'ASC', 'DESC', 'NULLS', 'FIRST', 'LAST', 'USING', 'NATURAL', 'BETWEEN',
'EXISTS', 'LIKE', 'REGEXP', 'RLIKE', 'CONCAT', 'SUBSTRING', 'TRIM',
'LENGTH', 'UPPER', 'LOWER', 'ROUND', 'FLOOR', 'CEIL', 'ABS', 'MOD',
'COUNT', 'SUM', 'AVG', 'MIN', 'MAX', 'COUNT', 'STDDEV', 'VARIANCE'
}
filtered = []
for table in tables:
upper_table = table.upper()
if upper_table not in sql_keywords:
filtered.append(table)
return filtered
def is_cte(self, name):
return name in self.cte_names
def is_physical_table(self, name):
return not self.is_cte(name)
def extract_subquery_aliases(self, sql):
"""
提取子查询别名及其对应SQL片段
:param sql: SQL片段
:return: dict: {alias: subquery_sql} 映射
"""
if not sql:
return {}
result = {}
pattern = re.compile(
r'(?:FROM|JOIN|INTO)\s*\(\s*SELECT\s+',
re.IGNORECASE | re.DOTALL
)
matches = list(pattern.finditer(sql))
for match in matches:
start = match.start()
paren_start = sql.find('(', start)
if paren_start == -1:
continue
end_paren = self._find_matching_paren(sql, paren_start)
if end_paren is None:
continue
subquery_sql = sql[paren_start + 1:end_paren]
after_paren = sql[end_paren + 1:].lstrip()
alias_match = re.match(r'(?:AS\s+)?([a-zA-Z_][a-zA-Z0-9_]*)', after_paren, re.IGNORECASE)
if alias_match:
alias = alias_match.group(1)
result[alias] = subquery_sql
return result
def _find_matching_paren(self, sql, open_pos):
"""
找到匹配的结束括号位置
:param sql: SQL字符串
:param open_pos: '('的位置
:return: int:匹配')'的位置,如果没找到返回None
"""
if open_pos >= len(sql) or sql[open_pos] != '(':
return None
depth = 1
i = open_pos + 1
in_string = False
string_char = None
while i < len(sql):
char = sql[i]
if char in ["'", "`"] and (i == 0 or sql[i - 1] != "\\"):
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
if in_string:
i += 1
continue
if char == "(":
depth += 1
elif char == ")":
depth -= 1
if depth == 0:
return i
i += 1
return None
def extract_cte_aliases(self, sql):
"""
提取CTE别名及其对应CTE名称
识别 FROM cte alias 模式
:param sql: SQL片段
:return: dict: {alias: cte_name} 映射
"""
if not sql:
return {}
result = {}
subquery_ends = set()
pattern = re.compile(r'\)\s+AS\s+[a-zA-Z_][a-zA-Z0-9_]*', re.IGNORECASE)
for match in pattern.finditer(sql):
end = match.end()
subquery_ends.add(end)
from_join_pattern = re.compile(
r'(?:FROM|JOIN|INTO)\s+([a-zA-Z_][a-zA-Z0-9_]*)',
re.IGNORECASE
)
for match in from_join_pattern.finditer(sql):
if match.end() in subquery_ends:
continue
table_name = match.group(1)
if self.is_cte(table_name):
after_table = sql[match.end():].lstrip()
sql_keywords = {
'WHERE', 'ON', 'AND', 'OR', 'GROUP', 'ORDER', 'HAVING', 'LIMIT',
'UNION', 'SET', 'AS', 'IN', 'EXISTS', 'BETWEEN', 'LIKE', 'IS',
'NULL', 'NOT', 'INTO', 'FROM', 'JOIN', 'LEFT', 'RIGHT', 'INNER',
'OUTER', 'FULL', 'CROSS', 'NATURAL', 'USING', 'PARTITION',
'DISTINCT', 'ALL', 'ANY', 'SOME', 'CASE', 'WHEN', 'THEN', 'ELSE',
'END', 'OVER', 'WINDOW', 'ROWS', 'RANGE', 'PRECEDING',
'FOLLOWING', 'CURRENT', 'UNBOUNDED', 'FIRST', 'LAST', 'NULLS',
'LATERAL', 'PIVOT', 'UNPIVOT', 'EXCEPT', 'INTERSECT', 'MINUS'
}
alias_match = re.match(r'(?:AS\s+)?([a-zA-Z_][a-zA-Z0-9_]*)', after_table, re.IGNORECASE)
if alias_match:
alias = alias_match.group(1)
if alias.upper() != table_name.upper() and alias.upper() not in sql_keywords:
result[alias] = table_name
return result