import re

import logging
import traceback

from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import tuple_factory
from cassandra.policies import RoundRobinPolicy

import sqlparse

from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult

from sql.models import SqlWorkflow

logger = logging.getLogger("default")


def split_sql(db_name=None, sql=""):
    """切分语句,追加到检测结果中,默认全部检测通过"""
    sql = sqlparse.format(sql, strip_comments=True)
    sql_result = []
    if db_name:
        sql_result += [f"""USE {db_name}"""]
    sql_result += sqlparse.split(sql)
    return sql_result


def dummy_audit(full_sql: str, sql_list) -> ReviewSet:
    check_result = ReviewSet(full_sql=full_sql)
    rowid = 1
    for statement in sql_list:
        check_result.rows.append(
            ReviewResult(
                id=rowid,
                errlevel=0,
                stagestatus="Audit completed",
                errormessage="None",
                sql=statement,
                affected_rows=0,
                execute_time=0,
            )
        )
        rowid += 1
    return check_result


class CassandraEngine(EngineBase):
    name = "Cassandra"
    info = "Cassandra engine"

    def get_connection(self, db_name=None):
        db_name = db_name or self.db_name
        if self.conn:
            if db_name:
                self.conn.execute(f"use {db_name}")
            return self.conn
        auth_provider = PlainTextAuthProvider(
            username=self.user, password=self.password
        )
        hosts = self.host.split(",")
        cluster = Cluster(
            hosts,
            port=self.port,
            auth_provider=auth_provider,
            load_balancing_policy=RoundRobinPolicy(),
            protocol_version=5,
        )
        self.conn = cluster.connect(keyspace=db_name)
        self.conn.row_factory = tuple_factory
        return self.conn

    def close(self):
        if self.conn:
            self.conn.shutdown()
            self.conn = None

    def test_connection(self):
        result = self.get_all_databases()
        self.close()
        return result

    def escape_string(self, value: str) -> str:
        return re.sub(r"[; ]", "", value)

    def get_all_databases(self, **kwargs):
        """
        获取所有的 keyspace/database
        :return:
        """
        result = self.query(sql="SELECT keyspace_name FROM system_schema.keyspaces;")
        result.rows = [x[0] for x in result.rows]
        return result

    def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
        """获取所有列, 返回一个ResultSet"""
        sql = "select column_name from columns where keyspace_name=%s and table_name=%s"
        result = self.query(
            db_name="system_schema", sql=sql, parameters=(db_name, tb_name)
        )
        result.rows = [x[0] for x in result.rows]
        return result

    def describe_table(self, db_name, tb_name, **kwargs):
        sql = f"describe table {tb_name}"
        result = self.query(db_name=db_name, sql=sql)
        result.column_list = ["table", "create table"]
        filtered_rows = []
        for r in result.rows:
            filtered_rows.append((r[2], r[3]))
        result.rows = filtered_rows
        return result

    def query_check(self, db_name=None, sql="", limit_num: int = 100):
        """提交查询前的检查"""
        # 查询语句的检查、注释去除、切分
        result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
        # 删除注释语句,进行语法判断,执行第一条有效sql
        try:
            sql = sqlparse.format(sql, strip_comments=True)
            sql = sqlparse.split(sql)[0]
            result["filtered_sql"] = sql.strip()
        except IndexError:
            result["bad_query"] = True
            result["msg"] = "没有有效的SQL语句"
        if re.match(r"^select|^describe", sql, re.I) is None:
            result["bad_query"] = True
            result["msg"] = "不支持的查询语法类型!"
        if "*" in sql:
            result["has_star"] = True
            result["msg"] = "SQL语句中含有 * "
        return result

    def filter_sql(self, sql="", limit_num=0) -> str:
        # 对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n
        sql = sql.rstrip(";").strip()
        if re.match(r"^select", sql, re.I):
            # LIMIT N
            limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I)
            if limit_n.search(sql):
                sql_limit = limit_n.search(sql).group(1)
                limit_num = min(int(limit_num), int(sql_limit))
                sql = limit_n.sub(f"limit {limit_num};", sql)
            else:
                sql = f"{sql} limit {limit_num};"
        else:
            sql = f"{sql};"
        return sql

    def query(
        self,
        db_name=None,
        sql="",
        limit_num=0,
        close_conn=True,
        parameters=None,
        **kwargs,
    ):
        """返回 ResultSet"""
        result_set = ResultSet(full_sql=sql)
        try:
            conn = self.get_connection(db_name=db_name)
            rows = conn.execute(sql, parameters=parameters)
            result_set.column_list = rows.column_names
            result_set.rows = rows.all()
            result_set.affected_rows = len(result_set.rows)
            if limit_num > 0:
                result_set.rows = result_set.rows[0:limit_num]
                result_set.affected_rows = min(limit_num, result_set.affected_rows)
        except Exception as e:
            logger.warning(
                f"{self.name} query 错误,语句:{sql}, 错误信息:{traceback.format_exc()}"
            )
            result_set.error = str(e)
        if close_conn:
            self.close()
        return result_set

    def get_all_tables(self, db_name, **kwargs):
        sql = "SELECT table_name FROM system_schema.tables WHERE keyspace_name = %s;"
        parameters = [db_name]
        result = self.query(db_name=db_name, sql=sql, parameters=parameters)
        tb_list = [row[0] for row in result.rows]
        result.rows = tb_list
        return result

    def query_masking(self, db_name=None, sql="", resultset=None):
        """不做脱敏"""
        return resultset

    def execute_check(self, db_name=None, sql=""):
        """上线单执行前的检查, 返回Review set"""
        sql_result = split_sql(db_name, sql)
        return dummy_audit(sql, sql_result)

    def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
        """执行sql语句 返回 Review set"""
        execute_result = ReviewSet(full_sql=sql)
        conn = self.get_connection(db_name=db_name)
        sql_result = split_sql(db_name, sql)
        rowid = 1
        for statement in sql_result:
            try:
                conn.execute(statement)
                execute_result.rows.append(
                    ReviewResult(
                        id=rowid,
                        errlevel=0,
                        stagestatus="Execute Successfully",
                        errormessage="None",
                        sql=statement,
                        affected_rows=0,
                        execute_time=0,
                    )
                )
            except Exception as e:
                logger.warning(
                    f"{self.name} 命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}"
                )
                execute_result.error = str(e)
                execute_result.rows.append(
                    ReviewResult(
                        id=rowid,
                        errlevel=2,
                        stagestatus="Execute Failed",
                        errormessage=f"异常信息:{e}",
                        sql=statement,
                        affected_rows=0,
                        execute_time=0,
                    )
                )
                break
            rowid += 1
        if execute_result.error:
            for statement in sql_result[rowid:]:
                execute_result.rows.append(
                    ReviewResult(
                        id=rowid,
                        errlevel=2,
                        stagestatus="Execute Failed",
                        errormessage="前序语句失败, 未执行",
                        sql=statement,
                        affected_rows=0,
                        execute_time=0,
                    )
                )
                rowid += 1
        if close_conn:
            self.close()
        return execute_result

    def execute_workflow(self, workflow: SqlWorkflow):
        """执行上线单,返回Review set"""
        return self.execute(
            db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
        )