# -*- coding: UTF-8 -*-
import logging
import pymemcache
from typing import List, Tuple
from django.core.checks.security.base import check_secret_key
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
from sql.models import SqlWorkflow
logger = logging.getLogger("default")
class MemcachedEngine(EngineBase):
test_query = "stats"
name = "Memcached"
info = "Memcached engine"
def __init__(self, instance=None):
super().__init__(instance=instance)
# 用于存储多个节点连接: db_name -> conn
# 如果 instance.host 使用 , 分割
self.nodes = {}
if not instance:
return
for i, host in enumerate(instance.host.split(",")):
db_name = f"Node - {i}"
self.nodes[db_name] = host.strip()
def get_connection(self, db_name=None):
db_name = db_name or "Node - 0"
if db_name not in self.nodes:
logger.warning(f"Memcached节点 {db_name} 不存在,使用默认节点 {db_name}")
raise Exception(f"Memcached节点 {db_name} 不存在")
node_host = self.nodes[db_name]
try:
conn = pymemcache.Client(
server=(node_host, self.port), connect_timeout=10.0, timeout=10.0
)
return conn
except Exception as e:
raise Exception(f"连接Memcached节点 {node_host} 失败: {str(e)}")
def test_connection(self):
"""测试实例链接是否正常"""
try:
conn = self.get_connection(None)
# 使用 version 命令测试
version = conn.version()
if version:
return ResultSet(
rows=[[f"连接成功,版本: {version}"]], column_list=["状态"]
)
except Exception as e:
logger.error(f"测试连接失败: {str(e)}")
raise Exception(f"测试连接失败: {str(e)}")
def get_all_databases(self):
"""获取所有可用节点,将节点作为"数据库"返回"""
result_set = ResultSet(column_list=["节点"], rows=[])
try:
for db_name in self.nodes:
result_set.rows.append([db_name])
return result_set
except Exception as e:
logger.error(f"获取所有节点失败: {str(e)}")
raise Exception(f"获取所有节点失败: {str(e)}")
def get_all_tables(self, db_name, **kwargs):
return ResultSet(rows=[])
# 修改后的 query 方法
def query(
self,
db_name=None,
sql="",
limit_num=0,
close_conn=True,
parameters=None,
**kwargs,
):
"""实际查询 返回一个ResultSet,采用cmd table驱动模式"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name)
result_set = _handle_cmd(conn, sql)
except Exception as e:
logger.error(f"查询执行失败: {str(e)}")
result_set.error = str(e)
result_set.rows = [[f"错误: {str(e)}"]]
finally:
if close_conn:
# 只关闭默认连接,保留节点连接
if self.conn:
self.conn = None
# 不关闭节点连接,因为可能会在后续查询中使用
return result_set
def query_check(self, db_name=None, sql=""):
"""查询语句的检查、注释去除、切分, 返回一个字典 {'bad_query': bool, 'filtered_sql': str}"""
# 简单的SQL语法检查
cmd, cmd_args = _parse_cmd_args(sql)
allowed_commands = [
"version",
"get",
"gets",
]
if cmd not in allowed_commands:
return {
"bad_query": True,
"filtered_sql": sql,
"msg": "仅支持 (version, get, gets) 命令",
}
return {"bad_query": False, "filtered_sql": sql}
def execute(self, db_name=None, sql="", **kwargs):
execute_result = ReviewSet(full_sql=sql)
try:
conn = self.get_connection(db_name)
cmd_result = _handle_cmd(conn, sql)
assert len(cmd_result.rows) == 1, "命令执行结果行数不是1"
assert len(cmd_result.rows[0]) == 1, "命令执行结果列数不是1"
if cmd_result.rows[0][0] == "FAIL":
execute_result.rows.append(
ReviewResult(
id=1,
affected_rows=0,
sql=sql,
stage="Execute",
stagestatus="Fail",
)
)
else:
execute_result.rows.append(
ReviewResult(
id=1,
affected_rows=1,
sql=sql,
stage="Execute",
stagestatus="Success",
)
)
execute_result.affected_rows = cmd_result.affected_rows
execute_result.error = cmd_result.error
except Exception as e:
logger.error(f"执行语句失败: {str(e)}")
execute_result.error = str(e)
execute_result.rows = [{"error": str(e)}]
return execute_result
def execute_check(self, db_name=None, sql=""):
"""执行语句的检查"""
check_result = ReviewSet(full_sql=sql)
allowed_commands = [
"set",
"delete",
"incr",
"decr",
"touch",
]
cmd, cmd_args = _parse_cmd_args(sql)
if cmd not in allowed_commands:
check_result.error_count += 1
check_result.error = f"不支持的命令: {cmd}"
check_result.rows = [
ReviewResult(
id=1,
affected_rows=0,
sql=sql,
stage="Check",
stagestatus="Fail",
errlevel=2,
errormessage=f"不支持的命令: {cmd}",
)
]
else:
check_result.rows = [
ReviewResult(
id=1,
affected_rows=1,
sql=sql,
stage="Check",
stagestatus="Success",
)
]
check_result.checked = True
return check_result
def execute_workflow(self, workflow: SqlWorkflow):
"""执行上线单,返回Review set"""
return self.execute(
db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
)
def get_execute_percentage(self):
"""获取执行进度"""
return 100
@property
def server_version(self):
"""返回引擎服务器版本"""
try:
conn = self.get_connection()
version = conn.version()
# 尝试解析版本号为tuple
parts = str(version).split(".")
version_tuple = tuple(
int(part) if part.isdigit() else 0 for part in parts[:3]
)
return version_tuple
except Exception as e:
logger.error(f"获取Memcached版本失败: {str(e)}")
return tuple()
# 命令处理函数
def _handle_get(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
"""
处理get命令: get <key>
"""
result_set = ResultSet(full_sql=sql)
if len(cmd_args) < 1:
raise Exception("get命令格式错误")
try:
key = cmd_args[0].strip()
value = conn.get(key)
result_set.column_list = ["值"]
result_set.rows = [[value if value is not None else "None"]]
except Exception as e:
raise Exception(f"get命令执行失败: {str(e)}")
result_set.affected_rows = len(result_set.rows)
return result_set
def _handle_set(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
"""
处理set命令: set <key> <value> [expiry]
"""
result_set = ResultSet(full_sql=sql)
if len(cmd_args) < 2:
raise Exception("set命令格式错误")
try:
key = cmd_args[0].strip()
value = cmd_args[1].strip()
expiry = int(cmd_args[2].strip()) if len(cmd_args) > 2 else 0
ok = conn.set(key, value, expire=expiry)
result_set.rows = [["OK"] if ok else ["FAIL"]]
result_set.column_list = ["状态"]
except Exception as e:
raise Exception(f"set命令执行失败: {str(e)}")
result_set.affected_rows = len(result_set.rows)
return result_set
def _handle_delete(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
"""
处理delete命令: delete <key>
"""
result_set = ResultSet(full_sql=sql)
if len(cmd_args) < 1:
raise Exception("delete命令格式错误")
try:
key = cmd_args[0].strip()
ok = conn.delete(key)
result_set.rows = [["OK"] if ok else ["FAIL"]]
result_set.column_list = ["状态"]
except Exception as e:
raise Exception(f"delete命令执行失败: {str(e)}")
result_set.affected_rows = len(result_set.rows)
return result_set
def _handle_version(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
"""
处理version命令: version
"""
result_set = ResultSet(full_sql=sql)
version = conn.version()
result_set.rows = [[version]]
result_set.column_list = ["版本"]
result_set.affected_rows = 1
return result_set
def _handle_gets(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
"""
处理gets命令: gets <key1> <key2>
"""
result_set = ResultSet(full_sql=sql)
if len(cmd_args) < 1:
raise Exception("gets命令格式错误")
try:
keys = [v.strip() for v in cmd_args]
values = conn.gets_many(keys)
result_set.column_list = ["键", "值", "CAS"]
for key, (value, cas) in values.items():
result_set.rows.append([key, value if value is not None else "None", cas])
except Exception as e:
raise Exception(f"gets命令执行失败: {str(e)}")
result_set.affected_rows = len(result_set.rows)
return result_set
def _handle_incr(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
"""
处理incr命令: incr <key> [value]
"""
result_set = ResultSet(full_sql=sql)
if len(cmd_args) < 1:
raise Exception("incr命令格式错误")
try:
key = cmd_args[0].strip()
value = int(cmd_args[1].strip()) if len(cmd_args) > 1 else 1
result = conn.incr(key, value)
result_set.rows = [[str(result) if result is not None else "FAIL"]]
result_set.column_list = ["结果"]
except Exception as e:
raise Exception(f"incr命令执行失败: {str(e)}")
result_set.affected_rows = 1
return result_set
def _handle_decr(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
"""
处理decr命令: decr <key> [value]
"""
result_set = ResultSet(full_sql=sql)
if len(cmd_args) < 1:
raise Exception("decr命令格式错误")
try:
key = cmd_args[0].strip()
value = int(cmd_args[1].strip()) if len(cmd_args) > 1 else 1
result = conn.decr(key, value)
result_set.rows = [[str(result) if result is not None else "FAIL"]]
result_set.column_list = ["结果"]
except Exception as e:
raise Exception(f"decr命令执行失败: {str(e)}")
result_set.affected_rows = 1
return result_set
def _handle_touch(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
"""
处理touch命令: touch <key> <expiry>
"""
result_set = ResultSet(full_sql=sql)
if len(cmd_args) < 2:
raise Exception("touch命令格式错误")
try:
key = cmd_args[0].strip()
expiry = int(cmd_args[1].strip())
ok = conn.touch(key, expire=expiry)
result_set.rows = [["OK"] if ok else ["FAIL"]]
result_set.column_list = ["状态"]
except Exception as e:
raise Exception(f"touch命令执行失败: {str(e)}")
result_set.affected_rows = 1
return result_set
# 命令处理函数映射表
cmd_handlers = {
"get": _handle_get,
"set": _handle_set,
"delete": _handle_delete,
"version": _handle_version,
"gets": _handle_gets,
"incr": _handle_incr,
"decr": _handle_decr,
"touch": _handle_touch,
}
def _parse_cmd_args(sql: str) -> Tuple[str, List[str]]:
"""
解析命令参数
"""
cmd = sql.split(" ")[0].strip().lower()
cmd_args = sql.split(" ")[1:]
return cmd, cmd_args
def _handle_cmd(conn: pymemcache.Client, sql: str):
"""
处理命令
"""
# 简单解析SQL命令
sql = sql.strip().lower()
if not sql:
raise Exception("空SQL语句")
# 提取命令名称
parts = sql.split(" ")
cmd = parts[0]
cmd_args = parts[1:]
if cmd not in cmd_handlers:
raise Exception(f"不支持的命令: {cmd}")
return cmd_handlers[cmd](conn, sql, cmd_args)