# -*- coding: UTF-8 -*-
from clickhouse_driver import connect
from clickhouse_driver.util.escape import escape_chars_map
from sql.utils.sql_utils import get_syntax_type
from .models import ResultSet, ReviewResult, ReviewSet
from common.utils.timer import FuncTimer
from common.config import SysConfig
from . import EngineBase
import sqlparse
import logging
import re
logger = logging.getLogger("default")
class ClickHouseEngine(EngineBase):
test_query = "SELECT 1"
def __init__(self, instance=None):
super(ClickHouseEngine, self).__init__(instance=instance)
self.config = SysConfig()
def get_connection(self, db_name=None):
if self.conn:
return self.conn
if db_name:
self.conn = connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=db_name,
connect_timeout=10,
)
else:
self.conn = connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
connect_timeout=10,
)
return self.conn
name = "ClickHouse"
info = "ClickHouse engine"
def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return "%s" % "".join(escape_chars_map.get(c, c) for c in value)
@property
def auto_backup(self):
"""是否支持备份"""
return False
@property
def server_version(self):
sql = "select value from system.build_options where name = 'VERSION_FULL';"
result = self.query(sql=sql)
version = result.rows[0][0].split(" ")[1]
return tuple([int(n) for n in version.split(".")[:3]])
def get_table_engine(self, tb_name):
"""获取某个table的engine type"""
db, tb = tb_name.split(".")
sql = f"""select engine from system.tables where database=%(db)s and name=%(tb)s"""
query_result = self.query(sql=sql, parameters={"db": db, "tb": tb})
if query_result.rows:
result = {"status": 1, "engine": query_result.rows[0][0]}
else:
result = {"status": 0, "engine": "None"}
return result
def get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet"""
sql = "show databases"
result = self.query(sql=sql)
db_list = [
row[0]
for row in result.rows
if row[0]
not in ("system", "INFORMATION_SCHEMA", "information_schema", "datasets")
]
result.rows = db_list
return result
def get_all_tables(self, db_name, **kwargs):
"""获取table 列表, 返回一个ResultSet"""
sql = "show tables"
result = self.query(db_name=db_name, sql=sql)
tb_list = [row[0] for row in result.rows]
result.rows = tb_list
return result
def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
"""获取所有字段, 返回一个ResultSet"""
sql = f"""select
name,
type,
comment
from
system.columns
where
database = %(db_name)s
and table = %(tb_name)s;"""
result = self.query(
db_name=db_name,
sql=sql,
parameters={"db_name": db_name, "tb_name": tb_name},
)
column_list = [row[0] for row in result.rows]
result.rows = column_list
return result
def describe_table(self, db_name, tb_name, **kwargs):
"""return ResultSet 类似查询"""
tb_name = self.escape_string(tb_name)
sql = f"show create table `{tb_name}`;"
result = self.query(db_name=db_name, sql=sql)
result.rows[0] = (tb_name,) + (
result.rows[0][0].replace("(", "(\n ").replace(",", ",\n "),
)
return result
def query(
self,
db_name=None,
sql="",
limit_num=0,
close_conn=True,
parameters=None,
**kwargs,
):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
cursor = conn.cursor()
cursor.execute(sql, parameters)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
rows = cursor.fetchall()
fields = cursor.description
result_set.column_list = [i[0] for i in fields] if fields else []
result_set.rows = rows
result_set.affected_rows = len(rows)
except Exception as e:
logger.warning(f"ClickHouse语句执行报错,语句:{sql},错误信息{e}")
result_set.error = str(e).split("Stack trace")[0]
finally:
if close_conn:
self.close()
return result_set
def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
result["filtered_sql"] = sql.strip()
except IndexError:
result["bad_query"] = True
result["msg"] = "没有有效的SQL语句"
if re.match(r"^select|^show|^explain|^with", sql, re.I) is None:
result["bad_query"] = True
result["msg"] = "不支持的查询语法类型!"
if "*" in sql:
result["has_star"] = True
result["msg"] = "SQL语句中含有 * "
# clickhouse 20.6.3版本开始正式支持explain语法
if re.match(r"^explain", sql, re.I) and self.server_version < (20, 6, 3):
result["bad_query"] = True
result["msg"] = f"当前ClickHouse实例版本低于20.6.3,不支持explain!"
# select语句先使用Explain判断语法是否正确
if re.match(r"^select", sql, re.I) and self.server_version >= (20, 6, 3):
explain_result = self.query(db_name=db_name, sql=f"explain {sql}")
if explain_result.error:
result["bad_query"] = True
result["msg"] = explain_result.error
return result
def filter_sql(self, sql="", limit_num=0):
# 对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n
sql = sql.rstrip(";").strip()
if re.match(r"^select", sql, re.I):
# LIMIT N
limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I)
# LIMIT M OFFSET N
limit_offset = re.compile(r"limit\s+(\d+)\s+offset\s+(\d+)\s*$", re.I)
# LIMIT M,N
offset_comma_limit = re.compile(r"limit\s+(\d+)\s*,\s*(\d+)\s*$", re.I)
if limit_n.search(sql):
sql_limit = limit_n.search(sql).group(1)
limit_num = min(int(limit_num), int(sql_limit))
sql = limit_n.sub(f"limit {limit_num};", sql)
elif limit_offset.search(sql):
sql_limit = limit_offset.search(sql).group(1)
sql_offset = limit_offset.search(sql).group(2)
limit_num = min(int(limit_num), int(sql_limit))
sql = limit_offset.sub(f"limit {limit_num} offset {sql_offset};", sql)
elif offset_comma_limit.search(sql):
sql_offset = offset_comma_limit.search(sql).group(1)
sql_limit = offset_comma_limit.search(sql).group(2)
limit_num = min(int(limit_num), int(sql_limit))
sql = offset_comma_limit.sub(f"limit {sql_offset},{limit_num};", sql)
else:
sql = f"{sql} limit {limit_num};"
else:
sql = f"{sql};"
return sql
def explain_check(self, check_result, db_name=None, line=0, statement=""):
"""使用explain ast检查sql语法, 返回Review set"""
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage="None",
sql=statement,
affected_rows=0,
execute_time=0,
)
# clickhouse版本>=21.1.2 explain ast才支持非select语句检查
if self.server_version >= (21, 1, 2):
explain_result = self.query(db_name=db_name, sql=f"explain ast {statement}")
if explain_result.error:
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回未通过检查SQL",
errormessage=f"explain语法检查错误:{explain_result.error}",
sql=statement,
)
return result
def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
sql = sqlparse.format(sql, strip_comments=True)
sql_list = sqlparse.split(sql)
# 禁用/高危语句检查
check_result = ReviewSet(full_sql=sql)
line = 1
critical_ddl_regex = self.config.get("critical_ddl_regex", "")
p = re.compile(critical_ddl_regex)
check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML
for statement in sql_list:
statement = statement.rstrip(";")
# 禁用语句
if re.match(r"^select|^show", statement, re.M | re.IGNORECASE):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回不支持语句",
errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
sql=statement,
)
# 高危语句
elif critical_ddl_regex and p.match(statement.strip().lower()):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回高危SQL",
errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!",
sql=statement,
)
# alter语句
elif re.match(r"^alter", statement, re.M | re.IGNORECASE):
# alter table语句
if re.match(
r"^alter\s+table\s+(.+?)\s+", statement, re.M | re.IGNORECASE
):
table_name = re.match(
r"^alter\s+table\s+(.+?)\s+", statement, re.M | re.IGNORECASE
).group(1)
if "." not in table_name:
table_name = f"{db_name}.{table_name}"
table_engine = self.get_table_engine(table_name)["engine"]
table_exist = self.get_table_engine(table_name)["status"]
if table_exist == 1:
if not table_engine.endswith(
"MergeTree"
) and table_engine not in ("Merge", "Distributed"):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回不支持SQL",
errormessage="ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!",
sql=statement,
)
else:
# delete与update语句,实际是alter语句的变种
if re.match(
r"^alter\s+table\s+(.+?)\s+(delete|update)\s+",
statement,
re.M | re.IGNORECASE,
):
if not table_engine.endswith("MergeTree"):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回不支持SQL",
errormessage="DELETE与UPDATE仅支持*MergeTree引擎表!",
sql=statement,
)
else:
result = self.explain_check(
check_result, db_name, line, statement
)
else:
result = self.explain_check(
check_result, db_name, line, statement
)
else:
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="表不存在",
errormessage=f"表 {table_name} 不存在!",
sql=statement,
)
# 其他alter语句
else:
result = self.explain_check(check_result, db_name, line, statement)
# truncate语句
elif re.match(
r"^truncate\s+table\s+(.+?)(\s|$)", statement, re.M | re.IGNORECASE
):
table_name = re.match(
r"^truncate\s+table\s+(.+?)(\s|$)", statement, re.M | re.IGNORECASE
).group(1)
if "." not in table_name:
table_name = f"{db_name}.{table_name}"
table_engine = self.get_table_engine(table_name)["engine"]
table_exist = self.get_table_engine(table_name)["status"]
if table_exist == 1:
if table_engine in ("View", "File", "URL", "Buffer", "Null"):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回不支持SQL",
errormessage="TRUNCATE不支持View,File,URL,Buffer和Null表引擎!",
sql=statement,
)
else:
result = self.explain_check(
check_result, db_name, line, statement
)
else:
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="表不存在",
errormessage=f"表 {table_name} 不存在!",
sql=statement,
)
# insert语句,explain无法正确判断,暂时只做表存在性检查与简单关键字匹配
elif re.match(r"^insert", statement, re.M | re.IGNORECASE):
if re.match(
r"^insert\s+into\s+([a-zA-Z_][0-9a-zA-Z_.]+)([\w\W]*?)(values|format|select)(\s+|\()",
statement,
re.M | re.IGNORECASE,
):
table_name = re.match(
r"^insert\s+into\s+([a-zA-Z_][0-9a-zA-Z_.]+)([\w\W]*?)(values|format|select)(\s+|\()",
statement,
re.M | re.IGNORECASE,
).group(1)
if "." not in table_name:
table_name = f"{db_name}.{table_name}"
table_exist = self.get_table_engine(table_name)["status"]
if table_exist == 1:
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage="None",
sql=statement,
affected_rows=0,
execute_time=0,
)
else:
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="表不存在",
errormessage=f"表 {table_name} 不存在!",
sql=statement,
)
else:
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回不支持SQL",
errormessage="INSERT语法不正确!",
sql=statement,
)
# 其他语句使用explain ast简单检查
else:
result = self.explain_check(check_result, db_name, line, statement)
# 没有找出DDL语句的才继续执行此判断
if check_result.syntax_type == 2:
if get_syntax_type(statement, parser=False, db_type="mysql") == "DDL":
check_result.syntax_type = 1
check_result.rows += [result]
line += 1
# 统计警告和错误数量
for r in check_result.rows:
if r.errlevel == 1:
check_result.warning_count += 1
if r.errlevel == 2:
check_result.error_count += 1
return check_result
def execute_workflow(self, workflow):
"""执行上线单,返回Review set"""
sql = workflow.sqlworkflowcontent.sql_content
execute_result = ReviewSet(full_sql=sql)
sqls = sqlparse.format(sql, strip_comments=True)
sql_list = sqlparse.split(sqls)
line = 1
for statement in sql_list:
with FuncTimer() as t:
result = self.execute(
db_name=workflow.db_name, sql=statement, close_conn=True
)
if not result.error:
execute_result.rows.append(
ReviewResult(
id=line,
errlevel=0,
stagestatus="Execute Successfully",
errormessage="None",
sql=statement,
affected_rows=0,
execute_time=t.cost,
)
)
line += 1
else:
# 追加当前报错语句信息到执行结果中
execute_result.error = result.error
execute_result.rows.append(
ReviewResult(
id=line,
errlevel=2,
stagestatus="Execute Failed",
errormessage=f"异常信息:{result.error}",
sql=statement,
affected_rows=0,
execute_time=0,
)
)
line += 1
# 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
for statement in sql_list[line - 1 :]:
execute_result.rows.append(
ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage=f"前序语句失败, 未执行",
sql=statement,
affected_rows=0,
execute_time=0,
)
)
line += 1
break
return execute_result
def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
"""原生执行语句"""
result = ResultSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
try:
cursor = conn.cursor()
for statement in sqlparse.split(sql):
cursor.execute(statement, parameters)
cursor.close()
except Exception as e:
logger.warning(f"ClickHouse语句执行报错,语句:{sql},错误信息{e}")
result.error = str(e).split("Stack trace")[0]
if close_conn:
self.close()
return result
def close(self):
if self.conn:
self.conn.close()
self.conn = None