import abc
import operator as op_module
from collections import defaultdict
from copy import deepcopy
from itertools import groupby
from .utils import get_table_names, get_columns
OPERATOR = {
'gt': op_module.gt,
'gte': op_module.ge,
'lt': op_module.lt,
'lte': op_module.le,
'sub': op_module.sub,
'div': op_module.truediv,
'add': op_module.add,
'mul': op_module.mul
}
OPERATOR_PAIR = {'div': 'mul', 'mul': 'div', 'sub': 'add',
'add': 'sub'}
COMPARE_PAIR = {'gt': 'lt', 'gte': 'lte', 'lt': 'gt', 'lte': 'gte', 'eq': 'eq', 'neq': 'neq'}
COMMUTATIVE_OP = {'add', 'mul'}
SIGN_CONVERSION_OP = {'mul', 'div'}
AGG_FUNCTIONS = {'sum', 'count', 'min', 'max', 'avg'}
class Rule:
def __init__(self):
self.tableinfo = None
@abc.abstractmethod
def _check_and_format(self, parsed_sql, count):
pass
def check_and_format(self, parsed_sql, tableinfo):
self.tableinfo = tableinfo
count = self._check_and_format(parsed_sql, 0)
if count:
return self.__class__.__name__, parsed_sql
return None, parsed_sql
class Delete2Truncate(Rule):
"""It is recommended that the DELETE
without the WHERE condition be changed to TRUNCATE."""
def _check_and_format(self, parsed_sql, count):
if len(parsed_sql) == 1 and 'delete' in parsed_sql:
parsed_sql['truncate'] = parsed_sql.pop('delete')
return 1
return 0
class In2Exists(Rule):
"""Use (not) exists instead of (not) in if the associated field does not have a NULL value."""
def _check_and_format(self, parsed_sql, in_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
in_count = self._check_and_format(sub_parsed_sql, in_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return self._check_and_format(list(parsed_sql.values())[0], in_count)
if 'where' not in parsed_sql:
return in_count
if 'from' in parsed_sql and 'select' in parsed_sql:
if not isinstance(parsed_sql['from'], str):
return in_count
table1 = parsed_sql['from'].split('.')[-1]
if isinstance(parsed_sql['where'], dict):
if 'or' not in parsed_sql['where'] and 'and' not in parsed_sql['where']:
parsed_sql['where'] = {'and': [parsed_sql['where']]}
where_key = 'and' if parsed_sql['where'].get('and') else 'or'
for index, _sub in enumerate(parsed_sql['where'][where_key]):
if isinstance(_sub, dict) and ('in' in _sub or 'nin' in _sub):
column1 = list(_sub.values())[0][0].split('.')[-1]
if column1 not in self.tableinfo.table_notnull_columns.get(table1):
continue
right = list(_sub.values())[0][1]
if (not isinstance(right, dict)) and ('select' not in right) and ('from' not in right):
continue
if not isinstance(right['from'], str) and not (
isinstance(right['select'], dict) and 'value' in right['select']):
continue
table2 = right['from'].split('.')[-1]
column2 = right['select']['value'].split('.')[-1]
in_count += 1
new_condition = {'exists': {'select': '*', 'from': table2,
'where': {'eq': [f"{table1}.{column1}",
f"{table2}.{column2}"]}}}
if 'nin' in _sub:
new_condition = {'not': new_condition}
parsed_sql['where'][where_key][index] = new_condition
return in_count
class Group2Hash(Rule):
"""Enables the executor to select hashagg instead of groupagg."""
def _check_and_format(self, parsed_sql, groupagg_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
groupagg_count = self._check_and_format(sub_parsed_sql, groupagg_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return self._check_and_format(list(parsed_sql.values())[0], groupagg_count)
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
groupagg_count = self._check_and_format(parsed_sql['from'], groupagg_count)
if not parsed_sql.get('groupby') or not parsed_sql.get('select'):
return groupagg_count
if not isinstance(parsed_sql.get('select'), list):
parsed_sql['select'] = [parsed_sql['select']]
if not isinstance(parsed_sql.get('groupby'), list):
parsed_sql['groupby'] = [parsed_sql['groupby']]
distinct_columns = []
targetlist = []
sub_targetlist = []
for index, target in enumerate(parsed_sql['select']):
if isinstance(target, dict):
if isinstance(target.get('value'), str):
targetlist.append(target)
sub_targetlist.append(target)
continue
elif not (isinstance(target.get('value'), dict) and (not target.get('value').keys() -
AGG_FUNCTIONS - {'distinct'})):
return groupagg_count
agg_function = list(target.get('value').keys() - {'distinct'})[0]
if not isinstance(target.get('value').get(agg_function), str):
return groupagg_count
distinct_column = {'value': target.get('value').get(agg_function)}
distinct_columns.append(distinct_column)
targetlist.append({'value': {agg_function: distinct_column}})
if target.get('name'):
targetlist[-1].update({'name': target.get('name')})
if distinct_column not in sub_targetlist:
sub_targetlist.append(distinct_column)
if distinct_columns:
parsed_sql['select'] = targetlist
else:
return groupagg_count
sub_groupby_list = parsed_sql['groupby'] + distinct_columns
sub_parsed_sql = {'select': sub_targetlist, 'from': parsed_sql['from'], 'groupby': sub_groupby_list}
for other_key in parsed_sql.keys() - {'select', 'from', 'limit', 'groupby', 'orderby'}:
sub_parsed_sql[other_key] = parsed_sql.pop(other_key)
parsed_sql['from'] = sub_parsed_sql
groupagg_count += 1
return groupagg_count
class Star2Columns(Rule):
"""SELECT * type is not advised."""
def _check_and_format(self, parsed_sql, star_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
star_count = self._check_and_format(sub_parsed_sql, star_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return self._check_and_format(list(parsed_sql.values())[0], star_count)
if parsed_sql.get('select') == '*' or parsed_sql.get('select_distinct') == '*':
select_values = get_columns(self.tableinfo.table_columns, parsed_sql)
if select_values:
if parsed_sql.get('select'):
parsed_sql['select'] = select_values
else:
parsed_sql['select_distinct'] = select_values
star_count += 1
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
star_count = self._check_and_format(parsed_sql['from'], star_count)
return star_count
class Having2Where(Rule):
""" Having clause is not advised. """
@staticmethod
def merge_where(where_clause, having_clause):
if where_clause is None:
return having_clause
if 'and' in where_clause:
where_clause = where_clause['and']
else:
where_clause = [where_clause]
if 'and' in having_clause:
having_clause = having_clause['and']
else:
having_clause = [having_clause]
return where_clause + having_clause
def _check_and_format(self, parsed_sql, having_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
having_count = self._check_and_format(sub_parsed_sql, having_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
having_count = self._check_and_format(list(parsed_sql.values())[0], having_count)
if 'having' in parsed_sql:
having_count += 1
parsed_sql['where']['and'] = self.merge_where(parsed_sql.get('where'), parsed_sql['having'])
parsed_sql.pop('having')
return having_count
class AlwaysTrue(Rule):
"""Remove useless where clause."""
@staticmethod
def rm_true_expr(where_clause, index, true_count):
if isinstance(where_clause[index], (int, float, bool)):
if (where_clause[index] != 0) and (where_clause[index] is not False):
where_clause.pop(index)
true_count += 1
return true_count
elif isinstance(where_clause[index], dict):
if 'eq' in where_clause[index]:
if isinstance(where_clause[index]['eq'][0], (int, float, bool)) and \
isinstance(where_clause[index]['eq'][1], (int, float, bool)):
if op_module.eq(where_clause[index]['eq'][0], where_clause[index]['eq'][1]):
where_clause.pop(index)
true_count += 1
elif isinstance(where_clause[index]['eq'][0], dict) \
and isinstance(where_clause[index]['eq'][1], dict):
if 'literal' in where_clause[index]['eq'][0] and \
where_clause[index]['eq'][0].get('literal') == where_clause[index]['eq'][1].get('literal'):
where_clause.pop(index)
true_count += 1
elif 'neq' in where_clause[index]:
if isinstance(where_clause[index]['neq'][0], (int, float, bool)) and \
isinstance(where_clause[index]['neq'][1], (int, float, bool)):
if op_module.ne(where_clause[index]['neq'][0], where_clause[index]['neq'][1]):
where_clause.pop(index)
true_count += 1
elif isinstance(where_clause[index]['neq'][0], dict) and isinstance(where_clause[index]['neq'][1],
dict):
if 'literal' in where_clause[index]['neq'][0] \
and where_clause[index]['neq'][0].get('literal') \
!= where_clause[index]['neq'][1].get('literal'):
where_clause.pop(index)
true_count += 1
elif any(operator in where_clause[index] for operator in OPERATOR) \
and isinstance(list(where_clause[index].values())[0][0], (int, float)) \
and isinstance(list(where_clause[index].values())[0][1], (int, float)):
operator = list(where_clause[index].keys())[0]
res = OPERATOR[operator](where_clause[index][operator][0], where_clause[index][operator][1])
if res != 0:
where_clause.pop(index)
true_count += 1
elif 'and' in where_clause[index]:
length = len(where_clause[index]['and'])
for cur_idx in range(length):
true_count = AlwaysTrue.rm_true_expr(where_clause[index]['and'], length - 1 - cur_idx, true_count)
if not where_clause[index]['and']:
where_clause.pop(index)
elif 'or' in where_clause[index]:
length = len(where_clause[index]['or'])
for cur_idx in range(length):
true_count = AlwaysTrue.rm_true_expr(where_clause[index]['or'], length - 1 - cur_idx, true_count)
if len(where_clause[index]['or']) < length:
where_clause.pop(index)
return true_count
def _check_and_format(self, parsed_sql, true_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
true_count = self._check_and_format(sub_parsed_sql, true_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return self._check_and_format(list(parsed_sql.values())[0], true_count)
if 'where' in parsed_sql:
true_count += AlwaysTrue.rm_true_expr(parsed_sql, 'where', 0)
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
true_count = self._check_and_format(parsed_sql['from'], true_count)
return true_count
class DistinctStar(Rule):
"""Distinct * is not meaningful for primary keys."""
def _check_and_format(self, parsed_sql, distinctstar_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
distinctstar_count = self._check_and_format(sub_parsed_sql, distinctstar_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return self._check_and_format(list(parsed_sql.values())[0], distinctstar_count)
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
distinctstar_count = self._check_and_format(parsed_sql['from'], distinctstar_count)
if parsed_sql.get('select_distinct') == '*':
table_names = get_table_names(parsed_sql['from'])
for table_name in table_names:
if self.tableinfo.table_exists_primary.get(table_name):
parsed_sql['select'] = parsed_sql.pop('select_distinct')
distinctstar_count += 1
break
return distinctstar_count
class UnionAll(Rule):
"""Change Union to Union All."""
def _check_and_format(self, parsed_sql, union_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
union_count = self._check_and_format(sub_parsed_sql, union_count)
elif isinstance(parsed_sql, dict):
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
union_count = self._check_and_format(parsed_sql['from'], union_count)
if 'union' in parsed_sql:
union_count += 1
parsed_sql['union_all'] = parsed_sql['union']
parsed_sql.pop('union')
elif isinstance(parsed_sql.get('from'), dict) and 'union' in parsed_sql.get('from'):
parsed_sql['from']['union_all'] = parsed_sql['from'].pop('union')
return union_count
class OrderbyConst(Rule):
"""Transform constant in ORDER BY or GROUP BY to column name.
Example: "select id from test where id=1 order by 1.
"""
@staticmethod
def replace_const_by_column(parsed_sql, index, columns, checked=False):
if isinstance(parsed_sql[index], dict) and isinstance(parsed_sql[index].get('value'), int):
parsed_sql[index]['value'] = columns[parsed_sql[index]['value'] - 1]
checked = True
return checked
if isinstance(parsed_sql[index], list):
for secondary_index in range(len(parsed_sql[index])):
checked += OrderbyConst.replace_const_by_column(parsed_sql[index], secondary_index, columns, checked)
return checked
def _check_and_format(self, parsed_sql, orderby_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
orderby_count = self._check_and_format(sub_parsed_sql, orderby_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return self._check_and_format(list(parsed_sql.values())[0], orderby_count)
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
orderby_count = self._check_and_format(parsed_sql['from'], orderby_count)
if 'select' in parsed_sql:
select_key = 'select'
elif 'select_distinct' in parsed_sql:
select_key = 'select_distinct'
else:
return orderby_count
columns = []
if not isinstance(parsed_sql[select_key], list):
parsed_sql[select_key] = [parsed_sql[select_key]]
for _column in parsed_sql[select_key]:
if isinstance(_column, dict) and 'value' in _column:
columns.append(_column.get('name', _column['value']))
else:
columns = []
break
if not columns:
return orderby_count
if 'groupby' in parsed_sql:
orderby_count += self.replace_const_by_column(parsed_sql, 'groupby', columns)
if 'orderby' in parsed_sql:
orderby_count += self.replace_const_by_column(parsed_sql, 'orderby', columns)
return orderby_count
return orderby_count
class Or2In(Rule):
"""Transform the OR query with different conditions in the same column to the IN query."""
def _check_and_format(self, parsed_sql, or_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
or_count = self._check_and_format(sub_parsed_sql, or_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return Or2In()._check_and_format(list(parsed_sql.values())[0], or_count)
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
or_count = Or2In()._check_and_format(parsed_sql['from'], or_count)
if 'where' in parsed_sql:
or_count = Or2In()._check_and_format(parsed_sql['where'], or_count)
elif 'or' in parsed_sql:
or_count += Or2In.or2in(parsed_sql)
elif 'and' in parsed_sql:
or_count = Or2In()._check_and_format(parsed_sql['and'], or_count)
return or_count
@staticmethod
def or2in(parsed_sql):
column_eq_expr = []
other_expr = []
or_count = 0
for sub_parsed_sql in parsed_sql['or']:
if isinstance(sub_parsed_sql, dict) and 'eq' in sub_parsed_sql:
sub_parsed_sql['eq'].sort(key=lambda x: 0 if isinstance(x, str) else 1)
if isinstance(sub_parsed_sql['eq'][0], str):
column_eq_expr.append(sub_parsed_sql['eq'])
else:
other_expr.append(sub_parsed_sql)
elif isinstance(sub_parsed_sql, dict) and 'and' in sub_parsed_sql:
or_count = Or2In()._check_and_format(sub_parsed_sql['and'], 0)
other_expr.append(sub_parsed_sql)
else:
other_expr.append(sub_parsed_sql)
in_list = []
for column, group in groupby(sorted(column_eq_expr, key=lambda x: x[0]), key=lambda x: x[0]):
values = []
for sub_column, sub_value in group:
values.append(sub_value)
in_list.append({'in': [column, values]})
if not in_list:
return False if not or_count else True
if len(in_list) + len(other_expr) == 1:
parsed_sql['in'] = in_list[0]['in']
parsed_sql.pop('or')
else:
parsed_sql['or'] = in_list + other_expr
return True
class OrderbyConstColumns(Rule):
"""Delete useless conditions in ORDER BY or GROUP BY.
Example: "select id from test where id=1 order by id
"""
@staticmethod
def get_columns(whereclause, columns=None):
if columns is None:
columns = []
if not isinstance(whereclause, dict):
return []
if 'and' in whereclause:
for sub_clause in whereclause['and']:
OrderbyConstColumns.get_columns(sub_clause, columns)
elif 'eq' in whereclause:
whereclause['eq'].sort(key=lambda x: 0 if isinstance(x, str) else 1)
if isinstance(whereclause['eq'][0], str) and (isinstance(whereclause['eq'][1], (int, float)) or (
isinstance(whereclause['eq'][1], dict) and 'literal' in whereclause['eq'][1])):
columns.append(whereclause['eq'][0])
return columns
@staticmethod
def filter_columns(parsed_sql, index, columns):
if isinstance(parsed_sql[index], dict):
if 'value' in parsed_sql[index] and isinstance(parsed_sql[index]['value'], str):
if parsed_sql[index]['value'] in columns:
parsed_sql.pop(index)
return True
elif isinstance(parsed_sql[index], list):
checked = False
length = len(parsed_sql[index])
for secondary_idx in range(length - 1, -1, -1):
if isinstance(parsed_sql[index][secondary_idx], dict) and isinstance(
parsed_sql[index][secondary_idx].get('value'), str):
if parsed_sql[index][secondary_idx]['value'] in columns:
parsed_sql[index].pop(secondary_idx)
if len(parsed_sql[index]) != length:
checked = True
if not parsed_sql[index]:
parsed_sql.pop(index)
return checked
return False
def _check_and_format(self, parsed_sql, orderby_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
orderby_count = self._check_and_format(sub_parsed_sql, orderby_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return self._check_and_format(list(parsed_sql.values())[0], orderby_count)
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
orderby_count = self._check_and_format(parsed_sql['from'], orderby_count)
if isinstance(parsed_sql.get('where'), dict):
columns = OrderbyConstColumns.get_columns(parsed_sql.get('where'))
if isinstance(parsed_sql.get('orderby'), (dict, list)):
orderby_count += self.filter_columns(parsed_sql, 'orderby', columns)
if isinstance(parsed_sql.get('groupby'), (dict, list)):
orderby_count += self.filter_columns(parsed_sql, 'groupby', columns)
return orderby_count
class ImplicitConversion(Rule):
"""Expression transformation.
SQL: select * from table1 where col1 + 1 < 2 -> select * from table1 where col1 < 1
"""
@staticmethod
def _exists_int_or_float(left_right):
"""Check if there exists a column with int or float."""
left, right = left_right
if isinstance(left, str) and isinstance(right, (int, float)):
return True
if isinstance(right, str) and isinstance(left, (int, float)):
return True
return False
@staticmethod
def _check_implicit(where_clause):
"""Return checked bool, compare_op, operator"""
checked = False, None, None
for operator in COMPARE_PAIR.keys():
if operator in where_clause:
left, right = where_clause[operator]
if isinstance(right, dict) and list(right.keys())[0] in OPERATOR_PAIR \
and isinstance(left, (int, float)) \
and ImplicitConversion._exists_int_or_float(list(right.values())[0]):
right, left = left, right
compare_op = COMPARE_PAIR[operator]
where_clause.pop(operator)
where_clause[compare_op] = [left, right]
checked = True, compare_op, list(left.keys())[0]
elif isinstance(left, dict) \
and list(left.keys())[0] in OPERATOR_PAIR \
and isinstance(right, (int, float)) \
and ImplicitConversion._exists_int_or_float(list(left.values())[0]):
checked = True, operator, list(left.keys())[0]
return checked
@staticmethod
def _format_implicit(where_clause, index, implicit_count):
if isinstance(where_clause[index], dict):
checked, compare_op, operator = ImplicitConversion._check_implicit(where_clause[index])
if checked:
column_index = 0 if isinstance(where_clause[index][compare_op][0][operator][0], str) else 1
column = where_clause[index][compare_op][0][operator][column_index]
left_value = where_clause[index][compare_op][0][operator][1 - column_index]
right_value = where_clause[index][compare_op][1]
if operator == 'div' and column_index == 0 and (left_value == 0 or left_value == 0.0):
raise ZeroDivisionError
implicit_count += 1
if operator in COMMUTATIVE_OP or column_index == 0:
pair_op = OPERATOR_PAIR[operator]
else:
pair_op = operator
if operator in SIGN_CONVERSION_OP and left_value < 0:
compare_op_result = COMPARE_PAIR[compare_op]
else:
compare_op_result = compare_op
where_clause[index][compare_op_result] = where_clause[index].pop(compare_op)
try:
right_value = OPERATOR[pair_op](
right_value, left_value) if operator in \
COMMUTATIVE_OP or column_index == 0 else \
OPERATOR[pair_op](left_value, right_value)
except ZeroDivisionError:
if operator == 'mul':
if OPERATOR[compare_op](0, right_value) > 0:
where_clause[index] = True
else:
where_clause[index] = False
else:
if OPERATOR[compare_op](left_value, 0) > 0:
where_clause[index] = True
else:
where_clause[index] = False
return implicit_count
where_clause[index][compare_op_result] = [column, right_value]
return implicit_count
if 'and' in where_clause[index] or 'or' in where_clause[index]:
if 'and' in where_clause[index]:
key = 'and'
else:
key = 'or'
for subindex, value in enumerate(where_clause[index][key]):
implicit_count = ImplicitConversion._format_implicit(where_clause[index][key], subindex,
implicit_count)
return implicit_count
def _check_and_format(self, parsed_sql, implicit_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
implicit_count = self._check_and_format(sub_parsed_sql, implicit_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return self._check_and_format(list(parsed_sql.values())[0], implicit_count)
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
implicit_count = self._check_and_format(parsed_sql['from'], implicit_count)
if 'where' in parsed_sql:
implicit_count += ImplicitConversion._format_implicit(parsed_sql, 'where', 0)
return implicit_count
class SelfJoin(Rule):
"""Transform Self join to Union all.
Original SQL: select a.c_id from bmsql_customer a, bmsql_customer b where a.c_id - b.c_id <= 20 and a.c_id > b.c_id.
Rewritten SQL: SELECT * FROM
(SELECT a.c_id FROM bmsql_customer AS a, bmsql_customer AS b
WHERE TRUNC((a.c_id) / 20) = TRUNC(b.c_id / 20) AND a.c_id > b.c_id
UNION ALL
SELECT a.c_id FROM bmsql_customer AS a, bmsql_customer AS b
WHERE TRUNC((a.c_id) / 20) = TRUNC(b.c_id / 20 + 1) AND a.c_id - b.c_id <= 20);'
"""
@staticmethod
def is_selfjoin(from_clause):
if isinstance(from_clause, list) and len(from_clause) == 2:
return 'value' in from_clause[0] and 'value' in from_clause[1] and from_clause[0]['value'] == \
from_clause[1]['value']
return False
@staticmethod
def format_gt_lt(cond):
op_ = list(cond.keys())[0]
if op_ in {'gte', 'gt', 'lte', 'lt'}:
left, right = cond[op_]
if (isinstance(right, (int, float)) and isinstance(left, dict) and 'sub' in left) or (
isinstance(left, (int, float)) and isinstance(right, dict) and 'sub' in right):
if [left, right] != sorted([left, right], key=lambda x: 0 if isinstance(x, dict) else 1):
left, right = [right, left]
op_ = COMPARE_PAIR[op_]
if (not isinstance(left['sub'][0], str)) or (not isinstance(left['sub'][1], str)):
return False, None
elif (isinstance(right, str) and isinstance(left, dict) and 'sub' in left) or (
isinstance(left, str) and isinstance(right, dict) and 'sub' in right):
if [left, right] != sorted([left, right], key=lambda x: 0 if isinstance(x, dict) else 1):
left, right = [right, left]
op_ = COMPARE_PAIR[op_]
if (not isinstance(left['sub'][0], str)) and (not isinstance(left['sub'][1], (int, float))):
return False, None
left['sub'][1], right = right, left['sub'][1]
elif (isinstance(right, str) and isinstance(left, dict) and 'add' in left) or (
isinstance(left, str) and isinstance(right, dict) and 'add' in right):
if [left, right] != sorted([left, right], key=lambda x: 0 if isinstance(x, str) else 1):
left, right = [right, left]
op_ = COMPARE_PAIR[op_]
right['add'].sort(key=lambda x: 0 if isinstance(x, str) else 1)
if (not isinstance(right['add'][0], str)) and (not isinstance(right['add'][1], (int, float))):
return False, None
temp_left = {'sub': [left, right['add'][0]]}
temp_right = right['add'][1]
left, right = temp_left, temp_right
elif isinstance(left, str) and isinstance(right, str):
sub_left = {'sub': [left, right]}
right = 0
left = sub_left
else:
return False, None
op_, right = SelfJoin._sort_left_right(op_, left, right)
if '.' in left['sub'][0] and '.' in left['sub'][1] and left['sub'][0].split('.')[0] != \
left['sub'][1].split('.')[0]:
return True, {op_: [left, right]}
return False, None
@staticmethod
def _sort_left_right(op, left, right):
"""Sort the left_column and right_column for consistent order of all conditions."""
left_column, right_column = left['sub']
left['sub'].sort()
if [left_column, right_column] != left['sub']:
op = COMPARE_PAIR[op]
right = -right
return op, right
@staticmethod
def _generate_subselect_conditions(conds):
left_cond, right_cond = conds
left_op = list(left_cond.keys())[0]
right_op = list(right_cond.keys())[0]
left_value = left_cond[left_op][1]
right_value = right_cond[right_op][1]
final_right_value = right_value - left_value
left_column, right_column = left_cond[left_op][0]['sub']
final_left_cond = {left_op: [left_column, right_column]} if left_value == 0 or left_value == 0.0 else left_cond
conds1 = [
{'eq': [{'trunc': {
'div': [{'add': [left_column] if left_value == 0 or left_value == 0.0 else [left_column, -left_value]},
final_right_value]}},
{'trunc': {'div': [right_column, final_right_value]}}]}, final_left_cond]
conds2 = [{'eq': [{'trunc': {
'div': [{'add': [left_column] if left_value == 0 or left_value == 0.0 else [left_column, -left_value]},
final_right_value]}},
{'trunc': {'add': [{'div': [right_column, final_right_value]}, 1]}}]}, right_cond]
return conds1, conds2
@staticmethod
def _check_region_conds(new_conds, indexes):
"""Check if there exists a region column like 1 < table1.col1 - table2.col1 < 10."""
column_region = defaultdict(dict)
for cond, index in zip(new_conds, indexes):
op = list(cond.keys())[0]
left_column, right_column = cond[op][0]['sub']
if op in ('lt', 'lte'):
if ('lt' not in column_region[(left_column, right_column)]) and (
'lte' not in column_region[(left_column, right_column)]):
column_region[(left_column, right_column)][op] = (cond, index)
else:
if ('gt' not in column_region[(left_column, right_column)]) and (
'gte' not in column_region[(left_column, right_column)]):
column_region[(left_column, right_column)][op] = (cond, index)
for key, value in column_region.items():
if len(value.keys()) == 2:
region_conds = [None, None]
related_indexes = [None, None]
left_value = right_value = None
for cond, index in value.values():
op = list(cond.keys())[0]
if op in ('lt', 'lte'):
right_value = cond[op][1]
region_conds[1] = cond
related_indexes[1] = index
else:
left_value = cond[op][1]
region_conds[0] = cond
related_indexes[0] = index
if right_value > left_value:
return True, region_conds, related_indexes
return False, None, None
def selfjoin2unionall(self, parsed_sql):
if 'and' not in parsed_sql['where']:
where_clause = deepcopy(parsed_sql['where'])
parsed_sql['where'] = {}
parsed_sql['where']['and'] = [where_clause]
new_conds = []
indexes = []
for index, cond in enumerate(parsed_sql['where']['and']):
if not isinstance(cond, dict):
continue
checked, new_cond = SelfJoin.format_gt_lt(cond)
if checked:
new_conds.append(new_cond)
indexes.append(index)
region_checked, region_conds, related_indexes = SelfJoin._check_region_conds(new_conds, indexes)
if region_checked:
select_key = 'select' if 'select' in parsed_sql else 'select_distinct'
if isinstance(parsed_sql[select_key], dict):
parsed_sql[select_key] = [parsed_sql[select_key]]
table_prefix_columns_alias = [column['name'] if 'name' in column else column['value'] for column in
parsed_sql[select_key]]
table_prefix_columns = [column['value'] for column in parsed_sql[select_key]]
if 'orderby' in parsed_sql:
if not self._column_to_index(parsed_sql, table_prefix_columns, table_prefix_columns_alias, 'orderby'):
return False
if 'groupby' in parsed_sql:
if not self._column_to_index(parsed_sql, table_prefix_columns, table_prefix_columns_alias, 'groupby'):
return False
conds1, conds2 = SelfJoin._generate_subselect_conditions(region_conds)
where_values = deepcopy(parsed_sql['where']['and'])
for index in sorted(related_indexes, reverse=True):
where_values.pop(index)
where1 = where_values + conds1
where2 = where_values + conds2
old_select = deepcopy(parsed_sql[select_key])
old_from = deepcopy(parsed_sql['from'])
parsed_sql.pop(select_key)
parsed_sql['select'] = '*'
parsed_sql.pop('where')
left_select = {select_key: old_select, 'from': old_from, 'where': {'and': where1}}
right_select = {select_key: old_select, 'from': old_from, 'where': {'and': where2}}
parsed_sql['from'] = {'union_all': [left_select, right_select]}
return True
return False
@staticmethod
def _column_to_index(parsed_sql, table_prefix_columns, table_prefix_columns_alias, type):
columns = [column.split('.')[-1] for column in table_prefix_columns]
if isinstance(parsed_sql[type], dict):
parsed_sql[type] = [parsed_sql[type]]
for _column in parsed_sql[type]:
if isinstance(_column['value'], int):
continue
elif isinstance(_column['value'], str):
index = 0
if _column['value'] in columns:
index = columns.index(_column['value']) + 1
_column['value'] = index
elif _column['value'] in table_prefix_columns_alias:
index = table_prefix_columns_alias.index(_column['value']) + 1
_column['value'] = index
elif _column['value'] in table_prefix_columns:
index = table_prefix_columns.index(_column['value']) + 1
_column['value'] = index
if index == 0:
return False
return True
@staticmethod
def _check_feasibility(parsed_sql):
required_fields1 = {'select', 'where', 'from'}
required_fields2 = {'select_distinct', 'where', 'from'}
useful_fields = {'select', 'where', 'from', 'orderby', 'groupby', 'select_distinct'}
if not ((required_fields1 - set(parsed_sql.keys())) or required_fields2 - set(parsed_sql.keys())) or (
set(parsed_sql.keys()) - useful_fields):
return False
return True
def _check_and_format(self, parsed_sql, selfjoin_count):
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
selfjoin_count = self._check_and_format(sub_parsed_sql, selfjoin_count)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return self._check_and_format(list(parsed_sql.values())[0], selfjoin_count)
if 'from' in parsed_sql:
if not isinstance(parsed_sql['from'], list):
parsed_sql['from'] = [parsed_sql['from']]
selfjoin_count = self._check_and_format(parsed_sql['from'], selfjoin_count)
if not self._check_feasibility(parsed_sql):
return selfjoin_count
if 'from' in parsed_sql:
if not SelfJoin.is_selfjoin(parsed_sql['from']):
return self._check_and_format(parsed_sql['from'], selfjoin_count)
if 'where' in parsed_sql and set(parsed_sql.keys()):
selfjoin_count += self.selfjoin2unionall(parsed_sql)
return selfjoin_count