# -*- coding: UTF-8 -*-
import re
import time
import pymongo
import logging
import traceback
import subprocess
import simplejson as json
import datetime
import tempfile
from bson.son import SON
from bson import json_util
from pymongo.errors import OperationFailure
from dateutil.parser import parse
from bson.objectid import ObjectId
from bson.int64 import Int64
from sql.utils.data_masking import data_masking
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
from common.config import SysConfig
logger = logging.getLogger("default")
# mongo客户端安装在本机的位置
mongo = "mongo"
# 自定义异常
class mongo_error(Exception):
def __init__(self, error_info):
super().__init__(self)
self.error_info = error_info
def __str__(self):
return self.error_info
class JsonDecoder:
"""处理传入mongodb语句中的条件,并转换成pymongo可识别的字典格式"""
def __init__(self):
pass
def __json_object(self, tokener):
# obj = collections.OrderedDict()
obj = {}
if tokener.cur_token() != "{":
raise Exception('Json must start with "{"')
while True:
tokener.next()
tk_temp = tokener.cur_token()
if tk_temp == "}":
return {}
# 限制key的格式
if not isinstance(
tk_temp, str
): # or (not tk_temp.isidentifier() and not tk_temp.startswith("$"))
raise Exception("invalid key %s" % tk_temp)
key = tk_temp.strip()
tokener.next()
if tokener.cur_token() != ":":
raise Exception('expect ":" after "%s"' % key)
tokener.next()
val = tokener.cur_token()
if val == "[":
val = self.__json_array(tokener)
elif val == "{":
val = self.__json_object(tokener)
obj[key] = val
tokener.next()
tk_split = tokener.cur_token()
if tk_split == ",":
continue
elif tk_split == "}":
break
else:
if tk_split is None:
raise Exception('missing "}" at at the end of object')
raise Exception('unexpected token "%s" at key "%s"' % (tk_split, key))
return obj
def __json_array(self, tokener):
if tokener.cur_token() != "[":
raise Exception('Json array must start with "["')
arr = []
while True:
tokener.next()
tk_temp = tokener.cur_token()
if tk_temp == "]":
return []
if tk_temp == "{":
val = self.__json_object(tokener)
elif tk_temp == "[":
val = self.__json_array(tokener)
elif tk_temp in (",", ":", "}"):
raise Exception('unexpected token "%s"' % tk_temp)
else:
val = tk_temp
arr.append(val)
tokener.next()
tk_end = tokener.cur_token()
if tk_end == ",":
continue
if tk_end == "]":
break
else:
if tk_end is None:
raise Exception('missing "]" at the end of array')
return arr
def decode(self, json_str):
tokener = JsonDecoder.__Tokener(json_str)
if not tokener.next():
return None
first_token = tokener.cur_token()
if first_token == "{":
decode_val = self.__json_object(tokener)
elif first_token == "[":
decode_val = self.__json_array(tokener)
else:
raise Exception('Json must start with "{"')
if tokener.next():
raise Exception('unexpected token "%s"' % tokener.cur_token())
return decode_val
class __Tokener: # Tokener 作为一个内部类
def __init__(self, json_str):
self.__str = json_str
self.__i = 0
self.__cur_token = None
def __cur_char(self):
if self.__i < len(self.__str):
return self.__str[self.__i]
return ""
def __previous_char(self):
if self.__i < len(self.__str):
return self.__str[self.__i - 1]
def __remain_str(self):
if self.__i < len(self.__str):
return self.__str[self.__i :]
def __move_i(self, step=1):
if self.__i < len(self.__str):
self.__i += step
def __next_string(self):
"""当出现了"和'后就进入这个方法解析,直到出现与之对应的结束字符"""
outstr = ""
trans_flag = False
start_ch = ""
self.__move_i()
while self.__cur_char() != "":
ch = self.__cur_char()
if start_ch == "":
start_ch = self.__previous_char()
if ch == '\\"': # 判断是否是转义
trans_flag = True
else:
if not trans_flag:
if (ch == '"' and start_ch == '"') or (
ch == "'" and start_ch == "'"
):
break
else:
trans_flag = False
outstr += ch
self.__move_i()
return outstr
def __next_number(self):
expr = ""
while self.__cur_char().isdigit() or self.__cur_char() in (".", "+", "-"):
expr += self.__cur_char()
self.__move_i()
self.__move_i(-1)
if "." in expr:
return float(expr)
else:
return int(expr)
def __next_const(self):
"""处理没有被''和""包含的字符,如true和ObjectId"""
outstr = ""
data_type = ""
while self.__cur_char().isalpha() or self.__cur_char() in ("$", "_", " "):
outstr += self.__cur_char()
self.__move_i()
if outstr.replace(" ", "") in (
"ObjectId",
"newDate",
"ISODate",
"newISODate",
"NumberLong",
): # ======类似的类型比较多还需单独处理,如int()等
data_type = outstr
for c in self.__remain_str():
outstr += c
self.__move_i()
if c == ")":
break
self.__move_i(-1)
if outstr in ("true", "false", "null"):
return {"true": True, "false": False, "null": None}[outstr]
elif data_type == "ObjectId":
ojStr = re.findall(r"ObjectId\(.*?\)", outstr) # 单独处理ObjectId
if len(ojStr) > 0:
# return eval(ojStr[0])
id_str = re.findall(r"\(.*?\)", ojStr[0])
oid = id_str[0].replace(" ", "")[2:-2]
return ObjectId(oid)
elif data_type.replace(" ", "") in (
"newDate",
"ISODate",
"newISODate",
): # 处理时间格式
tmp_type = "%s()" % data_type
if outstr.replace(" ", "") == tmp_type.replace(" ", ""):
return datetime.datetime.now() + datetime.timedelta(
hours=-8
) # mongodb默认时区为utc
date_regex = re.compile(r'%s\("(.*)"\)' % data_type, re.IGNORECASE)
date_content = date_regex.findall(outstr)
if len(date_content) > 0:
return parse(date_content[0], yearfirst=True)
elif data_type.replace(" ", "") in ("NumberLong",):
nuStr = re.findall(r"NumberLong\(.*?\)", outstr) # 单独处理NumberLong
if len(nuStr) > 0:
id_str = re.findall(r"\(.*?\)", nuStr[0])
nlong = id_str[0].replace(" ", "")[2:-2]
return Int64(nlong)
elif outstr:
return outstr
raise Exception('Invalid symbol "%s"' % outstr)
def next(self):
is_white_space = lambda a_char: a_char in (
"\x20",
"\n",
"\r",
"\t",
) # 定义一个匿名函数
while is_white_space(self.__cur_char()):
self.__move_i()
ch = self.__cur_char()
if ch == "":
cur_token = None
elif ch in ("{", "}", "[", "]", ",", ":"):
cur_token = ch
elif ch in ('"', "'"): # 当字符为" '
cur_token = self.__next_string()
elif ch.isalpha() or ch in ("$", "_"): # 字符串是否只由字母和"$","_"组成
cur_token = self.__next_const()
elif ch.isdigit() or ch in (".", "-", "+"): # 检测字符串是否只由数字组成
cur_token = self.__next_number()
else:
raise Exception('Invalid symbol "%s"' % ch)
self.__move_i()
self.__cur_token = cur_token
return cur_token is not None
def cur_token(self):
return self.__cur_token
class MongoEngine(EngineBase):
error = None
warning = None
methodStr = None
def test_connection(self):
return self.get_all_databases()
def exec_cmd(self, sql, db_name=None, slave_ok=""):
"""审核时执行的语句"""
if self.port and self.host:
msg = ""
auth_db = self.instance.db_name or "admin"
sql_len = len(sql)
is_load = False # 默认不使用load方法执行mongodb sql语句
try:
if not sql.startswith("var host=") and sql_len > 4000:
# 在master节点执行的情况,如果sql长度大于4000,就采取load js的方法
# 因为用mongo load方法执行js脚本,所以需要重新改写一下sql,以便回显js执行结果
sql = "var result = " + sql + "\nprintjson(result);"
# 因为要知道具体的临时文件位置,所以用了NamedTemporaryFile模块
fp = tempfile.NamedTemporaryFile(
suffix=".js", prefix="mongo_", dir="/tmp/", delete=True
)
fp.write(sql.encode("utf-8"))
fp.seek(0) # 把文件指针指向开始,这样写的sql内容才能落到磁盘文件上
cmd = self._build_cmd(
db_name, auth_db, slave_ok, fp.name, is_load=True
)
is_load = True # 标记使用了load方法,用来在finally里面判断是否需要强制删除临时文件
elif (
not sql.startswith("var host=") and sql_len < 4000
): # 在master节点执行的情况, 如果sql长度小于4000,就直接用mongo shell执行,减少磁盘交换,节省性能
cmd = self._build_cmd(db_name, auth_db, slave_ok, sql=sql)
else:
cmd = self._build_cmd(
db_name, auth_db, sql=sql, slave_ok="rs.slaveOk();"
)
p = subprocess.Popen(
cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
re_msg = []
for line in iter(p.stdout.read, ""):
re_msg.append(line)
# 因为返回的line中也有可能带有换行符,因此需要先全部转换成字符串
__msg = "\n".join(re_msg)
_re_msg = []
for _line in __msg.split("\n"):
if not _re_msg and re.match("WARNING.*", _line):
# 第一行可能是WARNING语句,因此跳过
continue
_re_msg.append(_line)
msg = "\n".join(_re_msg)
msg = msg.replace("true\n", "")
except Exception as e:
logger.warning(
f"mongo语句执行报错,语句:{sql},{e}错误信息{traceback.format_exc()}"
)
finally:
if is_load:
fp.close()
return msg
# 用来进行判断是否有用户名与密码以及是否需要临时文件的情况,进而返回要执行的mongo命令
def _build_cmd(
self, db_name, auth_db, slave_ok="", tempfile_=None, sql=None, is_load=False
):
# 提取公共参数
common_params = {
"mongo": "mongo",
"host": self.host,
"port": self.port,
"db_name": db_name,
"auth_db": auth_db,
"slave_ok": slave_ok,
}
if is_load:
cmd_template = (
"{mongo} --quiet {auth_options} {host}:{port}/{auth_db} <<\\EOF\n"
"db=db.getSiblingDB('{db_name}');{slave_ok}load('{tempfile_}')\nEOF"
)
# 长度超限使用loadjs的方式运行,使用临时文件
common_params["tempfile_"] = tempfile_
else:
cmd_template = (
"{mongo} --quiet {auth_options} {host}:{port}/{auth_db} <<\\EOF\n"
"db=db.getSiblingDB('{db_name}');{slave_ok}{sql}\nEOF"
)
# 长度不超限直接mongo shell,无需临时文件
common_params["sql"] = sql
# 如果有账号密码,则添加选项
if self.user and self.password:
common_params["auth_options"] = "-u {uname} -p '{password}'".format(
uname=self.user, password=self.password
)
else:
common_params["auth_options"] = ""
return cmd_template.format(**common_params)
def get_master(self):
"""获得主节点的port和host"""
sql = "rs.isMaster().primary"
master = self.exec_cmd(sql)
if master != "undefined":
sp_host = master.replace('"', "").split(":")
self.host = sp_host[0]
self.port = int(sp_host[1])
# return master
def get_slave(self):
"""获得从节点的port和host"""
sql = """var host=""; rs.status().members.forEach(function(item) {i=1; if (item.stateStr =="SECONDARY") \
{host=item.name } }); print(host);"""
slave_msg = self.exec_cmd(sql, db_name=self.db_name)
# 如果是阿里云的云mongodb,会获取不到备节点真实的ip和端口,那就干脆不获取,直接用主节点来执行sql
# 如果是自建mongodb,获取到备节点的ip是192.168.1.33:27019这样的值;但如果是阿里云mongodb,获取到的备节点ip是SECONDARY、hiddenNode这样的值
# 所以,为了使代码更加通用,通过有无冒号来判断自建Mongod还是阿里云mongdb;没有冒号就判定为阿里云mongodb,直接返回false;
if ":" not in slave_msg:
return False
if slave_msg.lower().find("undefined") < 0:
sp_host = slave_msg.replace('"', "").split(":")
self.host = sp_host[0]
self.port = int(sp_host[1])
return True
else:
return False
def get_table_conut(self, table_name, db_name):
try:
count_sql = f"db.{table_name}.count()"
status = self.get_slave() # 查询总数据要求在slave节点执行
if self.host and self.port and status:
count = int(self.exec_cmd(count_sql, db_name, slave_ok="rs.slaveOk();"))
else:
count = int(self.exec_cmd(count_sql, db_name))
return count
except Exception as e:
logger.debug("get_table_conut:" + str(e))
return 0
def execute_workflow(self, workflow):
"""执行上线单,返回Review set"""
return self.execute(
db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
)
def execute(self, db_name=None, sql=""):
"""mongo命令执行语句"""
self.get_master()
execute_result = ReviewSet(full_sql=sql)
sql = sql.strip()
# 以;切分语句,逐句执行
sp_sql = sql.split(";")
line = 0
for exec_sql in sp_sql:
if not exec_sql == "":
exec_sql = exec_sql.strip()
try:
# DeprecationWarning: time.clock has been deprecated in Python 3.3 and will be removed from Python 3.8: use time.perf_counter or time.process_time instead
start = time.perf_counter()
r = self.exec_cmd(exec_sql, db_name)
end = time.perf_counter()
line += 1
logger.debug("执行结果:" + r)
# 如果执行中有错误
rz = r.replace(" ", "").replace('"', "")
tr = 1
if (
r.lower().find("syntaxerror") >= 0
or rz.find("ok:0") >= 0
or rz.find("error:invalid") >= 0
or rz.find("ReferenceError") >= 0
or rz.find("getErrorWithCode") >= 0
or rz.find("failedtoconnect") >= 0
or rz.find("Error:") >= 0
):
tr = 0
if (rz.find("errmsg") >= 0 or tr == 0) and (
r.lower().find("already exist") < 0
):
execute_result.error = r
result = ReviewResult(
id=line,
stage="Execute failed",
errlevel=2,
stagestatus="异常终止",
errormessage=f"mongo语句执行报错: {r}",
sql=exec_sql,
)
else:
try:
r = json.loads(r)
except Exception as e:
logger.info(str(e))
finally:
methodStr = exec_sql.split(").")[-1].split("(")[0].strip()
if "." in methodStr:
methodStr = methodStr.split(".")[-1]
if methodStr == "insert":
m = re.search(r'"nInserted"\s*:\s*(\d+)', r)
actual_affected_rows = int(m.group(1))
elif methodStr in ("insertOne", "insertMany"):
if isinstance(r, dict):
# mongosh / driver JSON formats
if "nInserted" in r: # BulkWriteResult style
actual_affected_rows = r["nInserted"]
elif (
"insertedIds" in r
): # CLI acknowledged + insertedIds
actual_affected_rows = len(r["insertedIds"])
elif "insertedId" in r: # insertOne single id
actual_affected_rows = 1
else:
actual_affected_rows = 0
elif isinstance(r, str):
# mongo 4.x CLI string outputs
m = re.search(r'"nInserted"\s*:\s*(\d+)', r)
actual_affected_rows = (
int(m.group(1)) if m else r.count("ObjectId")
)
actual_affected_rows = r.count("ObjectId")
else:
actual_affected_rows = 0
elif methodStr == "update":
m = re.search(
r'(?:"modifiedCount"|"nModified")\s*:\s*(\d+)',
r,
)
actual_affected_rows = int(m.group(1))
elif methodStr in ("updateOne", "updateMany"):
if isinstance(r, dict):
actual_affected_rows = r.get(
"modifiedCount", r.get("nModified", 0)
)
elif isinstance(r, str):
m = re.search(
r'(?:"modifiedCount"|"nModified")\s*:\s*(\d+)',
r,
)
actual_affected_rows = int(m.group(1)) if m else 0
else:
actual_affected_rows = 0
elif methodStr in ("deleteOne", "deleteMany"):
actual_affected_rows = r.get("deletedCount", 0)
elif methodStr == "remove":
actual_affected_rows = r.get("nRemoved", 0)
else:
actual_affected_rows = 0
# 把结果转换为ReviewSet
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="执行结束",
errormessage=str(r),
execute_time=round(end - start, 6),
affected_rows=actual_affected_rows,
sql=exec_sql,
)
execute_result.rows += [result]
except Exception as e:
logger.warning(
f"mongo语句执行报错,语句:{exec_sql},错误信息{traceback.format_exc()}"
)
execute_result.error = str(e)
# result_set.column_list = [i[0] for i in fields] if fields else []
return execute_result
def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
line = 1
count = 0
check_result = ReviewSet(full_sql=sql)
# 获取real_row_count参数选项
real_row_count = SysConfig().get("real_row_count", False)
sql = sql.strip()
# sql 检查过滤注释语句
sql = re.sub(r"^\s*//.*$", "", sql, flags=re.MULTILINE)
if sql.find(";") < 0:
raise Exception("提交的语句请以分号结尾")
# 以;切分语句,逐句执行
sp_sql = sql.split(";")
# 执行语句
for check_sql in sp_sql:
alert = "" # 警告信息
check_sql = check_sql.strip()
if not check_sql == "" and check_sql != "\n":
# check_sql = f'''{check_sql}'''
# check_sql = check_sql.replace('\n', '') #处理成一行
# 支持的命令列表
supportMethodList = [
"explain",
"bulkWrite",
"convertToCapped",
"createIndex",
"createIndexes",
"deleteOne",
"deleteMany",
"drop",
"dropIndex",
"dropIndexes",
"ensureIndex",
"insert",
"insertOne",
"insertMany",
"remove",
"replaceOne",
"renameCollection",
"update",
"updateOne",
"updateMany",
"createCollection",
"renameCollection",
]
# 需要有表存在为前提的操作
is_exist_premise_method = [
"convertToCapped",
"deleteOne",
"deleteMany",
"drop",
"dropIndex",
"dropIndexes",
"remove",
"replaceOne",
"renameCollection",
"update",
"updateOne",
"updateMany",
"renameCollection",
]
pattern = re.compile(
r"""^db\.createCollection\(([\s\S]*)\)$|^db\.([\w\.-]+)\.(?:[A-Za-z]+)(?:\([\s\S]*\)$)|^db\.getCollection\((?:\s*)(?:'|")([\w\.-]+)('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)$)"""
)
m = pattern.match(check_sql)
if (
m is not None
and (re.search(re.compile(r"}(?:\s*){"), check_sql) is None)
and check_sql.count("{") == check_sql.count("}")
and check_sql.count("(") == check_sql.count(")")
):
sql_str = m.group()
table_name = (
m.group(1) or m.group(2) or m.group(3)
).strip() # 通过正则的组拿到表名
table_name = table_name.replace('"', "").replace("'", "")
table_names = self.get_all_tables(db_name).rows
is_in = table_name in table_names # 检查表是否存在
if not is_in:
alert = f"\n提示:{table_name}文档不存在!"
if sql_str:
count = 0
if (
sql_str.find("createCollection") > 0
): # 如果是db.createCollection()
methodStr = "createCollection"
alert = ""
if is_in:
check_result.error = "文档已经存在"
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="文档已经存在",
errormessage="文档已经存在!",
affected_rows=count,
sql=check_sql,
)
check_result.rows += [result]
continue
else:
methodStr = sql_str.split(").")[-1].split("(")[0].strip()
if "." in methodStr:
methodStr = methodStr.split(".")[-1]
if methodStr in is_exist_premise_method and not is_in:
check_result.error = "文档不存在"
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="文档不存在",
errormessage=f"文档不存在,不能进行{methodStr}操作!",
sql=check_sql,
)
check_result.rows += [result]
continue
if methodStr in supportMethodList: # 检查方法是否支持
if (
methodStr == "createIndex"
or methodStr == "createIndexes"
or methodStr == "ensureIndex"
): # 判断是否创建索引,如果大于500万,提醒不能在高峰期创建
p_back = re.compile(
r"""(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)true|background\s*:\s*true|(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)(['"])(?:(?!\2)true)\2""",
re.M,
)
m_back = re.search(p_back, check_sql)
if m_back is None:
count = 5555555
check_result.warning = "创建索引请加background:true"
check_result.warning_count += 1
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="后台创建索引",
errormessage="创建索引没有加 background:true"
+ alert,
sql=check_sql,
)
elif not is_in:
count = 0
else:
count = self.get_table_conut(
table_name, db_name
) # 获得表的总条数
if count >= 5000000:
check_result.warning = (
alert
+ "大于500万条,请在业务低谷期创建索引"
)
check_result.warning_count += 1
result = ReviewResult(
id=line,
errlevel=1,
stagestatus="大表创建索引",
errormessage="大于500万条,请在业务低谷期创建索引!",
affected_rows=count,
sql=check_sql,
)
if count < 5000000:
# 检测通过
affected_all_row_method = [
"drop",
"dropIndex",
"dropIndexes",
"createIndex",
"createIndexes",
"ensureIndex",
]
if methodStr not in affected_all_row_method:
count = 0
else:
count = self.get_table_conut(
table_name, db_name
) # 获得表的总条数
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage="检测通过",
affected_rows=count,
sql=check_sql,
execute_time=0,
)
if real_row_count:
if methodStr == "insertOne":
count = 1
elif methodStr in ("insert", "insertMany"):
insert_str = re.search(
rf"{methodStr}\((.*)\)", sql_str, re.S
).group(1)
first_char = insert_str.replace(" ", "").replace(
"\n", ""
)[0]
if first_char == "{":
count = 1
elif first_char == "[":
insert_values = re.search(
r"\[(.*?)\]", insert_str, re.S
).group(0)
de = JsonDecoder()
insert_values = de.decode(insert_values)
count = len(insert_values)
else:
count = 0
elif methodStr in (
"update",
"updateOne",
"updateMany",
"deleteOne",
"deleteMany",
"remove",
):
if sql_str.find("find(") > 0:
count_sql = sql_str.replace(methodStr, "count")
else:
count_sql = (
sql_str.replace(methodStr, "find")
+ ".count()"
)
query_dict = self.parse_query_sentence(count_sql)
count_sql = f"""db.getCollection("{query_dict["collection"]}").find({query_dict["condition"]}).count()"""
query_result = self.query(db_name, count_sql)
count = json.loads(query_result.rows[0][0]).get(
"count", 0
)
if (
methodStr == "update"
and "multi:true"
not in sql_str.replace(" ", "")
.replace('"', "")
.replace("'", "")
.replace("\n", "")
) or methodStr in ("deleteOne", "updateOne"):
count = 1 if count > 0 else 0
if methodStr in (
"insertOne",
"insert",
"insertMany",
"update",
"updateOne",
"updateMany",
"deleteOne",
"deleteMany",
"remove",
):
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage="检测通过",
affected_rows=count,
sql=check_sql,
execute_time=0,
)
else:
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回不支持语句",
errormessage="仅支持DML和DDL语句,如需查询请使用数据库查询功能!",
sql=check_sql,
)
else:
check_result.error = "语法错误"
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="语法错误",
errormessage="请检查语句的正确性或(){} },{是否正确匹配!",
sql=check_sql,
)
check_result.rows += [result]
line += 1
count = 0
check_result.column_list = ["Result"] # 审核结果的列名
check_result.checked = True
check_result.warning = self.warning
# 统计警告和错误数量
for r in check_result.rows:
if r.errlevel == 1:
check_result.warning_count += 1
if r.errlevel == 2:
check_result.error_count += 1
return check_result
def get_connection(self, db_name=None):
self.db_name = db_name or self.instance.db_name or "admin"
auth_db = self.instance.db_name or "admin"
options = {
"host": self.host,
"port": self.port,
"username": self.user,
"password": self.password,
"authSource": auth_db,
"connect": True,
"connectTimeoutMS": 10000,
}
# only set TLS options while the instance enabled the TLS, to avoid
# tlsInsecure option being set but the instance is not enabled the TLS
# which would cause pymongo.ConfigurationError
if self.instance.is_ssl:
options["tls"] = True
options["tlsInsecure"] = not self.instance.verify_ssl
if self.user and self.password:
self.conn = pymongo.MongoClient(**options)
else:
self.conn = pymongo.MongoClient(**options)
return self.conn
def close(self):
if self.conn:
self.conn.close()
self.conn = None
name = "Mongo"
info = "Mongo engine"
def get_roles(self):
sql_get_roles = "db.system.roles.find({},{_id:1})"
result_set = self.query("admin", sql_get_roles)
rows = ["read", "readWrite", "userAdminAnyDatabase"]
for row in result_set.rows:
rows.append(row[1])
result_set.rows = rows
return result_set
def get_all_databases(self):
result = ResultSet()
conn = self.get_connection()
try:
db_list = conn.list_database_names()
except OperationFailure:
db_list = [self.db_name]
result.rows = db_list
return result
def get_all_tables(self, db_name, **kwargs):
result = ResultSet()
conn = self.get_connection()
db = conn[db_name]
result.rows = db.list_collection_names()
return result
def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
"""获取所有字段, 返回一个ResultSet"""
# https://github.com/getredash/redash/blob/master/redash/query_runner/mongodb.py
result = ResultSet()
db = self.get_connection()[db_name]
collection_name = tb_name
documents_sample = []
if "viewOn" in db[collection_name].options():
for d in db[collection_name].find().limit(2):
documents_sample.append(d)
else:
for d in db[collection_name].find().sort([("_id", 1)]).limit(1):
documents_sample.append(d)
for d in db[collection_name].find().sort([("_id", -1)]).limit(1):
documents_sample.append(d)
columns = []
# _merge_property_names
for document in documents_sample:
for prop in document:
if prop not in columns:
columns.append(prop)
result.column_list = ["COLUMN_NAME"]
result.rows = columns
return result
def describe_table(self, db_name, tb_name, **kwargs):
"""return ResultSet 类似查询"""
result = self.get_all_columns_by_tb(db_name=db_name, tb_name=tb_name)
result.rows = [
[
[r],
]
for r in result.rows
]
return result
@staticmethod
def dispose_str(parse_sql, start_flag, index):
"""解析处理字符串"""
stop_flag = ""
while index < len(parse_sql):
if parse_sql[index] == stop_flag and parse_sql[index - 1] != "\\":
return index
index += 1
stop_flag = start_flag
raise Exception("near column %s,' or \" has no close" % index)
def dispose_pair(self, parse_sql, index, begin, end):
"""解析处理需要配对的字符{}[]() 检索一个左括号计数器加1,右括号计数器减1"""
start_pos = -1
stop_pos = 0
count = 0
while index < len(parse_sql):
char = parse_sql[index]
if char == begin:
count += 1
if start_pos == -1:
start_pos = index
if char == end:
count -= 1
if count == 0:
stop_pos = index + 1
break
if char in ("'", '"'): # 避免字符串中带括号的情况,如{key:"{dd"}
index = self.dispose_str(parse_sql, char, index)
index += 1
if count > 0:
raise Exception(
"near column %s, The symbol %s has no closed" % (index, begin)
)
re_char = parse_sql[start_pos:stop_pos] # 截取
return index, re_char
def parse_query_sentence(self, parse_sql):
"""解析mongodb的查询语句,返回一个字典"""
index = 0
query_dict = {}
# 开始解析查询语句
while index < len(parse_sql):
char = parse_sql[index]
if char == "(":
# 获得语句中的方法名
head_sql = parse_sql[:index]
method = parse_sql[:index].split(".")[-1].strip()
index, re_char = self.dispose_pair(parse_sql, index, "(", ")")
re_char = re_char.lstrip("(").rstrip(")")
# 获得表名
if method and "collection" not in query_dict:
collection = head_sql.replace("." + method, "").replace("db.", "")
query_dict["collection"] = collection
# 分割查询条件和投影(返回字段)
if method == "find":
p_index, condition = self.dispose_pair(re_char, 0, "{", "}")
query_dict["condition"] = condition
query_dict["method"] = method
# 获取查询返回字段
projection = re_char[p_index:].strip()[2:]
if projection:
query_dict["projection"] = projection
# 聚合查询
elif method == "aggregate":
pipeline = []
agg_index = 0
while agg_index < len(re_char):
p_index, condition = self.dispose_pair(
re_char, agg_index, "{", "}"
)
agg_index = p_index + 1
if condition:
de = JsonDecoder()
step = de.decode(condition)
if "$sort" in step:
sort_list = []
for name, direction in step["$sort"].items():
sort_list.append((name, direction))
step["$sort"] = SON(sort_list)
pipeline.append(step)
query_dict["condition"] = pipeline
query_dict["method"] = method
elif method.lower() == "getcollection": # 获得表名
collection = re_char.strip().replace("'", "").replace('"', "")
query_dict["collection"] = collection
elif method.lower() == "getindexes":
query_dict["method"] = "index_information"
else:
query_dict[method] = re_char
index += 1
logger.debug(query_dict)
if query_dict:
return query_dict
def filter_sql(self, sql="", limit_num=0):
"""给查询语句改写语句, 返回修改后的语句"""
sql = sql.split(";")[0].strip()
# 执行计划
if sql.startswith("explain"):
sql = sql.replace("explain", "") + ".explain()"
return sql.strip()
def query_check(self, db_name=None, sql=""):
"""提交查询前的检查"""
sql = sql.strip()
sql = re.sub(r"^\s*//.*$", "", sql, flags=re.MULTILINE)
if sql.startswith("explain"):
sql = sql[7:] + ".explain()"
sql = re.sub("[;\s]*.explain\(\)$", ".explain()", sql).strip()
result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
pattern = re.compile(
r"""^db\.(\w+\.?)+(?:\([\s\S]*\)(\s*;*)$)|^db\.getCollection\((?:\s*)(?:'|")(\w+\.?)+('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)(\s*;*)$)"""
)
m = pattern.match(sql)
if m is not None:
logger.debug(sql)
query_dict = self.parse_query_sentence(sql)
if "method" not in query_dict:
result["msg"] += "错误:对不起,只支持查询相关方法"
result["bad_query"] = True
return result
collection_name = query_dict["collection"]
collection_names = self.get_all_tables(db_name).rows
is_in = collection_name in collection_names # 检查表是否存在
if not is_in:
result["msg"] += f"\n错误: {collection_name} 文档不存在!"
result["bad_query"] = True
return result
else:
result["msg"] += "请检查语句的正确性! 请使用原生查询语句"
result["bad_query"] = True
return result
def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
"""执行查询"""
result_set = ResultSet(full_sql=sql)
find_cmd = ""
# 提取命令中()中的内容
query_dict = self.parse_query_sentence(sql)
# 创建一个解析对象
de = JsonDecoder()
collection_name = query_dict["collection"]
if "method" in query_dict and query_dict["method"]:
method = query_dict["method"]
find_cmd = "collection." + method
if method == "index_information":
find_cmd += "()"
if "condition" in query_dict:
if method == "aggregate":
condition = query_dict["condition"]
# 给aggregate查询加limit行数限制,防止返回结果过多导致archery挂掉
condition.append({"$limit": limit_num})
if method == "find":
condition = de.decode(query_dict["condition"])
find_cmd += "(condition)"
if "projection" in query_dict and query_dict["projection"]:
projection = de.decode(query_dict["projection"])
find_cmd = find_cmd[:-1] + ",projection)"
if "sort" in query_dict and query_dict["sort"]:
sorting = []
for k, v in de.decode(query_dict["sort"]).items():
sorting.append((k, v))
find_cmd += ".sort(sorting)"
if (
method == "find"
and "limit" not in query_dict
and "explain" not in query_dict
):
find_cmd += ".limit(limit_num)"
if "limit" in query_dict and query_dict["limit"]:
query_limit = int(query_dict["limit"])
limit = min(limit_num, query_limit) if query_limit else limit_num
find_cmd += f".limit({limit})"
if "skip" in query_dict and query_dict["skip"]:
query_skip = int(query_dict["skip"])
find_cmd += f".skip({query_skip})"
if "count" in query_dict:
if condition:
find_cmd = "collection.count_documents(condition)"
else:
find_cmd = "collection.count_documents({})"
if "explain" in query_dict:
find_cmd += ".explain()"
try:
conn = self.get_connection()
db = conn[db_name]
collection = db[collection_name]
# 执行语句
logger.debug(find_cmd)
cursor = eval(find_cmd)
columns = []
rows = []
if "count" in query_dict:
columns.append("count")
rows.append({"count": cursor})
elif "explain" in query_dict: # 生成执行计划数据
columns.append("explain")
cursor = json.loads(json_util.dumps(cursor)) # bson转换成json
for k, v in cursor.items():
if k not in ("serverInfo", "ok"):
rows.append({k: v})
elif method == "index_information": # 生成返回索引数据
columns.append("index_list")
for k, v in cursor.items():
rows.append({k: v})
elif method == "aggregate" and sql.find("$group") >= 0: # 生成聚合数据
row = []
columns.insert(0, "mongodballdata")
for ro in cursor:
json_col = json.dumps(
ro, ensure_ascii=False, indent=2, separators=(",", ":")
)
row.insert(0, json_col)
for k, v in ro.items():
if k not in columns:
columns.append(k)
row.append(v)
rows.append(tuple(row))
row.clear()
rows = tuple(rows)
result_set.rows = rows
else:
cursor = json.loads(json_util.dumps(cursor))
cols = projection if "projection" in dir() else None
rows, columns = self.parse_tuple(cursor, db_name, collection_name, cols)
result_set.rows = rows
result_set.column_list = columns
result_set.affected_rows = len(rows)
if isinstance(rows, list):
logger.debug(rows)
result_set.rows = tuple(
[json.dumps(x, ensure_ascii=False, indent=2, separators=(",", ":"))]
for x in rows
)
except Exception as e:
logger.warning(
f"Mongo命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}"
)
result_set.error = str(e)
finally:
if close_conn:
self.close()
return result_set
def parse_tuple(self, cursor, db_name, tb_name, projection=None):
"""前端bootstrap-table显示,需要转化mongo查询结果为tuple((),())的格式"""
columns = []
rows = []
row = []
if projection:
for k in projection.keys():
columns.append(k)
else:
result = self.get_all_columns_by_tb(db_name=db_name, tb_name=tb_name)
columns = result.rows
columns.insert(0, "mongodballdata") # 隐藏JSON结果列
columns = self.fill_query_columns(cursor, columns)
for ro in cursor:
json_col = json.dumps(
ro, ensure_ascii=False, indent=2, separators=(",", ":")
)
row.insert(0, json_col)
for key in columns[1:]:
if key in ro:
value = ro[key]
if isinstance(value, list):
value = "(array) %d Elements" % len(value)
re_oid = re.compile(r"{\'\$oid\': \'[0-9a-f]{24}\'}")
re_date = re.compile(r"{\'\$date\': [0-9]{13}}")
# 转换$oid
ff = re.findall(re_oid, str(value))
for ii in ff:
value = str(value).replace(
ii, "ObjectId(" + ii.split(":")[1].strip()[:-1] + ")"
)
# 转换时间戳$date
dd = re.findall(re_date, str(value))
for d in dd:
t = int(d.split(":")[1].strip()[:-1])
e = datetime.datetime.fromtimestamp(t / 1000)
value = str(value).replace(
d, e.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
)
row.append(str(value))
else:
row.append("(N/A)")
rows.append(tuple(row))
row.clear()
return tuple(rows), columns
@staticmethod
def fill_query_columns(cursor, columns):
"""补充结果集中`get_all_columns_by_tb`未获取的字段"""
cols = columns
for ro in cursor:
for key in ro.keys():
if key not in cols:
cols.append(key)
return cols
def processlist(self, command_type, **kwargs):
"""
获取当前连接信息
command_type:
Full 包含活跃与不活跃的连接,包含内部的连接,即全部的连接状态
All 包含活跃与不活跃的连接,不包含内部的连接
Active 包含活跃
Inner 内部连接
"""
result_set = ResultSet(
full_sql='db.aggregate([{"$currentOp": {"allUsers":true, "idleConnections":true}}])'
)
try:
conn = self.get_connection()
processlists = []
if not command_type:
command_type = "Active"
if command_type in ["Full", "All", "Inner"]:
idle_connections = True
else:
idle_connections = False
# conn.admin.current_op() 这个方法已经被pymongo废除,但mongodb3.6+才支持aggregate
with conn.admin.aggregate(
[
{
"$currentOp": {
"allUsers": True,
"idleConnections": idle_connections,
}
}
]
) as cursor:
for operation in cursor:
# 对sharding集群的特殊处理
if "client" not in operation and operation.get(
"clientMetadata", {}
).get("mongos", {}).get("client", {}):
operation["client"] = operation["clientMetadata"]["mongos"][
"client"
]
# 获取此会话的用户名
effective_users_key = "effectiveUsers_user"
effective_users = operation.get("effectiveUsers", [])
if isinstance(effective_users, list) and effective_users:
first_user = effective_users[0]
if isinstance(first_user, dict):
operation[effective_users_key] = first_user.get("user", [])
else:
operation[effective_users_key] = None
else:
operation[effective_users_key] = None
# client_s 只是处理的mongos,并不是实际客户端
# client 在sharding获取不到?
if command_type in ["Full"]:
processlists.append(operation)
elif command_type in ["All", "Active"]:
if "clientMetadata" in operation:
processlists.append(operation)
elif command_type in ["Inner"]:
if not "clientMetadata" in operation:
processlists.append(operation)
result_set.rows = processlists
except Exception as e:
logger.warning(f"mongodb获取连接信息错误,错误信息{traceback.format_exc()}")
result_set.error = str(e)
return result_set
def get_kill_command(self, opids):
"""由传入的opid列表生成kill字符串"""
conn = self.get_connection()
active_opid = []
with conn.admin.aggregate(
[{"$currentOp": {"allUsers": True, "idleConnections": False}}]
) as cursor:
for operation in cursor:
if "opid" in operation and operation["opid"] in opids:
active_opid.append(operation["opid"])
kill_command = ""
for opid in active_opid:
if isinstance(opid, int):
kill_command = kill_command + "db.killOp({});".format(opid)
else:
kill_command = kill_command + 'db.killOp("{}");'.format(opid)
return kill_command
def kill_op(self, opids):
"""kill"""
result = ResultSet()
try:
conn = self.get_connection()
except Exception as e:
logger.error(f"{self.name} 连接失败, error: {str(e)}")
result.error = str(e)
return result
for opid in opids:
try:
conn.admin.command({"killOp": 1, "op": opid})
except Exception as e:
sql = {"killOp": 1, "op": opid}
logger.warning(
f"{self.name}语句执行killOp报错,语句:db.runCommand({sql}) ,错误信息{traceback.format_exc()}"
)
result.error = str(e)
return result
def get_all_databases_summary(self):
"""实例数据库管理功能,获取实例所有的数据库描述信息"""
query_result = self.get_all_databases()
if not query_result.error:
dbs = query_result.rows
conn = self.get_connection()
# 获取数据库用户信息
rows = []
for db_name in dbs:
# 执行语句
listing = conn[db_name].command(command="usersInfo")
grantees = []
for user_obj in listing["users"]:
grantees.append(
{"user": user_obj["user"], "roles": user_obj["roles"]}.__str__()
)
row = {
"db_name": db_name,
"grantees": grantees,
"saved": False,
}
rows.append(row)
query_result.rows = rows
return query_result
def get_instance_users_summary(self):
"""实例账号管理功能,获取实例所有账号信息"""
query_result = self.get_all_databases()
if not query_result.error:
dbs = query_result.rows
conn = self.get_connection()
# 获取数据库用户信息
rows = []
for db_name in dbs:
# 执行语句
listing = conn[db_name].command(command="usersInfo")
for user_obj in listing["users"]:
rows.append(
{
"db_name_user": f"{db_name}.{user_obj['user']}",
"db_name": db_name,
"user": user_obj["user"],
"roles": [role["role"] for role in user_obj["roles"]],
"saved": False,
}
)
query_result.rows = rows
return query_result
def create_instance_user(self, **kwargs):
"""实例账号管理功能,创建实例账号"""
exec_result = ResultSet()
db_name = kwargs.get("db_name", "")
user = kwargs.get("user", "")
password1 = kwargs.get("password1", "")
remark = kwargs.get("remark", "")
try:
conn = self.get_connection()
conn[db_name].command("createUser", user, pwd=password1, roles=[])
exec_result.rows = [
{
"instance": self.instance,
"db_name": db_name,
"user": user,
"password": password1,
"remark": remark,
}
]
except Exception as e:
exec_result.error = str(e)
return exec_result
def drop_instance_user(self, db_name_user: str, **kwarg):
"""实例账号管理功能,删除实例账号"""
arr = db_name_user.split(".")
db_name = arr[0]
user = arr[1]
exec_result = ResultSet()
try:
conn = self.get_connection()
conn[db_name].command("dropUser", user)
except Exception as e:
exec_result.error = str(e)
return exec_result
def reset_instance_user_pwd(self, db_name_user: str, reset_pwd: str, **kwargs):
"""实例账号管理功能,重置实例账号密码"""
arr = db_name_user.split(".")
db_name = arr[0]
user = arr[1]
exec_result = ResultSet()
try:
conn = self.get_connection()
conn[db_name].command("updateUser", user, pwd=reset_pwd)
except Exception as e:
exec_result.error = str(e)
return exec_result
def query_masking(self, db_name=None, sql="", resultset=None):
"""传入 sql语句, db名, 结果集,
返回一个脱敏后的结果集"""
mask_result = data_masking(self.instance, db_name, sql, resultset)
return mask_result