# -*- coding: UTF-8 -*-
"""
@author: feiazifeiazi
@license: Apache Licence
@file: xx.py
@time: 2024-08-01
"""
__author__ = "feiazifeiazi"

import logging
import os
import re
import traceback
from opensearchpy import OpenSearch
import simplejson as json
import sqlparse

from common.utils.timer import FuncTimer
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
from common.config import SysConfig
import logging

from elasticsearch import Elasticsearch
from elasticsearch.exceptions import TransportError


logger = logging.getLogger("default")


class QueryParamsSearch:
    def __init__(
        self,
        index: str = None,
        path: str = None,
        params: str = None,
        method: str = None,
        size: int = 100,
        sql: str = None,
        query_body: dict = None,
    ):
        self.index = index if index is not None else ""
        self.path = path if path is not None else ""
        self.method = method if method is not None else ""
        self.params = params
        self.size = size
        self.sql = sql if sql is not None else ""
        self.query_body = query_body if query_body is not None else {}


class ElasticsearchDocument:
    """ES doc对象"""

    def __init__(
        self,
        sql: str = None,
        method: str = None,
        index_name: str = None,
        api_endpoint: str = "",
        doc_id: str = None,
        doc_data_body: str = None,
    ):
        self.sql = sql
        self.method = method.upper() if method is not None else None
        self.index_name = index_name
        self.api_endpoint = api_endpoint.lower() if api_endpoint is not None else ""
        self.doc_id = doc_id
        self.doc_data_body = doc_data_body

    def describe(self) -> str:
        """返回格式化的描述信息"""
        return f"[index_name:{self.index_name}, method:{self.method}, api_endpoint:{self.api_endpoint}, doc_id:{self.doc_id}]"


class ElasticsearchEngineBase(EngineBase):
    """
    Elasticsearch、OpenSearch等Search父类实现
    如果2者方法差异不大,可以在父类用if else实现。如果差异大,建议在子类实现。
    """

    def __init__(self, instance=None):
        self.conn = None  # type: Elasticsearch  # 使用类型注释来显式提示类型
        self.db_separator = "__"  # 设置分隔符
        # 限制只能2种支持的子类
        self.search_name = ["Elasticsearch", "OpenSearch"]
        if self.name not in self.search_name:
            raise ValueError(
                f"Invalid name: {self.name}. Must be one of {self.search_name}."
            )
        super().__init__(instance=instance)

    def get_connection(self, db_name=None):
        """返回一个conn实例"""

    def test_connection(self):
        """测试实例链接是否正常"""
        return self.get_all_databases()

    name: str = "SearchBase"
    info: str = "SearchBase 引擎"

    def get_all_databases(self):
        """获取所有“数据库”名(从索引名提取),默认提取 __ 前的部分作为数据库名"""
        try:
            self.get_connection()
            # 获取所有的别名,没有别名就是本身。
            indices = self.conn.indices.get_alias(index=self.db_name)
            database_names = set()
            database_names.add("system")  # 系统表名使用的库名
            for index_name in indices.keys():
                if self.db_separator in index_name:
                    db_name = index_name.split(self.db_separator)[0]
                    database_names.add(db_name)
            database_names.add("other")  # 表名没有__时,使用的库名
            database_names_sorted = sorted(database_names)
            return ResultSet(rows=database_names_sorted)
        except Exception as e:
            logger.error(f"获取数据库时出错:{e}{traceback.format_exc()}")
            raise Exception(f"获取数据库时出错: {str(e)}")

    def get_all_tables(self, db_name, **kwargs):
        """根据给定的数据库名获取所有相关的表名
        以点开头的表名,不返回。此为系统表,官方不让查询了。
        """
        try:
            self.get_connection()
            indices = self.conn.indices.get_alias(index=self.db_name)
            tables = set()

            db_mapping = {
                "system": "",
                "other": "",
            }
            # 根据分隔符分隔的库名
            if db_name not in db_mapping:
                index_prefix = db_name.rstrip(self.db_separator) + self.db_separator
                tables = [
                    index for index in indices.keys() if index.startswith(index_prefix)
                ]
            else:
                # 处理系统表,和other
                if db_name == "system":
                    # 将系统的API作为表名
                    tables.add("/_cat/indices/" + self.db_name)
                    tables.add("/_cat/nodes")
                    tables.add("/_security/role")
                    tables.add("/_security/user")

                for index_name in indices.keys():
                    if index_name.startswith("."):
                        # if db_name == "system":
                        #     tables.add(index_name)
                        continue
                    elif index_name.startswith(db_name):
                        tables.add(index_name)
                        if db_name == "system":
                            tables.add("/_cat/indices/" + db_name)
                        continue
                    elif self.db_separator in index_name:
                        separator_db_name = index_name.split(self.db_separator)[0]
                        if db_name == "system":
                            tables.add("/_cat/indices/" + separator_db_name)
                    else:
                        if db_name == "other":
                            tables.add(index_name)
            tables_sorted = sorted(tables)
            return ResultSet(rows=tables_sorted)
        except Exception as e:
            raise Exception(f"获取表列表时出错: {str(e)}")

    def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
        """获取所有字段"""
        result_set = ResultSet(full_sql=f"{tb_name}/_mapping")
        if tb_name.startswith(("/", "_")):
            return result_set
        else:
            try:
                self.get_connection()
                mapping = self.conn.indices.get_mapping(index=tb_name)
                properties = (
                    mapping.get(tb_name, {}).get("mappings", {}).get("properties", None)
                )
                # 返回字段名
                result_set.column_list = ["column_name"]
                if properties is None:
                    result_set.rows = ["无"]
                else:
                    result_set.rows = list(properties.keys())
                return result_set
            except Exception as e:
                raise Exception(f"获取字段时出错: {str(e)}")

    def describe_table(self, db_name, tb_name, **kwargs):
        """表结构"""
        result_set = ResultSet(full_sql=f"{tb_name}/_mapping")
        if tb_name.startswith(("/", "_")):
            return result_set
        else:
            try:
                self.get_connection()
                mapping = self.conn.indices.get_mapping(index=tb_name)
                properties = (
                    mapping.get(tb_name, {}).get("mappings", {}).get("properties", None)
                )
                # 创建包含字段名、类型和其他信息的列表结构
                result_set.column_list = ["column_name", "type", "fields"]
                if properties is None:
                    result_set.rows = [("无", "无", "无")]
                else:
                    result_set.rows = [
                        (
                            column,
                            details.get("type"),
                            json.dumps(details.get("fields", {})),
                        )
                        for column, details in properties.items()
                    ]
                return result_set
            except Exception as e:
                raise Exception(f"获取字段时出错: {str(e)}")

    def query_check(self, db_name=None, sql=""):
        """语句检查"""
        result = {
            "msg": "语句检查通过。",
            "bad_query": False,
            "filtered_sql": sql,
            "has_star": False,
        }
        sql = sql.rstrip(";").strip()
        result["filtered_sql"] = sql
        # 检查是否以 'get' 或 'select' 开头
        if re.match(r"^get", sql, re.I):
            pass
        elif re.match(r"^select", sql, re.I):
            try:
                sql = sqlparse.format(sql, strip_comments=True)
                sql = sqlparse.split(sql)[0]
                result["filtered_sql"] = sql.strip()
            except IndexError:
                result["bad_query"] = True
                result["msg"] = "没有有效的SQL语句。"
        else:
            result["msg"] = (
                "语句检查失败:语句必须以 'get' 或 'select' 开头。示例查询:GET /dmp__iv/_search、select * from dmp__iv limit 10;"
            )
            result["bad_query"] = True
        return result

    def filter_sql(self, sql="", limit_num=0):
        """过滤 SQL 语句。
        对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n
        此方法SQL部分的逻辑copy的mysql实现。
        """
        #
        sql = sql.rstrip(";").strip()
        if re.match(r"^get", sql, re.I):
            pass
        elif re.match(r"^select", sql, re.I):
            # LIMIT N
            limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I)
            # LIMIT M OFFSET N
            limit_offset = re.compile(r"limit\s+(\d+)\s+offset\s+(\d+)\s*$", re.I)
            # LIMIT M,N
            offset_comma_limit = re.compile(r"limit\s+(\d+)\s*,\s*(\d+)\s*$", re.I)
            if limit_n.search(sql):
                sql_limit = limit_n.search(sql).group(1)
                limit_num = min(int(limit_num), int(sql_limit))
                sql = limit_n.sub(f"limit {limit_num};", sql)
            elif limit_offset.search(sql):
                sql_limit = limit_offset.search(sql).group(1)
                sql_offset = limit_offset.search(sql).group(2)
                limit_num = min(int(limit_num), int(sql_limit))
                sql = limit_offset.sub(f"limit {limit_num} offset {sql_offset};", sql)
            elif offset_comma_limit.search(sql):
                sql_offset = offset_comma_limit.search(sql).group(1)
                sql_limit = offset_comma_limit.search(sql).group(2)
                limit_num = min(int(limit_num), int(sql_limit))
                sql = offset_comma_limit.sub(f"limit {sql_offset},{limit_num};", sql)
            else:
                sql = f"{sql} limit {limit_num};"
        else:
            sql = f"{sql};"
        return sql

    def query(
        self,
        db_name=None,
        sql="",
        limit_num=0,
        close_conn=True,
        parameters=None,
        **kwargs,
    ):
        """执行查询"""
        try:
            result_set = ResultSet(full_sql=sql)

            # 解析查询字符串
            query_params = self.parse_es_select_query_to_query_params(sql, limit_num)
            self.get_connection()
            # 管理查询处理
            if query_params.path.startswith("/_cat/indices"):
                # v这个参数用显示标题,需要加上。 opensearch 需要字符串的true
                if "v" not in query_params.params:
                    query_params.params["v"] = "true"
                response = self.conn.cat.indices(
                    index=query_params.index, params=query_params.params
                )
                response_body = ""
                if isinstance(response, str):
                    response_body = response
                else:
                    response_body = response.body
                response_data = self.parse_cat_indices_response(response_body)
                # 如果有数据,设置列名
                if response_data:
                    result_set.column_list = list(response_data[0].keys())
                    result_set.rows = [tuple(row.values()) for row in response_data]
                else:
                    result_set.column_list = []
                    result_set.rows = []
                    result_set.affected_rows = 0
            elif query_params.path.startswith("/_security/role"):
                result_set = self._security_role(sql, query_params)
            elif query_params.path.startswith("/_security/user"):
                result_set = self._security_user(sql, query_params)
            elif query_params.sql and self.name == "Elasticsearch":
                query_body = {"query": query_params.sql}
                response = self.conn.sql.query(body=query_body)
                # 提取列名和行数据
                columns = response.get("columns", [])
                rows = response.get("rows", [])
                # 获取字段名作为列名
                column_list = [col["name"] for col in columns]

                # 处理查询结果,将列表和字典转换为 JSON 字符串。列名可能是重复的。
                formatted_rows = []
                for row in rows:
                    # 创建字典,将列名和对应的行值关联
                    formatted_row = []
                    for col_name, value in zip(column_list, row):
                        # 如果字段是列表或字典,将其转换为 JSON 字符串
                        if isinstance(value, (list, dict)):
                            formatted_row.append(json.dumps(value))
                        else:
                            formatted_row.append(value)
                    formatted_rows.append(formatted_row)
                # 构建结果集
                result_set.rows = formatted_rows
                result_set.column_list = column_list
            elif query_params.sql and self.name == "OpenSearch":
                query_body = {"query": query_params.sql}
                response = self.conn.transport.perform_request(
                    method="POST", url="/_opendistro/_sql", body=query_body
                )
                # 提取列名和行数据
                columns = response.get("schema", [])
                rows = response.get("datarows", [])
                # 获取字段名作为列名
                column_list = [col["name"] for col in columns]

                # 处理查询结果,将列表和字典转换为 JSON 字符串。列名可能是重复的。
                formatted_rows = []
                for row in rows:
                    # 创建字典,将列名和对应的行值关联
                    formatted_row = []
                    for col_name, value in zip(column_list, row):
                        # 如果字段是列表或字典,将其转换为 JSON 字符串
                        if isinstance(value, (list, dict)):
                            formatted_row.append(json.dumps(value))
                        else:
                            formatted_row.append(value)
                    formatted_rows.append(formatted_row)
                # 构建结果集
                result_set.rows = formatted_rows
                result_set.column_list = column_list
            else:
                # 执行搜索查询
                response = self.conn.search(
                    index=query_params.index,
                    body=query_params.query_body,
                    params=query_params.params,
                )

                # 提取查询结果
                hits = response.get("hits", {}).get("hits", [])
                # 处理查询结果,将列表和字典转换为 JSON 字符串
                rows = []
                all_search_keys = {}  # 用于收集所有字段的集合
                all_search_keys["_id"] = None
                for hit in hits:
                    # 获取文档 ID 和 _source 数据
                    doc_id = hit.get("_id")
                    source_data = hit.get("_source", {})

                    # 转换需要转换为 JSON 字符串的字段
                    for key, value in source_data.items():
                        all_search_keys[key] = None  # 收集所有字段名
                        if isinstance(value, (list, dict)):  # 如果字段是列表或字典
                            source_data[key] = json.dumps(value)  # 转换为 JSON 字符串

                    # 构建结果行
                    row = {"_id": doc_id, **source_data}
                    rows.append(row)

                column_list = list(all_search_keys.keys())
                # 构建结果集
                result_set.rows = []
                for row in rows:
                    # 按照 column_list 的顺序填充每一行
                    result_row = tuple(row.get(key, None) for key in column_list)
                    result_set.rows.append(result_row)
                result_set.column_list = column_list
            result_set.affected_rows = len(result_set.rows)
            return result_set
        except Exception as e:
            raise Exception(f"执行查询时出错: {str(e)}")

    def _security_role(self, sql, query_params: QueryParamsSearch):
        """角色查询方法。请子类实现。"""

    def _security_user(self, sql, query_params: QueryParamsSearch):
        """用户查询方法。请子类实现。"""

    def parse_cat_indices_response(self, response_text):
        """解析cat indices结果"""
        # 将响应文本按行分割
        lines = response_text.strip().splitlines()
        # 获取列标题
        headers = lines[0].strip().split()
        # 解析每一行数据
        indices_info = []
        for line in lines[1:]:
            # 按空格分割,并与标题进行配对
            values = line.strip().split(maxsplit=len(headers) - 1)
            index_info = dict(zip(headers, values))
            indices_info.append(index_info)
        return indices_info

    def parse_es_select_query_to_query_params(
        self, search_query_str: str, limit_num: int
    ) -> QueryParamsSearch:
        """解析 search query 字符串为 QueryParamsSearch 对象"""

        query_params = QueryParamsSearch()
        sql = search_query_str.rstrip(";").strip()
        if re.match(r"^get", sql, re.I):
            # 解析查询字符串
            lines = sql.splitlines()
            method_line = lines[0].strip()

            query_body = "\n".join(lines[1:]).strip()
            # 如果 query_body 为空,使用默认查询体
            if not query_body:
                query_body = json.dumps({"query": {"match_all": {}}})

            # 确保 query_body 是有效的 JSON
            try:
                json_body = json.loads(query_body)
            except json.JSONDecodeError as json_err:
                raise ValueError(
                    f"无法转为Json格式。{json_err}。query_body:{query_body}。"
                )

            # 提取方法和路径
            method, path_with_params = method_line.split(maxsplit=1)
            # 确保路径以 '/' 开头
            if not path_with_params.startswith("/"):
                path_with_params = "/" + path_with_params

            # 分离路径和查询参数
            path, params_str = (
                path_with_params.split("?", 1)
                if "?" in path_with_params
                else (path_with_params, "")
            )
            params = {}
            if params_str:
                for pair in params_str.split("&"):
                    if "=" in pair:
                        key, value = pair.split("=", 1)
                    else:
                        key = pair
                        value = ""
                    params[key] = value
            index_pattern = ""
            # 判断路径类型并提取索引模式
            if path.startswith("/_cat/indices"):
                # _cat API 路径
                path_parts = path.split("/")
                if len(path_parts) > 3:
                    index_pattern = path_parts[3]
                if not index_pattern:
                    index_pattern = "*"
            elif path.startswith("/_security/role"):
                path_parts = path.split("/")
                index_pattern = "*"
            elif path.startswith("/_security/user"):
                path_parts = path.split("/")
                index_pattern = "*"
            elif "/_search" in path:
                # 默认情况,处理常规索引路径
                # 提取索引名称
                path_parts = path.split("/")
                if len(path_parts) > 1:
                    index_pattern = path_parts[1]

            if not index_pattern:
                raise Exception("未找到索引名称。")

            size = limit_num if limit_num > 0 else 100
            # 检查 JSON 中是否已经有 size,如果没有就设置
            if "size" not in json_body:
                json_body["size"] = size
            # 构建 QueryParams 对象
            query_params = QueryParamsSearch(
                index=index_pattern,
                path=path_with_params,
                params=params,
                method=method,
                size=size,
                query_body=json_body,
            )
        elif re.match(r"^select", sql, re.I):
            query_params = QueryParamsSearch(sql=sql)
        return query_params

    def execute_check(self, db_name=None, sql=""):
        """上线单执行前的检查
        #PUT只有索引名,没有api-endpoint时, 解释为创建索引,需要包含mappings或settings。
        #PUT有索引名,有_doc,没有Id,错误写法,必须要写Id。

        #post 有索引名, 没有_doc,错误写法。报错。
        #post 有索引,有_doc,  有或没有id 均可。
        #post 有索引,api-endpoint=_search时,这是查询,报错。

        #delete 有索引,没有_doc,解释为删除表。 archery禁止此操作,需要报错。
        #delete 有索引,有_doc,没有id,删除必须包含id,需要报错。

        # api-endpoint为_update时,只能post,不能put,错误写法,报错。
        # api-endpoint为_update_by_query时,只能post,不能put,错误写法,报错。
        """
        check_result = ReviewSet(full_sql=sql)
        rowid = 1
        documents = self.__split_sql(sql)
        for doc in documents:
            is_pass = False
            doc_desc = doc.describe()
            if re.match(r"^get|^select", doc.sql, re.I):
                result = ReviewResult(
                    id=rowid,
                    errlevel=2,
                    stagestatus="驳回不支持语句",
                    errormessage="仅支持PUT,POST,DELETE等API方法,GET,SELECT查询语句请使用SQL查询功能!",
                    sql=doc.sql,
                )
            elif re.match(r"^#", doc.sql, re.I):
                result = ReviewResult(
                    id=rowid,
                    errlevel=0,
                    stagestatus="Audit completed",
                    errormessage="此为注释信息。",
                    sql=doc.sql,
                    affected_rows=0,
                    execute_time=0,
                )
            elif not doc.index_name:
                result = ReviewResult(
                    id=rowid,
                    errlevel=2,
                    stagestatus="驳回不支持语句",
                    errormessage=f"请求必须包含索引名称或无法解析。解析结果:{doc_desc}",
                    sql=doc.sql,
                )
            elif doc.method == "DELETE":
                if not doc.doc_id:
                    result = ReviewResult(
                        id=rowid,
                        errlevel=2,
                        stagestatus="驳回不支持语句",
                        errormessage="删除操作必须包含id条件。",
                        sql=doc.sql,
                    )
                else:
                    if is_pass == False:
                        is_pass = True
            elif not doc.api_endpoint:
                if doc.method == "PUT":
                    if not doc.doc_data_body or (
                        "mappings" in doc.doc_data_body
                        or "settings" in doc.doc_data_body
                    ):
                        result = ReviewResult(
                            id=rowid,
                            errlevel=0,
                            stagestatus="Audit completed",
                            errormessage=f"审核通过。解析结果:创建表:[index_name:{doc.index_name}]",
                            sql=doc.sql,
                        )
                    else:
                        result = ReviewResult(
                            id=rowid,
                            errlevel=2,
                            stagestatus="驳回不支持语句",
                            errormessage="PUT请求创建索引时请求体可以为空或需要包含mappings或settings。",
                            sql=doc.sql,
                        )
                elif doc.method == "POST":
                    result = ReviewResult(
                        id=rowid,
                        errlevel=2,
                        stagestatus="驳回不支持语句",
                        errormessage=f"POST请求必须指定API端点,例如_doc。解析结果:{doc_desc}",
                        sql=doc.sql,
                    )
                else:
                    result = ReviewResult(
                        id=rowid,
                        errlevel=2,
                        stagestatus="驳回不支持语句",
                        errormessage=f"不支持此操作。解析结果:{doc_desc}",
                        sql=doc.sql,
                        affected_rows=0,
                        execute_time=0,
                    )
            elif doc.api_endpoint == "_doc":
                if doc.method == "PUT":
                    if not doc.doc_id:
                        result = ReviewResult(
                            id=rowid,
                            errlevel=2,
                            stagestatus="驳回不支持语句",
                            errormessage="PUT请求必须包含文档Id。",
                            sql=doc.sql,
                        )
                    else:
                        if is_pass == False:
                            is_pass = True
                elif doc.method == "POST":
                    if is_pass == False:
                        is_pass = True
                else:
                    result = ReviewResult(
                        id=rowid,
                        errlevel=2,
                        stagestatus="驳回不支持语句",
                        errormessage=f"不支持此操作。解析结果:{doc_desc}",
                        sql=doc.sql,
                        affected_rows=0,
                        execute_time=0,
                    )
            elif doc.api_endpoint == "_search":
                result = ReviewResult(
                    id=rowid,
                    errlevel=2,
                    stagestatus="驳回不支持语句",
                    errormessage="_search属于查询方法。",
                    sql=doc.sql,
                )
            elif doc.api_endpoint == "_update":
                if doc.method == "POST":
                    if not doc.doc_id:
                        result = ReviewResult(
                            id=rowid,
                            errlevel=2,
                            stagestatus="驳回不支持语句",
                            errormessage=f"POST请求{doc.api_endpoint}时必须包含文档Id。",
                            sql=doc.sql,
                        )
                    else:
                        if is_pass == False:
                            is_pass = True
                else:
                    result = ReviewResult(
                        id=rowid,
                        errlevel=2,
                        stagestatus="驳回不支持语句",
                        errormessage=f"不支持此操作,{doc.api_endpoint}需要使用POST方法。解析结果:{doc_desc}",
                        sql=doc.sql,
                        affected_rows=0,
                        execute_time=0,
                    )
            elif doc.api_endpoint == "_update_by_query":
                if doc.method == "POST":
                    if is_pass == False:
                        is_pass = True
                else:
                    result = ReviewResult(
                        id=rowid,
                        errlevel=2,
                        stagestatus="驳回不支持语句",
                        errormessage=f"不支持此操作,{doc.api_endpoint}需要使用POST方法。解析结果:{doc_desc}",
                        sql=doc.sql,
                        affected_rows=0,
                        execute_time=0,
                    )
            elif doc.api_endpoint not in ["", "_doc", "_update_by_query", "_update"]:
                result = ReviewResult(
                    id=rowid,
                    errlevel=2,
                    stagestatus="驳回不支持语句",
                    errormessage="API操作端点(API Endpoint)仅支持: 空, _doc、_update、_update_by_query。",
                    sql=doc.sql,
                )
            else:
                result = ReviewResult(
                    id=rowid,
                    errlevel=2,
                    stagestatus="驳回不支持语句",
                    errormessage=f"不支持此操作。解析结果:{doc_desc}",
                    sql=doc.sql,
                    affected_rows=0,
                    execute_time=0,
                )
            # 通用的,通过审核
            if is_pass:
                result = ReviewResult(
                    id=rowid,
                    errlevel=0,
                    stagestatus="Audit completed",
                    errormessage=f"审核通过。解析结果:{doc_desc}",
                    sql=doc.sql,
                    affected_rows=0,
                    execute_time=0,
                )

            check_result.rows.append(result)
            rowid += 1
        # 统计警告和错误数量
        for r in check_result.rows:
            if r.errlevel == 1:
                check_result.warning_count += 1
            if r.errlevel == 2:
                check_result.error_count += 1
        return check_result

    def execute_workflow(self, workflow):
        """执行上线单,返回Review set"""
        sql = workflow.sqlworkflowcontent.sql_content
        docs = self.__split_sql(sql)
        execute_result = ReviewSet(full_sql=sql)
        line = 0
        try:
            conn = self.get_connection(db_name=workflow.db_name)
            for doc in docs:
                line += 1
                if re.match(r"^#", doc.sql, re.I):
                    execute_result.rows.append(
                        ReviewResult(
                            id=line,
                            errlevel=0,
                            stagestatus="Execute Successfully",
                            errormessage="注释信息不需要执行。",
                            sql=doc.sql,
                            affected_rows=0,
                            execute_time=0,
                        )
                    )
                elif doc.method == "DELETE":
                    reviewResult = self.__delete_data(conn, doc)
                    reviewResult.id = line
                    execute_result.rows.append(reviewResult)
                elif doc.api_endpoint == "":
                    # 创建索引
                    reviewResult = self.__create_index(conn, doc)
                    reviewResult.id = line
                    execute_result.rows.append(reviewResult)
                elif doc.api_endpoint == "_update":
                    reviewResult = self.__update(conn, doc)
                    reviewResult.id = line
                    execute_result.rows.append(reviewResult)
                elif doc.api_endpoint == "_update_by_query":
                    reviewResult = self.__update_by_query(conn, doc)
                    reviewResult.id = line
                    execute_result.rows.append(reviewResult)
                elif doc.api_endpoint == "_doc":
                    reviewResult = self.__add_or_update(conn, doc)
                    reviewResult.id = line
                    execute_result.rows.append(reviewResult)
                else:
                    raise Exception(f"不支持的API类型:{doc.api_endpoint}")
        except Exception as e:
            logger.warning(
                f"ES命令执行报错,语句:{doc.sql}, 错误信息:{traceback.format_exc()}"
            )
            # 追加当前报错语句信息到执行结果中
            execute_result.error = str(e)
            execute_result.rows.append(
                ReviewResult(
                    id=line,
                    errlevel=2,
                    stagestatus="Execute Failed",
                    errormessage=f"异常信息:{e}",
                    sql=doc.sql,
                    affected_rows=0,
                    execute_time=0,
                )
            )
        if execute_result.error:
            # 如果失败, 将剩下的部分加入结果集
            for doc in docs[line:]:
                line += 1
                execute_result.rows.append(
                    ReviewResult(
                        id=line,
                        errlevel=0,
                        stagestatus="Audit completed",
                        errormessage=f"前序语句失败, 未执行",
                        sql=doc.sql,
                        affected_rows=0,
                        execute_time=0,
                    )
                )
        return execute_result

    def __update(self, conn, doc):
        """ES的  update方法"""
        errlevel = 0
        with FuncTimer() as t:
            try:
                response = conn.update(
                    index=doc.index_name,
                    id=doc.doc_id,
                    body=doc.doc_data_body,
                )
                successful_count = response.get("_shards", {}).get("successful", None)
                response_str = str(response)
            except Exception as e:
                error_message = str(e)
                if "NotFoundError" in error_message:
                    response_str = "document missing: " + error_message
                    successful_count = 0
                    errlevel = 1
                else:
                    raise
        return ReviewResult(
            errlevel=errlevel,
            stagestatus="Execute Successfully",
            errormessage=response_str,
            sql=doc.sql,
            affected_rows=successful_count,
            execute_time=t.cost,
        )

    def __add_or_update(self, conn, doc):
        """ES的 add_or_update方法"""
        with FuncTimer() as t:
            if doc.api_endpoint == "_doc":
                response = conn.index(
                    index=doc.index_name,
                    id=doc.doc_id,
                    body=doc.doc_data_body,
                )
            else:
                raise Exception(f"不支持的API类型:{doc.api_endpoint}")

            successful_count = response.get("_shards", {}).get("successful", None)
            response_str = str(response)
        return ReviewResult(
            errlevel=0,
            stagestatus="Execute Successfully",
            errormessage=response_str,
            sql=doc.sql,
            affected_rows=successful_count,
            execute_time=t.cost,
        )

    def __update_by_query(self, conn, doc):
        """ES的 update_by_query方法"""
        errlevel = 0
        with FuncTimer() as t:
            try:
                response = conn.update_by_query(
                    index=doc.index_name, body=doc.doc_data_body
                )
                successful_count = response.get("total", 0)
                response_str = str(response)
            except Exception as e:
                raise e
        return ReviewResult(
            errlevel=errlevel,
            stagestatus="Execute Successfully",
            errormessage=response_str,
            sql=doc.sql,
            affected_rows=successful_count,
            execute_time=t.cost,
        )

    def __create_index(self, conn, doc):
        """ES的 创建索引方法"""
        errlevel = 0
        with FuncTimer() as t:
            try:
                response = conn.indices.create(
                    index=doc.index_name, body=doc.doc_data_body
                )
                successful_count = 0
                response_str = str(response)
            except Exception as e:
                error_message = str(e)
                if "already_exists" in error_message:
                    response_str = "index already exists: " + error_message
                    successful_count = 0
                    errlevel = 1
                else:
                    raise

        return ReviewResult(
            errlevel=errlevel,
            stagestatus="Execute Successfully",
            errormessage=response_str,
            sql=doc.sql,
            affected_rows=successful_count,
            execute_time=t.cost,
        )

    def __delete_data(self, conn, doc):
        """
        数据删除
        """
        errlevel = 0
        if not doc.doc_id:
            response_str = "删除操作必须包含id条件。"
            successful_count = 0
        with FuncTimer() as t:
            try:
                response = conn.delete(index=doc.index_name, id=doc.doc_id)
                successful_count = response.get("_shards", {}).get("successful", None)
                response_str = str(response)
            except Exception as e:
                error_message = str(e)
                if "NotFoundError" in error_message:
                    response_str = "Document not found: " + error_message
                    successful_count = 0
                    errlevel = 1
                else:
                    raise
        return ReviewResult(
            errlevel=errlevel,
            stagestatus="Execute Successfully",
            errormessage=response_str,
            sql=doc.sql,
            affected_rows=successful_count,
            execute_time=t.cost,
        )

    def __get_document_from_sql(self, sql):
        """
        解析输入的SQL,提取索引、文档 ID 和文档数据,返回 ElasticsearchDocument 实例。
        """
        result = ElasticsearchDocument(sql=sql)
        if re.match(r"^POST |^PUT |^DELETE ", sql, re.I):

            # 提取方法和路径
            method, path_with_params = sql.split(maxsplit=1)
            if path_with_params.startswith("{"):
                # 如果是{ 开头,说明没有路径部分。
                return result
            # 确保路径以 '/' 开头
            if not path_with_params.startswith("/"):
                path_with_params = "/" + path_with_params

            parts = path_with_params.split(maxsplit=1)
            path = parts[0]  # 获取路径部分
            doc_data_body = parts[1].strip() if len(parts) > 1 else None

            path_parts = path.split("/")
            # 提取各个部分
            index_name = path_parts[1] if len(path_parts) > 1 else None
            api_endpoint = path_parts[2] if len(path_parts) > 2 else None
            doc_id = path_parts[3] if len(path_parts) > 3 else None
            doc_data_json = None
            if doc_data_body:
                try:
                    doc_data_json = json.loads(doc_data_body)
                except json.JSONDecodeError as json_err:
                    raise ValueError(
                        f"无法转为Json格式。{json_err}。doc_data_body:{doc_data_body}。"
                    )
            result = ElasticsearchDocument(
                sql=sql,
                method=method,
                index_name=index_name,
                api_endpoint=api_endpoint,
                doc_id=doc_id,
                doc_data_body=doc_data_json,
            )
        return result

    def __split_sql(self, sql):
        """
        解析输入的多行命令字符串,将其分割为独立的命令列表,解析为documents对象返回
        """
        lines = sql.strip().splitlines()
        commands = []
        current_command = []
        brace_level = 0

        for line in lines:
            stripped_line = line.strip()

            if not stripped_line:
                continue
            if stripped_line.startswith("#"):
                continue

            brace_level += stripped_line.count("{")
            brace_level -= stripped_line.count("}")

            # 将当前行加入当前命令
            current_command.append(stripped_line)

            if brace_level == 0 and current_command:
                commands.append(os.linesep.join(current_command))
                current_command = []

        merged_commands = []
        for command in commands:
            # 如果当前命令以 { 开头,合并到前一个命令
            if command.startswith("{") and merged_commands:
                # 合并当前命令到上一个命令
                merged_commands[-1] += os.linesep + command
            else:
                # 如果不是以 { 开头,则直接添加到结果中
                merged_commands.append(command)

        # 创建 ElasticsearchDocument 实例列表
        documents = []
        for command in merged_commands:
            doc = self.__get_document_from_sql(command)
            if doc:
                documents.append(doc)
        return documents


class ElasticsearchEngine(ElasticsearchEngineBase):
    """Elasticsearch 引擎实现"""

    def __init__(self, instance=None):
        super().__init__(instance=instance)

    name: str = "Elasticsearch"
    info: str = "Elasticsearch 引擎"

    def get_connection(self, db_name=None):
        if self.conn:
            return self.conn
        if self.instance:
            scheme = "https" if self.instance.is_ssl else "http"
            hosts = [
                {
                    "host": self.host,
                    "port": self.port,
                    "scheme": scheme,
                    "use_ssl": self.instance.is_ssl,
                }
            ]
            http_auth = (
                (self.user, self.password) if self.user and self.password else None
            )
            self.db_name = (self.db_name or "") + "*"
            try:
                # 创建 Elasticsearch 连接,高版本有basic_auth
                self.conn = Elasticsearch(
                    hosts=hosts,
                    http_auth=http_auth,
                    verify_certs=self.instance.verify_ssl,  # 需要证书验证
                )
            except Exception as e:
                raise Exception(f"Elasticsearch 连接建立失败: {str(e)}")
        if not self.conn:
            raise Exception("Elasticsearch 连接无法建立。")
        return self.conn

    def _security_role(self, sql, query_params: QueryParamsSearch):
        """TODO 角色查询方法。"""
        raise NotImplementedError("此方法暂未实现。")

    def _security_user(self, sql, query_params: QueryParamsSearch):
        """TODO 用户查询方法。"""
        raise NotImplementedError("此方法暂未实现。")


class OpenSearchEngine(ElasticsearchEngineBase):
    """OpenSearch 引擎实现"""

    def __init__(self, instance=None):
        self.conn = None  # type: OpenSearch  # 使用类型注释来显式提示类型
        super().__init__(instance=instance)

    name: str = "OpenSearch"
    info: str = "OpenSearch 引擎"

    def get_connection(self, db_name=None):
        if self.conn:
            return self.conn
        if self.instance:
            scheme = "https" if self.instance.is_ssl else "http"
            hosts = [
                {
                    "host": self.host,
                    "port": self.port,
                    "scheme": scheme,
                    "use_ssl": self.instance.is_ssl,
                }
            ]
            http_auth = (
                (self.user, self.password) if self.user and self.password else None
            )
            self.db_name = (self.db_name or "") + "*"

            try:
                # 创建 OpenSearch 连接
                self.conn = OpenSearch(
                    hosts=hosts,
                    http_auth=http_auth,
                    verify_certs=self.instance.verify_ssl,  # 开启证书验证
                )
            except Exception as e:
                raise Exception(f"OpenSearch 连接建立失败: {str(e)}")
        if not self.conn:
            raise Exception("OpenSearch 连接无法建立。")
        return self.conn

    def _security_role(self, sql, query_params: QueryParamsSearch):
        """角色查询方法。"""
        result_set = ResultSet(full_sql=sql)
        url = "/_opendistro/_security/api/roles"
        try:
            body = {}
            # "/_security/role"
            response = self.conn.transport.perform_request("GET", url, body=body)
            response_body = response
            if response and isinstance(response_body, (dict)):
                # 获取第一个角色的信息,动态生成 column_list
                first_role_info = next(iter(response.values()), {})
                column_list = ["role_name"] + list(first_role_info.keys())
                formatted_rows = []

                for role_name, role_info in response.items():
                    row = [role_name]
                    for column in first_role_info.keys():
                        value = role_info.get(column, None)
                        # 检查值的类型,如果是 list 或 dict,转换为 JSON 字符串
                        if isinstance(value, (list, dict)):
                            row.append(json.dumps(value))
                        else:
                            row.append(value)
                    formatted_rows.append(row)
                result_set.rows = formatted_rows
                result_set.column_list = column_list
        except Exception as e:
            raise Exception(f"执行查询时出错: {str(e)}")
        return result_set

    def _security_user(self, sql, query_params: QueryParamsSearch):
        """用户查询方法。"""
        result_set = ResultSet(full_sql=sql)
        url = "/_opendistro/_security/api/user"
        try:
            body = {}
            # "/_security/role"
            response = self.conn.transport.perform_request("GET", url, body=body)
            response_body = response
            if response and isinstance(response_body, (dict)):
                # 获取第一个角色的信息,动态生成 column_list
                first_role_info = next(iter(response.values()), {})
                column_list = ["user_name"] + list(first_role_info.keys())
                formatted_rows = []

                for role_name, role_info in response.items():
                    row = [role_name]
                    for column in first_role_info.keys():
                        value = role_info.get(column, None)
                        # 检查值的类型,如果是 list 或 dict,转换为 JSON 字符串
                        if isinstance(value, (list, dict)):
                            row.append(json.dumps(value))
                        else:
                            row.append(value)
                    formatted_rows.append(row)
                result_set.rows = formatted_rows
                result_set.column_list = column_list
        except Exception as e:
            raise Exception(f"执行查询时出错: {str(e)}")
        return result_set