from dataclasses import dataclass, field
from functools import reduce
from typing import Any, List
import operator as op_module
from itertools import groupby
import logging
import intervals as intervals_module
from mo_sql_parsing import parse
@dataclass
class Interval:
interval: Any
column: str = None
def __sub__(self, other):
if self.column == other.column:
return Interval(self.interval - other.interval, self.column) if not (
self.interval - other.interval).is_empty() else False
return True
def __or__(self, other):
return Interval(self.interval | other.interval, self.column)
def groupby_column(intervals: List[Interval]):
intervals.sort(key=lambda x: x.column)
for column, group in groupby(intervals, key=lambda x: x.column):
yield column, group
@dataclass
class Intervals:
intervals: List[Interval] = field(default_factory=lambda: [])
def __and__(self, other):
results = Intervals()
all_intervals = self.intervals + other.intervals
if False in self.intervals:
intervals = Intervals()
intervals.append_interval(False)
return intervals
for interval in all_intervals[:]:
if interval is True:
all_intervals.remove(interval)
for column, group in groupby_column(all_intervals):
intervals = list(group)
cur_interval = intervals[0]
for next_interval in intervals[1:]:
cur_interval = Interval(op_module.and_(cur_interval.interval, next_interval.interval), column)
results.append_interval(cur_interval if not cur_interval.interval.is_empty() else False)
if False in results.intervals:
intervals = Intervals()
intervals.append_interval(False)
return intervals
return results
def append_interval(self, interval):
self.intervals.append(interval)
def is_empty(self):
if False in self.intervals:
return True
return False
def _gte(left, right):
""" Generate intervals for '>=' operator. """
intervals = Intervals()
if isinstance(left, str) and isinstance(right, (float, int)):
interval = Interval(intervals_module.closed(right, intervals_module.inf), left)
intervals.append_interval(interval)
elif isinstance(right, str) and isinstance(left, (float, int)):
interval = Interval(intervals_module.closed(-intervals_module.inf, left), right)
intervals.append_interval(interval)
else:
intervals = Intervals()
intervals.append_interval(True)
return intervals
return intervals
def _gt(left, right):
""" Generate intervals for '>' operator. """
intervals = Intervals()
if isinstance(left, str) and isinstance(right, (float, int)):
interval = Interval(intervals_module.open(right, intervals_module.inf), left)
intervals.append_interval(interval)
elif isinstance(right, str) and isinstance(left, (float, int)):
interval = Interval(intervals_module.open(-intervals_module.inf, left), right)
intervals.append_interval(interval)
else:
intervals = Intervals()
intervals.append_interval(True)
return intervals
return intervals
def _lte(left, right):
""" Generate intervals for '<=' operator. """
intervals = Intervals()
if isinstance(left, str) and isinstance(right, (float, int)):
interval = Interval(intervals_module.closed(-intervals_module.inf, right), left)
intervals.append_interval(interval)
elif isinstance(right, str) and isinstance(left, (float, int)):
interval = Interval(intervals_module.closed(left, intervals_module.inf), right)
intervals.append_interval(interval)
else:
intervals = Intervals()
intervals.append_interval(True)
return intervals
return intervals
def _lt(left, right):
""" Generate intervals for '<' operator. """
intervals = Intervals()
if isinstance(left, str) and isinstance(right, (float, int)):
interval = Interval(intervals_module.open(-intervals_module.inf, right), left)
intervals.append_interval(interval)
elif isinstance(right, str) and isinstance(left, (float, int)):
interval = Interval(intervals_module.open(left, intervals_module.inf), right)
intervals.append_interval(interval)
else:
intervals = Intervals()
intervals.append_interval(True)
return intervals
return intervals
def _eq(left, right):
""" Generate intervals for '=' operator. """
intervals = Intervals()
left, right = sorted([left, right], key=lambda x: 0 if isinstance(x, str) else 1)
if isinstance(left, str) and isinstance(right, (float, int)):
intervals.append_interval(Interval(intervals_module.singleton(right), left))
else:
intervals = Intervals()
intervals.append_interval(True)
return intervals
return intervals
def _between(column, left, right):
""" Generate intervals for 'between' operator. """
if isinstance(column, str) and isinstance(left, (float, int)) and isinstance(right, (float, int)):
left_intervals = _gte(column, left)
right_intervals = _lte(column, right)
return left_intervals & right_intervals
else:
intervals = Intervals()
intervals.append_interval(True)
return intervals
def _in(left, right):
""" Generate intervals for 'in' operator. """
intervals = Intervals()
if isinstance(left, str) and isinstance(right, list):
if all(isinstance(x, (float, int)) for x in right):
interval = _eq(left, right[0]).intervals[0]
for x in right[1:]:
interval = interval | _eq(left, x).intervals[0]
intervals.append_interval(interval)
return intervals
else:
intervals.append_interval(True)
return intervals
intervals = Intervals()
intervals.append_interval(True)
return intervals
def _nin(left, right):
""" Generate intervals for 'not in' operator. """
intervals = Intervals()
if isinstance(left, str) and isinstance(right, list):
interval = Interval(intervals_module.open(-intervals_module.inf, intervals_module.inf), left)
for x in right:
if isinstance(x, (float, int)):
interval = interval - _eq(left, x).intervals[0]
else:
intervals = Intervals()
intervals.append_interval(True)
return intervals
intervals.append_interval(interval)
return intervals
def to_intervals(data):
leaf_op_func = {'gt': _gt, 'gte': _gte, 'lt': _lt, 'lte': _lte, 'eq': _eq, 'in': _in, 'nin': _nin,
'between': _between}
for op in leaf_op_func:
if op in data:
return leaf_op_func[op](*data[op])
if 'and' in data:
intervals_list = []
for subdata in data['and']:
intervals_list.append(to_intervals(subdata))
intervals = reduce(op_module.and_, intervals_list)
return intervals
else:
intervals = Intervals()
intervals.append_interval(True)
return intervals
def combine(iterable, cur_element, total_list):
""" Append subsequences of elements from iterable to total_list,
e.g., ['AB', 'CD'] -> [['A','C'], ['A', 'D'], ['B', 'C'], ['B', 'D']].
"""
if not iterable:
return total_list
for subdata in iterable[0]:
cur_element.append(subdata)
if len(iterable) == 1:
total_list.append(cur_element[:])
combine(iterable[1:], cur_element, total_list)
cur_element.pop()
def split_parsed_sql(where_clause):
"""Split predicates of WHERE clause, e.g., (a OR b) AND c -> [a AND b, a AND c] """
if isinstance(where_clause, dict) and 'and' in where_clause:
total_list = []
cur_element = []
combine([split_parsed_sql(subdata) for subdata in where_clause['and']], cur_element, total_list)
return [{'and': subdata} for subdata in total_list]
elif isinstance(where_clause, dict) and 'or' in where_clause:
results = []
for subdata in where_clause['or']:
results.extend(split_parsed_sql(subdata))
return results
else:
return [where_clause]
def has_valid_intervals(sql: str):
try:
parsed_sql = parse(sql)
except Exception as e:
logging.error(e)
return None
if 'where' not in parsed_sql:
return None
else:
intervals_list = []
for _sub_parsed_sql in split_parsed_sql(parsed_sql['where']):
intervals = to_intervals(_sub_parsed_sql)
intervals_list.append(intervals)
return any(not intervals.is_empty() for intervals in intervals_list)