import ast
import re
from datetime import datetime
from itertools import count
import logging
import sqlparse
from sqlparse.sql import Identifier, IdentifierList
from sqlparse.sql import Where, Comparison, Function, Parenthesis
from sqlparse.tokens import Keyword, DML, Token
from sqlparse.tokens import Name
OPERATOR = ('lt', 'lte', 'gt', 'gte', 'eq', 'neq')
def analyze_column(column, where_clause):
for tokens in where_clause.tokens:
if isinstance(tokens, Comparison) and isinstance(tokens.left, Identifier):
column.add(tokens.left.value)
def get_columns(sql):
column = set()
parsed_tree = sqlparse.parse(sql)[0]
for item in parsed_tree:
if isinstance(item, Where):
analyze_column(column, item)
return list(column)
def get_indexes(dbagent, sql, timestamp):
"""
Get indexes of SQL from dataset.
:param timestamp:
:param dbagent: obj, interface for sqlite3.
:param sql: str, query.
:return: list, the set of indexes.
"""
indexes = []
indexes_dict = dbagent.fetch_all_result("SELECT indexes from wdr where timestamp ==\"{timestamp}\""
" and query == \"{query}\"".format(timestamp=timestamp,
query=sql))
if len(indexes_dict):
try:
indexes_dict = ast.literal_eval(indexes_dict[0][0])
indexes_def_list = list(list(indexes_dict.values())[0].values())
for sql_index in indexes_def_list:
value_in_bracket = re.compile(r'[(](.*?)[)]', re.S)
indexes.append(re.findall(value_in_bracket, sql_index)[0].split(',')[0])
except Exception as e:
logging.exception(e)
return indexes
return indexes
def wdr_sql_processing(sql):
standard_sql = standardize_sql(sql)
standard_sql = re.sub(r';', r'', standard_sql)
standard_sql = re.sub(r'VALUES (\(.*\))', r'VALUES', standard_sql)
standard_sql = re.sub(r'\$\d+?', r'?', standard_sql)
return standard_sql
def check_select(parsed_sql):
if not parsed_sql.is_group:
return False
for token in parsed_sql.tokens:
if token.ttype is DML and token.value.upper() == 'SELECT':
return True
return False
def get_table_token_list(parsed_sql, token_list):
flag = False
for token in parsed_sql.tokens:
if not flag:
if token.ttype is Keyword and token.value.upper() == 'FROM':
flag = True
else:
if check_select(token):
get_table_token_list(token, token_list)
elif token.ttype is Keyword:
return
else:
token_list.append(token)
def standardize_sql(sql):
"""Standardized processing of SQL format"""
return sqlparse.format(
sql, keyword_case='upper', identifier_case='lower', strip_comments=True,
use_space_around_operators=True, strip_whitespace=True
)
def is_num(input_str):
if isinstance(input_str, str) and re.match(r'^\d+\.?\d+$', input_str):
return True
return False
def str2int(input_str):
return int(re.match(r'^(\d+)\.?\d+$', input_str).groups()[0])
def to_ts(obj):
if isinstance(obj, str):
if '.' in obj:
obj = obj.split('.')[0]
try:
timestamp = int(datetime.strptime(obj, '%Y-%m-%d %H:%M:%S').timestamp())
return timestamp
except Exception as e:
logging.exception(e)
return 0
elif isinstance(obj, datetime):
return int(obj.timestamp())
elif isinstance(obj, int):
return obj
else:
return 0
def fill_value(query_content):
"""
Fill specific values into the SQL statement for parameters,
case: select id from table where info = $1 and id_d < $2; PARAMETERS: $1 = 1, $2 = 4;
result: select id from table where info = '1' and id_d < '4';
"""
if len(query_content.split(';')) == 2 and 'parameters: $1' in query_content.lower():
template, parameter = query_content.split(';')
else:
return query_content
param_list = re.search(r'parameters: (.*)', parameter,
re.IGNORECASE).group(1).split(', $')
param_list = list(param.split('=', 1) for param in param_list)
param_list.sort(key=lambda x: int(x[0].strip(' $')),
reverse=True)
for item in param_list:
template = template.replace(item[0].strip() if re.match(r'\$', item[0]) else
('$' + item[0].strip()), item[1].strip())
return template
def exists_regular_match(query):
"""Determine if there is such a regular case in SQL: like '%xxxx', 'xxxx%', '%xxxx%'"""
result = re.findall(r"like\s+'%\S+'|like\s+'\S+%'", query)
return result
def exist_track_parameter(query):
"""Determine if SQL contains parameters"""
return True if '; parameters: $1 = ' in query.lower() else False
def is_query_normalized(query):
"""Determine if SQL is normalized or not"""
placeholders = []
for item in sqlparse.parse(query)[0].flatten():
if item.ttype is Name.Placeholder:
if not re.match(r"\$\d+|\?", item.value):
return False
placeholders.append(item.value)
if not placeholders:
return False
return True
def remove_parameter_part(query):
"""
remove parameter part when GUC 'track_parameter ' is ON, for example:
case: SELECT no_o_id FROM bmsql_new_order WHERE no_w_id = $1 AND no_d_id = $2
ORDER BY no_o_id ASC; parameters: $1 = '10', $2 = '2'
result: SELECT no_o_id FROM bmsql_new_order WHERE no_w_id = $1 AND no_d_id = $2
ORDER BY no_o_id ASC;
"""
return re.sub(r";\s*parameters: \$.+", ";", query, flags=re.IGNORECASE)
def exists_function(query):
"""
Determine if a function is used in Where clause, for example:
case1: select * from table where abs(l_quantity) <= 8;
result: abs(l_quantity)
case2: select col from table2 where id >
(select max(id2) from table2 where substring(info from 1 for 2) = 'xxx')
result: substring(info from 1 for 2)
"""
flags = []
def get_function(parsed):
for item in parsed:
if item.is_group:
if isinstance(item, Comparison) and isinstance(item.parent, Where):
for sub_item in item.tokens:
if isinstance(sub_item, Function):
flags.append(sub_item.value)
elif isinstance(sub_item, Parenthesis):
get_function(sub_item)
else:
get_function(item)
parsed_tree = sqlparse.parse(query)[0]
get_function(parsed_tree)
return flags
def regular_match(pattern, string, **kwargs):
"""Provides simple regularization functions."""
if re.search(pattern, string, **kwargs):
return True
return False
def remove_bracket(string):
"""
Remove bracket in string.
case: "substring"(c1, 2, 4)"
result: "substring"
"""
return re.sub(r"\(.*?\)", '', string)
def exists_bool_clause(query):
"""
Get boolean expression in SQL, there are two cases:
case1: select * from table where col in (xx, xx, xx, ...);
case2: select * from table where col not in (xx, xx, ...);
result: '(xx, xx, xx, ...)'
"""
flags = []
def get_in_clause(parsed):
for item in parsed:
if item.is_group:
if isinstance(item, Parenthesis) and isinstance(item.parent, Where):
comparisons = [subitem.value for subitem in item.parent.tokens if
subitem.ttype == sqlparse.tokens.Token.Keyword]
if any(comparisons[i - 1] == 'not' for i, x in enumerate(comparisons) if x == 'in') \
or any(op in comparisons for op in ('not in',)):
for sub_item in item.tokens:
if isinstance(sub_item, IdentifierList):
flags.append(sub_item.value.split(','))
elif sub_item.is_group:
get_in_clause(sub_item)
else:
get_in_clause(item)
parsed_tree = sqlparse.parse(query)[0]
get_in_clause(parsed_tree)
return flags
def exists_subquery(query):
"""
Determine if there is a subquery in SQL, for example:
case: select id from (select id from table2);
result: ["select id from table2"]
"""
flags = []
def get_subquery(parsed, height):
for item in parsed:
if item.is_group:
get_subquery(item, height + 1)
elif item.ttype == DML and item.value.upper() == "SELECT":
if height == 0:
continue
formatted_query = standardize_sql(item.parent.value).strip("()")
flags.append((formatted_query, height))
parsed_tree = sqlparse.parse(query)[0]
get_subquery(parsed_tree, 0)
return flags
def get_placeholders(query):
placeholders = set()
for item in sqlparse.parse(query)[0].flatten():
if item.ttype is Name.Placeholder:
placeholders.add(item.value)
return placeholders
def get_generate_prepare_sqls_function():
counter = count(start=0, step=1)
def get_prepare_sqls(statement):
statement = statement.strip().strip(';')
prepare_id = 'prepare_' + str(next(counter))
placeholder_size = len(get_placeholders(statement))
prepare_args = '' if not placeholder_size else '(%s)' % (','.join(['NULL'] * placeholder_size))
return [f'prepare {prepare_id} as {statement}', f'explain execute {prepare_id}{prepare_args}',
f'deallocate prepare {prepare_id}']
return get_prepare_sqls
def replace_question_mark_with_value(query):
"""
PBE does not support the following situations, we can solve it by replacing the '?' with a fixed value.
1. col >= date ?
2. interval ? year
3. fetch first ? row
4. count(?)
5. decode(?, xx, xx)
6. extract(? from o_year) as year
7. concat(?, col1, col2, ?)
"""
query = re.sub(r"([\s+|\s*,]date\s+)\?", r"\1'1999-01-01'", query, flags=re.IGNORECASE)
query = re.sub(r"(\s+interval\s+)\?", r"\1'1'", query, flags=re.IGNORECASE)
query = re.sub(r"(\s+fetch first\s+)\?", r"\g<1>1", query, flags=re.IGNORECASE)
query = re.sub(r"([\s+|\s*,]count\(\s*)\?(\s*\)\s+)", r"\g<1>1\g<2>", query, flags=re.IGNORECASE)
query = re.sub(r"([\s+|\s*,]decode\(\s*)\?(\s*,)", r"\1'1'\2", query, flags=re.IGNORECASE)
query = re.sub(r"([\s+|\s*,]extract\(\s*)\?(\s+from)", r"\1'day'\2", query, flags=re.IGNORECASE)
query = re.sub(r"([\s+|\s*,]concat)(\(.+\))", lambda x: x.group(1) + x.group(2).replace('?', '\'1\''), query)
return query
def replace_question_mark_with_dollar(query):
"""
Replacing '?' with '$+Numbers' in SQL:
input: UPDATE bmsql_customer SET c_balance = c_balance + $1, c_delivery_cnt = c_delivery_cnt + ?
WHERE c_w_id = $2 AND c_d_id = $3 AND c_id = $4 and c_info = ?;
output: UPDATE bmsql_customer SET c_balance = c_balance + $1, c_delivery_cnt = c_delivery_cnt + $5
WHERE c_w_id = $2 AND c_d_id = $3 AND c_id = $4 and c_info = $6;
note: if track_stmt_parameter is off, all '?' in SQL need to be replaced
"""
if '?' not in query:
return query
max_dollar_number = 0
dollar_parts = re.findall(r'(\$\d+)', query)
if dollar_parts:
max_dollar_number = max(int(item.strip('$')) for item in dollar_parts)
while '?' in query:
dollar = "$%s" % (max_dollar_number + 1)
query = query.replace('?', dollar, 1)
max_dollar_number += 1
return query
def exists_count_operation(query):
if re.search(r"[\s+|\s*,]count\(-?[\d+|\*]\)", query, flags=re.IGNORECASE):
return True
return False