"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2026 Huawei Technologies Co.,Ltd.

MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:

         http://license.coscl.org.cn/MulanPSL2

THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os.path
import sqlite3
from typing import Any, Dict, List, Optional, Type, Union, Iterable, get_origin, get_args

# 支持的 Python 类型映射
_PY_TYPE_TO_SQLITE = {
    int: "INTEGER",
    float: "REAL",
    str: "TEXT",
    bool: "INTEGER",  # SQLite 没有 BOOLEAN,用 0/1
    bytes: "BLOB",
}


def _map_py_type_to_sqlite(py_type: Type) -> str:
    """将 Python 类型转换为 SQLite 类型"""
    origin = get_origin(py_type) or py_type
    if origin in _PY_TYPE_TO_SQLITE:
        return _PY_TYPE_TO_SQLITE[origin]
    # 处理 Optional[T] => T(Optional 是 Union[T, None])
    if origin is Union:
        args = get_args(py_type)
        non_none = [t for t in args if t is not type(None)]
        if len(non_none) == 1:
            return _map_py_type_to_sqlite(non_none[0])
    # 默认 fallback
    return "TEXT"


def _sqlite_type_to_py_type(sqlite_type: str) -> type:
    """将 SQLite 声明类型映射为最可能的 Python 类型"""
    if not sqlite_type:
        return str  # 无类型默认为 TEXT
    upper = sqlite_type.upper()
    if "INT" in upper:
        return int
    elif "CHAR" in upper or "CLOB" in upper or "TEXT" in upper:
        return str
    elif "BLOB" in upper:
        return bytes
    elif "REAL" in upper or "FLOA" in upper or "DOUB" in upper:
        return float
    else:
        # 兜底:可能是 NUMERIC 或自定义类型,按 TEXT 处理
        return str


import ast


def _parse_default_value(dflt_str: str) -> Any:
    """将 PRAGMA 返回的默认值字符串转为 Python 对象"""
    if dflt_str is None:
        return None

    # 尝试解析为字面量(支持数字、字符串、布尔、None)
    try:
        # 处理 SQLite 中的布尔:'1'/'0' 或 'true'/'false'(但 SQLite 实际存整数)
        if dflt_str == '1':
            return True
        elif dflt_str == '0':
            return False
        elif dflt_str.lower() in ('true', 'false'):
            return dflt_str.lower() == 'true'
        # 尝试用 ast.literal_eval 安全解析
        return ast.literal_eval(dflt_str)
    except (ValueError, SyntaxError):
        # 如果不是合法字面量,当作字符串处理(去掉外层引号)
        if dflt_str.startswith("'") and dflt_str.endswith("'"):
            return dflt_str[1:-1].replace("''", "'")
        elif dflt_str.startswith('"') and dflt_str.endswith('"'):
            return dflt_str[1:-1].replace('""', '"')
        else:
            return dflt_str


class SqliteColumn:
    def __init__(
            self,
            name: str,
            data_type: Type = str,
            primary_key: bool = False,  # 是否主键
            autoincrement: bool = False,  # 是否自增
            not_null: bool = False,  # 是否不可为空
            unique: bool = False,  # 是否唯一
            default: Optional[Any] = None  # 缺省值
    ):
        if autoincrement and not primary_key:
            raise ValueError("autoincrement requires primary_key=True")
        if autoincrement and data_type is not int:
            raise ValueError("autoincrement only supported for INTEGER type")
        self.name = name
        self.data_type = data_type
        self.primary_key = primary_key
        self.autoincrement = autoincrement
        self.not_null = not_null
        self.unique = unique
        self.default = default

    def _format_default(self) -> str:
        """格式化默认值为 SQL 字面量"""
        val = self.default
        if val is None:
            return "NULL"
        elif isinstance(val, bool):
            return "1" if val else "0"
        elif isinstance(val, str):
            # 转义单引号(简单处理)
            escaped = val.replace("'", "''")
            return f"'{escaped}'"
        elif isinstance(val, (int, float)):
            return str(val)
        else:
            # 兜底:转为字符串并加引号
            return f"'{str(val)}'"

    def to_sql_def(self) -> str:
        parts = [f"`{self.name}`", _map_py_type_to_sqlite(self.data_type)]

        if self.primary_key:
            parts.append("PRIMARY KEY")
        if self.autoincrement:
            parts.append("AUTOINCREMENT")
        if self.not_null:
            parts.append("NOT NULL")
        if self.unique:
            parts.append("UNIQUE")
        if self.default is not None:
            parts.append(f"DEFAULT {self._format_default()}")

        return " ".join(parts)


class SqliteTable:
    name: str
    column_dict: Dict[str, SqliteColumn]

    def __init__(self, table_name: str, columns: Iterable[SqliteColumn] = None):
        self.name = table_name
        self.column_dict = {}
        if columns:
            for column in columns:
                self.column_dict[column.name] = column

    def to_sql_def(self, delete_if_exists: bool = False) -> str:
        """
        生成创建表的 SQL 语句
        :param delete_if_exists: 是否先 DROP TABLE IF EXISTS
        """
        column_defs = [col.to_sql_def() for _, col in self.column_dict.items()]
        create_sql = f"CREATE TABLE {self.name} ({', '.join(column_defs)});"

        if delete_if_exists:
            drop_sql = f"DROP TABLE IF EXISTS {self.name};"
            return f"{drop_sql}\n{create_sql}"
        else:
            # 使用 IF NOT EXISTS 更安全
            create_sql = create_sql.replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS", 1)
            return create_sql

    def create_table(self, conn: sqlite3.Connection, delete_if_exists: bool = False):
        """
        在数据库中创建表
        :param conn: sqlite3.Connection 对象
        :param delete_if_exists: 是否先删除已存在的表
        """
        sql = self.to_sql_def(delete_if_exists=delete_if_exists)
        conn.executescript(sql)  # 支持多条 SQL(如 DROP + CREATE)
        conn.commit()

    def insert_record(self, conn: sqlite3.Connection, record: Dict[str, Any]):
        """插入单条记录"""
        columns = SqliteTable.get_insert_columns_by_record(record)
        placeholders = SqliteTable.get_insert_placeholder_by_record(record)
        sql = f"INSERT INTO {self.name} ({', '.join(columns)}) VALUES ({placeholders})"
        conn.execute(sql, tuple(record.values()))
        conn.commit()

    def insert_records(self, conn: sqlite3.Connection, records: List[Dict[str, Any]]):
        """批量插入多条记录"""
        if not records:
            return
        columns = SqliteTable.get_insert_columns_by_record(records[0])
        placeholders = SqliteTable.get_insert_placeholder_by_record(records[0])
        sql = f"INSERT INTO {self.name} ({', '.join(columns)}) VALUES ({placeholders})"
        values = SqliteTable.get_insert_values_by_records(records)
        conn.executemany(sql, values)
        conn.commit()

    @staticmethod
    def get_insert_columns_by_record(record: Dict[str, Any]):
        return [f"`{key}`" for key in record.keys()]

    @staticmethod
    def get_insert_placeholder_by_record(record: Dict[str, Any]):
        return ', '.join(['?' for _ in record.keys()])

    @staticmethod
    def get_insert_values_by_records(records: List[Dict[str, Any]]):
        if not records:
            return []
        return [tuple(r[k] for k in records[0].keys()) for r in records]


class SqliteDB:
    path: str
    conn: sqlite3.Connection
    table_cache: Dict[str, SqliteTable]

    def __init__(self, path: str, auto_create: bool = True):
        self.path = os.path.realpath(path)
        if not os.path.exists(self.path):
            if not auto_create:
                raise FileNotFoundError(f"Db file not found: {self.path}.")
            dir_path = os.path.dirname(self.path)
            if not os.path.exists(dir_path):
                os.makedirs(dir_path, exist_ok=True)

        self.conn = sqlite3.connect(self.path)
        self.table_cache = {}

    def create_table(self, table: SqliteTable, delete_if_exists: bool = True):
        table.create_table(self.conn, delete_if_exists)
        self.table_cache[table.name] = table

    def is_table_exists(self, table_name: str) -> bool:
        if table_name in self.table_cache:
            return True
        cursor = self.conn.cursor()
        cursor.execute("""
                       SELECT name
                       FROM sqlite_master
                       WHERE type = 'table'
                         AND name = ?
                       """, (table_name,))
        exist = cursor.fetchone() is not None
        if exist:
            self.table_cache[table_name] = self.get_table_by_name(table_name)
        return exist

    def get_table_by_name(self, table_name: str) -> SqliteTable:
        """
            从数据库中读取表结构,还原为 SqliteTable 对象
        """
        if table_name in self.table_cache:
            return self.table_cache[table_name]
        # 获取列信息
        cur = self.conn.execute(f"PRAGMA table_info({table_name});")
        rows = cur.fetchall()

        if not rows:
            raise ValueError(f"Table '{table_name}' does not exist.")
        table = SqliteTable(table_name)
        for row in rows:
            cid, name, type_affinity, notnull, dflt_value, pk = row

            py_type = _sqlite_type_to_py_type(type_affinity)
            default_val = _parse_default_value(dflt_value)

            # 检测是否为自增(仅当 INTEGER 主键且有 sqlite_sequence 记录)
            autoincrement = False
            if pk and py_type is int:
                # 检查是否存在 sqlite_sequence 表且包含该表
                seq_cur = self.conn.execute(
                    "SELECT 1 FROM sqlite_master WHERE type='table' AND name='sqlite_sequence';"
                )
                if seq_cur.fetchone():
                    seq_cur = self.conn.execute(
                        "SELECT 1 FROM sqlite_sequence WHERE name = ?;", (table_name,)
                    )
                    autoincrement = seq_cur.fetchone() is not None

            column = SqliteColumn(
                name=name,
                data_type=py_type,
                primary_key=bool(pk),
                autoincrement=autoincrement,
                not_null=bool(notnull),
                default=default_val,
                # 注意:UNIQUE、COLLATE 无法从 table_info 获取,需解析 CREATE SQL
            )
            table.column_dict[column.name] = column
        self.table_cache[table_name] = table
        return table