import argparse
import getpass
import logging
import re
import sys
from copy import deepcopy
from functools import partial
import sqlparse
from mo_sql_parsing import parse, format
from prettytable import PrettyTable
from sql_metadata.compat import get_query_tables
from sqlparse.sql import Parenthesis, IdentifierList, Values
from sqlparse.tokens import Punctuation, Whitespace
from dbmind import global_vars
from dbmind.common.parser.sql_parsing import get_generate_prepare_sqls_function
from dbmind.common.utils.checking import CheckWordValid, path_type, positive_int_type
from dbmind.common.utils.cli import read_input_from_pipe
from dbmind.constants import __version__
from .executor import Executor
from .rules import (AlwaysTrue, DistinctStar, OrderbyConst, Star2Columns, UnionAll, Delete2Truncate,
Or2In, In2Exists,
OrderbyConstColumns,
ImplicitConversion,
SelfJoin, Group2Hash)
from .rules import Rule
from .utils import get_table_names
def get_all_involved_tables(sql, table_names=None):
if table_names is None:
table_names = []
parsed_sql = parse(sql)
else:
parsed_sql = sql
if isinstance(parsed_sql, list):
for sub_parsed_sql in parsed_sql:
get_all_involved_tables(sub_parsed_sql, table_names)
elif isinstance(parsed_sql, dict):
if 'union' in parsed_sql or 'union_all' in parsed_sql:
return get_all_involved_tables(list(parsed_sql.values())[0], table_names)
if parsed_sql.get('from'):
if not (isinstance(parsed_sql['from'], dict) and parsed_sql.get('from')):
cur_table_names = get_table_names(parsed_sql['from'])
for table_name in cur_table_names:
if isinstance(table_name, str):
table_names.append(table_name)
else:
table_names.extend(get_all_involved_tables(parsed_sql['from'], []))
return table_names
def is_no_column_insert_sql(sql):
insert_p = re.compile(r'(\s+)?insert\s+into\s+\w+\s+values(\s+)?', re.IGNORECASE)
if insert_p.match(sql):
return True
def get_insert_value_number(sql):
for token in sqlparse.parse(sql)[0].tokens:
if isinstance(token, Values):
for _token in token:
if isinstance(_token, Parenthesis):
for x in _token.tokens:
if x.ttype in (Punctuation, Whitespace):
continue
if isinstance(x, IdentifierList):
return len([_ for _ in x.tokens if _.ttype not in (Whitespace, Punctuation)])
else:
return 1
def rewrite_no_column_insert_sql(sql, columns):
column_number = get_insert_value_number(sql)
res = sql.lower().replace(' values', f' ({",".join(columns[:column_number])}) values')
return res
def get_table_columns(sql, executor: Executor):
table2columns_mapper = dict()
involved_tables = get_all_involved_tables(sql)
for table_name in involved_tables:
table2columns_mapper[table_name] = executor.get_table_columns(table_name)
return table2columns_mapper
def exists_primary_key(tables, executor: Executor):
table_exists_primary = dict()
for table_name in tables:
table_exists_primary[table_name] = executor.exists_primary_key(table_name)
return table_exists_primary
def get_notnull_columns(tables, executor: Executor):
table_notnull_columns = dict()
for table_name in tables:
table_notnull_columns[table_name] = executor.get_notnull_columns(table_name)
return table_notnull_columns
class TableInfo:
def __init__(self):
self.table_columns = None
self.table_exists_primary = None
self.table_notnull_columns = None
def singleton(cls):
_instance = {}
def inner():
if cls not in _instance:
_instance[cls] = cls()
return _instance[cls]
return inner
def get_offline_rewriter():
rewriter = deepcopy(SQLRewriter())
rewriter.clear_rules()
rewriter.add_rule(ImplicitConversion)
rewriter.add_rule(OrderbyConstColumns)
rewriter.add_rule(AlwaysTrue)
rewriter.add_rule(UnionAll)
rewriter.add_rule(Delete2Truncate)
rewriter.add_rule(Or2In)
rewriter.add_rule(SelfJoin)
rewriter.add_rule(Group2Hash)
return rewriter
@singleton
class SQLRewriter:
def __init__(self):
self.rules = []
self.add_rule(In2Exists)
self.add_rule(DistinctStar)
self.add_rule(Star2Columns)
self.add_rule(ImplicitConversion)
self.add_rule(OrderbyConst)
self.add_rule(OrderbyConstColumns)
self.add_rule(AlwaysTrue)
self.add_rule(UnionAll)
self.add_rule(Delete2Truncate)
self.add_rule(Or2In)
self.add_rule(SelfJoin)
self.add_rule(Group2Hash)
def rewrite(self, sql, tableinfo=TableInfo(), if_format=True):
parsed_sql = parse(sql)
try:
checked_rules, parsed_sql = self._apply_rules(parsed_sql, tableinfo)
except Exception as e:
logging.warning(e)
return False, sql if sql.endswith(';') else sql + ';'
sql_string = format(parsed_sql) + ';'
sql_string = re.sub(r'"\$(\d+)"', r'$\1', sql_string)
if Delete2Truncate().__class__.__name__ in checked_rules:
return True, 'TRUNCATE TABLE ' + sql_string.split('(')[1].split(')')[0] + ';'
if if_format:
sql_string = sqlparse.format(sql_string, reindent=True, keyword_case='upper')
return True if checked_rules else False, sql_string
def add_rule(self, rule):
if not issubclass(rule, Rule):
raise NotImplementedError()
self.rules.append(rule)
def clear_rules(self):
self.rules = []
def _apply_rules(self, parsed_sql, tableinfo):
checked_rules = []
if not parsed_sql.get('delete'):
parsed_sql = parse(format(parsed_sql))
for rule in self.rules:
res, parsed_sql = rule().check_and_format(parsed_sql, tableinfo)
if res:
checked_rules.append(res)
return checked_rules, parsed_sql
def get_password():
password = read_input_from_pipe()
if password:
logging.warning("Read password from pipe.")
else:
password = getpass.getpass("Password for database user:")
if not password:
raise ValueError('Please input the password')
return password
def canbe_parsed(sql):
try:
parse(sql)
except Exception as e:
return False
return True
def rewrite_sql_api(database, sqls, rewritten_flags=None, if_format=True, driver=None):
rewritten_sqls = []
get_prepare_sqls = get_generate_prepare_sqls_function()
if rewritten_flags is None:
rewritten_flags = []
if driver is not None:
executor = partial(driver.query, force_connection_db=database)
else:
executor = partial(global_vars.agent_proxy.call, funcname='query_in_database', database=database)
schemas_results = executor(stmt='select distinct(table_schema) from information_schema.tables;', return_tuples=True)
schemas = ','.join([res[0] for res in schemas_results]) if schemas_results else 'public'
for _sql in sqls.split(';'):
if not _sql.strip():
continue
sql = _sql + ';'
formatted_sql = sqlparse.format(sql, keyword_case='lower', identifier_case='lower', strip_comments=True)
prepare_sqls = get_prepare_sqls(formatted_sql)
sql_checking_stmt = f'set current_schema={schemas};{";".join(prepare_sqls)}'
checking_results = executor(stmt=sql_checking_stmt, return_tuples=False, fetch_all=True)
if not checking_results:
rewritten_sqls.append(sql)
rewritten_flags.append(False)
continue
table_columns = dict()
table_exists_primary = dict()
table_notnull_columns = dict()
involved_tables = get_query_tables(formatted_sql)
tableinfo = TableInfo()
for table_name in involved_tables:
search_table_stmt = "select column_name, ordinal_position " \
"from information_schema.columns where table_name='%s';" % table_name
results = sorted(executor(stmt=search_table_stmt, return_tuples=True), key=lambda x: x[1])
table_columns[table_name] = [res[0] for res in results]
exists_primary_stmt = "SELECT count(*) FROM information_schema.table_constraints WHERE " \
"constraint_type in ('PRIMARY KEY', 'UNIQUE') AND table_name = '%s'" % table_name
table_exists_primary[table_name] = \
executor(stmt=exists_primary_stmt, return_tuples=True)[0][0]
notnull_columns_stmt = f"SELECT attname from pg_attribute where attrelid=(select oid from pg_class " \
f"where relname='{table_name}') and attnotnull=true"
table_notnull_columns[table_name] = [_tuple[0] for _tuple in
executor(stmt=notnull_columns_stmt, return_tuples=True)]
tableinfo.table_columns = table_columns
tableinfo.table_exists_primary = table_exists_primary
tableinfo.table_notnull_columns = table_notnull_columns
rewritten_flag, output_sql = SQLRewriter().rewrite(formatted_sql, tableinfo, if_format)
rewritten_flags.append(rewritten_flag)
rewritten_sqls.append(output_sql)
return '\n'.join(rewritten_sqls)
def main(argv):
arg_parser = argparse.ArgumentParser(
description='SQL Rewriter')
arg_parser.add_argument("db_port", help="Port for database", type=positive_int_type)
arg_parser.add_argument("database", help="Name for database", action=CheckWordValid)
arg_parser.add_argument("file", type=path_type, help="File containing SQL statements which need to rewrite")
arg_parser.add_argument(
"--db-host", help="Host for database", action=CheckWordValid)
arg_parser.add_argument(
"--db-user", help="Username for database log-in", action=CheckWordValid)
arg_parser.add_argument(
"--schema", help="Schema name for the current business data", default='public', action=CheckWordValid)
arg_parser.add_argument('-v', '--version', action='version', version=__version__)
args = arg_parser.parse_args(argv)
args.W = get_password()
executor = Executor(args.database, args.db_user, args.W, args.db_host, args.db_port, args.schema)
field_names = ('Raw SQL', 'Rewritten SQL')
output_table = PrettyTable()
output_table.field_names = field_names
output_table.align = "l"
with open(args.file) as file_h:
content = file_h.read()
for _sql in sqlparse.split(content):
if not _sql.strip():
continue
sql = _sql.strip() if _sql.strip().endswith(';') else _sql.strip() + ';'
if not executor.syntax_check(sql) or not canbe_parsed(sql):
output_table.add_row([sql, ''])
continue
formatted_sql = sqlparse.format(sql, keyword_case='lower', identifier_case='lower', strip_comments=True)
tableinfo = TableInfo()
tableinfo.table_columns = get_table_columns(formatted_sql, executor)
tables = tableinfo.table_columns.keys()
tableinfo.table_exists_primary = exists_primary_key(tables, executor)
tableinfo.table_notnull_columns = get_notnull_columns(tables, executor)
if is_no_column_insert_sql(formatted_sql):
if len(tables) != 1:
res = False
rewritten_sql = formatted_sql
else:
res = True
rewritten_sql = rewrite_no_column_insert_sql(sql, tableinfo.table_columns[list(tables)[0]])
else:
res, rewritten_sql = SQLRewriter().rewrite(formatted_sql, tableinfo)
if not executor.syntax_check(rewritten_sql) or not res:
output_table.add_row([sql, ''])
else:
output_table.add_row([sql, rewritten_sql])
print(output_table)
if __name__ == '__main__':
main(sys.argv[1:])