# -*- coding: UTF-8 -*-
import logging
import re
import traceback
import MySQLdb
import pymysql
import simplejson as json
from common.config import SysConfig
from sql.models import AliyunRdsConfig
from sql.utils.sql_utils import get_syntax_type
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
logger = logging.getLogger("default")
class GoInceptionEngine(EngineBase):
test_query = "INCEPTION GET VARIABLES"
name = "GoInception"
info = "GoInception engine"
def get_connection(self, db_name=None):
if self.conn:
return self.conn
if hasattr(self, "instance"):
self.conn = MySQLdb.connect(
host=self.host,
port=self.port,
charset=self.instance.charset or "utf8mb4",
connect_timeout=10,
)
return self.conn
archer_config = SysConfig()
go_inception_host = archer_config.get("go_inception_host")
go_inception_port = int(archer_config.get("go_inception_port", 4000))
go_inception_user = archer_config.get("go_inception_user", "")
go_inception_password = archer_config.get("go_inception_password", "")
self.conn = MySQLdb.connect(
host=go_inception_host,
port=go_inception_port,
user=go_inception_user,
passwd=go_inception_password,
charset="utf8mb4",
connect_timeout=10,
)
return self.conn
@staticmethod
def get_backup_connection():
archer_config = SysConfig()
backup_host = archer_config.get("inception_remote_backup_host")
backup_port = int(archer_config.get("inception_remote_backup_port", 3306))
backup_user = archer_config.get("inception_remote_backup_user")
backup_password = archer_config.get("inception_remote_backup_password", "")
return MySQLdb.connect(
host=backup_host,
port=backup_port,
user=backup_user,
passwd=backup_password,
charset="utf8mb4",
autocommit=True,
)
def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return pymysql.escape_string(value)
def execute_check(self, instance=None, db_name=None, sql=""):
"""inception check"""
# 判断如果配置了隧道则连接隧道
host, port, user, password = self.remote_instance_conn(instance)
check_result = ReviewSet(full_sql=sql)
# inception 校验
check_result.rows = []
variables, set_session_sql = get_session_variables(instance)
# 获取real_row_count参数选项
real_row_count = SysConfig().get("real_row_count", False)
real_row_count_option = "--real_row_count=true;" if real_row_count else ""
inception_sql = f"""/*--user='{user}';--password='{password}';--host='{host}';--port={port};--check=1;{real_row_count_option}*/
inception_magic_start;
{set_session_sql}
use `{db_name}`;
{sql.rstrip(';')};
inception_magic_commit;"""
inception_result = self.query(sql=inception_sql)
check_result.syntax_type = (
2 # TODO 工单类型 0、其他 1、DDL,2、DML 仅适用于MySQL,待调整
)
for r in inception_result.rows:
check_result.rows += [ReviewResult(inception_result=r)]
if r[2] == 1: # 警告
check_result.warning_count += 1
elif r[2] == 2: # 错误
check_result.error_count += 1
# 没有找出DDL语句的才继续执行此判断
if check_result.syntax_type == 2:
if get_syntax_type(r[5], parser=False, db_type="mysql") == "DDL":
check_result.syntax_type = 1
check_result.column_list = inception_result.column_list
check_result.checked = True
check_result.error = inception_result.error
check_result.warning = inception_result.warning
return check_result
def execute(self, workflow=None):
"""执行上线单"""
instance = workflow.instance
# 判断如果配置了隧道则连接隧道
host, port, user, password = self.remote_instance_conn(instance)
execute_result = ReviewSet(full_sql=workflow.sqlworkflowcontent.sql_content)
if workflow.is_backup:
str_backup = "--backup=1"
else:
str_backup = "--backup=0"
# 提交inception执行
variables, set_session_sql = get_session_variables(instance)
sql_execute = f"""/*--user='{user}';--password='{password}';--host='{host}';--port={port};--execute=1;--ignore-warnings=1;{str_backup};--sleep=200;--sleep_rows=100*/
inception_magic_start;
{set_session_sql}
use `{workflow.db_name}`;
{workflow.sqlworkflowcontent.sql_content.rstrip(';')};
inception_magic_commit;"""
inception_result = self.query(sql=sql_execute)
# 执行报错,inception crash或者执行中连接异常的场景
if inception_result.error and not execute_result.rows:
execute_result.error = inception_result.error
execute_result.rows = [
ReviewResult(
stage="Execute failed",
errlevel=2,
stagestatus="异常终止",
errormessage=f"goInception Error: {inception_result.error}",
sql=workflow.sqlworkflowcontent.sql_content,
)
]
return execute_result
# 把结果转换为ReviewSet
for r in inception_result.rows:
execute_result.rows += [ReviewResult(inception_result=r)]
# 如果发现任何一个行执行结果里有errLevel为1或2,并且状态列没有包含Execute Successfully,则最终执行结果为有异常.
for r in execute_result.rows:
if r.errlevel in (1, 2) and not re.search(
r"Execute Successfully", r.stagestatus
):
execute_result.error = "Line {0} has error/warning: {1}".format(
r.id, r.errormessage
)
break
return execute_result
def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
conn = self.get_connection()
try:
cursor = conn.cursor()
effect_row = cursor.execute(sql)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
rows = cursor.fetchall()
fields = cursor.description
result_set.column_list = [i[0] for i in fields] if fields else []
result_set.rows = rows
result_set.affected_rows = effect_row
except Exception as e:
logger.warning(f"goInception语句执行报错,错误信息{traceback.format_exc()}")
result_set.error = str(e)
if close_conn:
self.close()
return result_set
def query_print(self, instance, db_name=None, sql=""):
"""
打印语法树。
"""
# 判断如果配置了隧道则连接隧道
host, port, user, password = self.remote_instance_conn(instance)
sql = f"""/*--user='{user}';--password='{password}';--host='{host}';--port={port};--enable-query-print;*/
inception_magic_start;\
use `{db_name}`;
{sql.rstrip(';')};
inception_magic_commit;"""
print_info = self.query(db_name=db_name, sql=sql).to_dict()[1]
if print_info.get("errmsg"):
raise RuntimeError(print_info.get("errmsg"))
return print_info
def query_data_masking(self, instance, db_name=None, sql=""):
"""
将sql交给goInception打印语法树,获取select list
使用 masking 参数,可参考 https://github.com/hanchuanchuan/goInception/pull/355
"""
# 判断如果配置了隧道则连接隧道
host, port, user, password = self.remote_instance_conn(instance)
sql = f"""/*--user={user};--password={password};--host={host};--port={port};--masking=1;*/
inception_magic_start;
use `{db_name}`;
{sql}
inception_magic_commit;"""
query_result = self.query(db_name=db_name, sql=sql)
# 有异常时主动抛出
if query_result.error:
raise RuntimeError(f"Inception Error: {query_result.error}")
if not query_result.rows:
raise RuntimeError(f"Inception Error: 未获取到语法信息")
# 兼容某些异常场景下返回内容为审核结果的问题 https://github.com/hhyo/Archery/issues/1826
print_info = query_result.to_dict()[0]
if "error_level" in print_info:
raise RuntimeError(f'Inception Error: {print_info.get("error_message")}')
if print_info.get("errlevel") == 0 and print_info.get("errmsg") is None:
return json.loads(print_info["query_tree"])
else:
raise RuntimeError(f'Inception Error: print_info.get("errmsg")')
def get_rollback(self, workflow):
"""
获取回滚语句,并且按照执行顺序倒序展示,return ['源语句','回滚语句']
"""
list_execute_result = json.loads(
workflow.sqlworkflowcontent.execute_result or "[]"
)
# 回滚语句倒序展示
list_execute_result.reverse()
list_backup_sql = []
# 创建连接
conn = self.get_backup_connection()
cur = conn.cursor()
for row in list_execute_result:
try:
# 获取backup_db_name, 兼容旧数据'[[]]'格式
if isinstance(row, list):
if row[8] == "None":
continue
backup_db_name = row[8]
sequence = row[7]
sql = row[5]
# 新数据
else:
if row.get("backup_dbname") in ("None", ""):
continue
backup_db_name = row.get("backup_dbname")
sequence = row.get("sequence")
sql = row.get("sql")
# 获取备份表名
opid_time = sequence.replace("'", "")
sql_table = f"""select tablename
from {backup_db_name}.$_$Inception_backup_information$_$
where opid_time='{opid_time}';"""
cur.execute(sql_table)
list_tables = cur.fetchall()
if list_tables:
# 获取备份语句
table_name = list_tables[0][0]
sql_back = f"""select rollback_statement
from {backup_db_name}.{table_name}
where opid_time='{opid_time}'"""
cur.execute(sql_back)
list_backup = cur.fetchall()
# 拼接成回滚语句列表,['源语句','回滚语句']
list_backup_sql.append(
[sql, "\n".join([back_info[0] for back_info in list_backup])]
)
except Exception as e:
logger.error(f"获取回滚语句报错,异常信息{traceback.format_exc()}")
raise Exception(e)
# 关闭连接
if conn:
conn.close()
return list_backup_sql
def get_variables(self, variables=None):
"""获取实例参数"""
if variables:
sql = f"inception get variables like '{variables[0]}';"
else:
sql = "inception get variables;"
return self.query(sql=sql)
def set_variable(self, variable_name, variable_value):
"""修改实例参数值"""
sql = f"""inception set {variable_name}={variable_value};"""
return self.query(sql=sql)
def osc_control(self, **kwargs):
"""控制osc执行,获取进度、终止、暂停、恢复等"""
sqlsha1 = self.escape_string(kwargs.get("sqlsha1", ""))
command = self.escape_string(kwargs.get("command", ""))
if command == "get":
sql = f"inception get osc_percent '{sqlsha1}';"
else:
sql = f"inception {command} osc '{sqlsha1}';"
return self.query(sql=sql)
@staticmethod
def get_table_ref(query_tree, db_name=None):
__author__ = "xxlrr"
"""
* 从goInception解析后的语法树里解析出兼容Inception格式的引用表信息。
* 目前的逻辑是在SQL语法树中通过递归查找选中最小的 TableRefs 子树(可能有多个),
然后在最小的 TableRefs 子树选中Source节点来获取表引用信息。
* 查找最小TableRefs子树的方案竟然是通过逐步查找最大子树(直到找不到)来获得的,
具体为什么这样实现,我不记得了,只记得当时是通过猜测goInception的语法树生成规
则来写代码,结果猜一次错一次错一次猜一次,最终代码逐渐演变于此。或许直接查找最
小子树才是效率较高的算法,但是就这样吧,反正它能运行 :)
"""
table_ref = []
find_queue = [query_tree]
for tree in find_queue:
tree = DictTree(tree)
# nodes = tree.find_max_tree("TableRefs") or tree.find_max_tree("Left", "Right")
nodes = tree.find_max_tree("TableRefs", "Left", "Right")
if nodes:
# assert isinstance(v, dict) is true
find_queue.extend([v for node in nodes for v in node.values() if v])
else:
snodes = tree.find_max_tree("Source")
if snodes:
table_ref.extend(
[
{
"schema": snode["Source"].get("Schema", {}).get("O")
or db_name,
"name": snode["Source"].get("Name", {}).get("O", ""),
}
for snode in snodes
]
)
# assert: source node must exists if table_refs node exists.
# else:
# raise Exception("GoInception Error: not found source node")
return table_ref
def close(self):
if self.conn:
self.conn.close()
self.conn = None
class DictTree(dict):
def find_max_tree(self, *keys):
__author__ = "xxlrr"
"""通过广度优先搜索算法查找满足条件的最大子树(不找叶子节点)"""
fit = []
find_queue = [self]
for tree in find_queue:
for k, v in tree.items():
if k in keys:
fit.append({k: v})
elif isinstance(v, dict):
find_queue.append(v)
elif isinstance(v, list):
find_queue.extend([n for n in v if isinstance(n, dict)])
return fit
def get_session_variables(instance):
"""按照目标实例动态设置goInception的会话参数,可用于按照业务组自定义审核规则等场景"""
variables = {}
set_session_sql = ""
if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists():
variables.update(
{
"ghost_aliyun_rds": "on",
"ghost_allow_on_master": "true",
"ghost_assume_rbr": "true",
}
)
# 转换成SQL语句
for k, v in variables.items():
set_session_sql += f"inception set session {k} = '{v}';\n"
return variables, set_session_sql