/*
 * driver.cj - 数据库驱动模块
 *
 * 提供数据库驱动抽象层:
 * - Driver 接口定义
 * - DriverManager 驱动管理
 * - 连接工厂
 * - SPI 加载机制
 *
 * 设计参考 JDBC Driver 架构。
 */

package tybb2026::tycj_orm

import std.collection.*
import std.sync.*

// ============================================================================
// 数据库驱动接口
// ============================================================================

/**
 * Driver - 数据库驱动接口
 *
 * 所有数据库驱动必须实现此接口。
 */
public interface Driver {
    // 获取驱动名称
    func getName(): String
    
    // 获取驱动版本
    func getVersion(): String
    
    // 支持的数据库类型
    func getSupportedDatabaseType(): DatabaseType
    
    // 创建连接
    func connect(config: DatabaseConfig): Option<Connection>
    
    // 检查 URL 是否匹配
    func acceptsURL(url: String): Bool
}

/**
 * DriverInfo - 驱动信息
 */
public class DriverInfo {
    public var name: String = ""
    public var version: String = ""
    public var dbType: DatabaseType = DatabaseType.MySql
    public var driver: Option<Driver> = None
    
    public init() {}
    
    public init(name: String, version: String, dbType: DatabaseType) {
        this.name = name
        this.version = version
        this.dbType = dbType
    }
}

// ============================================================================
// 驱动管理器
// ============================================================================

/**
 * DriverManager - 驱动管理器
 *
 * 管理已注册的数据库驱动,提供连接获取功能。
 */
public class DriverManager {
    private static var registeredDrivers: ArrayList<DriverInfo> = ArrayList<DriverInfo>()
    private static var lock: ReentrantMutex = ReentrantMutex()
    
    // 注册驱动
    public static func registerDriver(driver: Driver): Unit {
        lock.lock()
        try {
            let info = DriverInfo()
            info.name = driver.getName()
            info.version = driver.getVersion()
            info.dbType = driver.getSupportedDatabaseType()
            info.driver = Some(driver)
            registeredDrivers.add(info)
        } finally {
            lock.unlock()
        }
        println("[tycj_orm] 注册驱动: ${driver.getName()} v${driver.getVersion()}")
    }
    
    // 注销驱动
    public static func deregisterDriver(driver: Driver): Unit {
        lock.lock()
        try {
            let newDrivers = ArrayList<DriverInfo>()
            for (info in registeredDrivers) {
                match (info.driver) {
                    case Some(d) =>
                        if (d.getName() != driver.getName()) {
                            newDrivers.add(info)
                        }
                    case None => ()
                }
            }
            registeredDrivers = newDrivers
        } finally {
            lock.unlock()
        }
        println("[tycj_orm] 注销驱动: ${driver.getName()}")
    }
    
    // 获取连接
    public static func getConnection(config: DatabaseConfig): Option<Connection> {
        lock.lock()
        try {
            for (info in registeredDrivers) {
                match (info.driver) {
                    case Some(driver) =>
                        match (driver.connect(config)) {
                            case Some(conn) =>
                                return Some(conn)
                            case None => ()
                        }
                    case None => ()
                }
            }
            None
        } finally {
            lock.unlock()
        }
    }
    
    // 获取所有已注册驱动
    public static func getDrivers(): ArrayList<DriverInfo> {
        lock.lock()
        try {
            let result = ArrayList<DriverInfo>()
            for (info in registeredDrivers) {
                result.add(info)
            }
            result
        } finally {
            lock.unlock()
        }
    }
    
    // 清除所有驱动
    public static func clear(): Unit {
        lock.lock()
        try {
            registeredDrivers.clear()
        } finally {
            lock.unlock()
        }
    }
}

// ============================================================================
// Mock 驱动实现
// ============================================================================

/**
 * MockDriver - Mock 数据库驱动
 *
 * 用于测试和开发,不连接真实数据库。
 */
public class MockDriver <: Driver {
    public func getName(): String {
        "MockDriver"
    }
    
    public func getVersion(): String {
        "1.0.0"
    }
    
    public func getSupportedDatabaseType(): DatabaseType {
        DatabaseType.MySql
    }
    
    public func connect(config: DatabaseConfig): Option<Connection> {
        Some(MockConnection())
    }
    
    public func acceptsURL(url: String): Bool {
        url.startsWith("mock:")
    }
}

// ============================================================================
// 连接工厂
// ============================================================================

/**
 * ConnectionFactory - 连接工厂接口
 */
public interface ConnectionFactory {
    func createConnection(): Option<Connection>
    func validateConnection(conn: Connection): Bool
}

/**
 * DefaultConnectionFactory - 默认连接工厂
 */
public class DefaultConnectionFactory <: ConnectionFactory {
    private var config: DatabaseConfig
    private var driver: Option<Driver> = None
    
    public init(config: DatabaseConfig) {
        this.config = config
    }
    
    public func setDriver(driver: Driver): DefaultConnectionFactory {
        this.driver = Some(driver)
        return this
    }
    
    public func createConnection(): Option<Connection> {
        match (driver) {
            case Some(d) => d.connect(config)
            case None => DriverManager.getConnection(config)
        }
    }
    
    public func validateConnection(conn: Connection): Bool {
        // 简化实现:尝试创建语句
        try {
            let stmt = conn.prepareStatement("SELECT 1")
            stmt.close()
            return true
        } catch (_: Exception) {
            return false
        }
    }
}

// ============================================================================
// SQL 方言
// ============================================================================

/**
 * Dialect - SQL 方言接口
 *
 * 不同数据库的 SQL 语法差异处理。
 */
public interface Dialect {
    // 获取数据库类型
    func getDatabaseType(): DatabaseType
    
    // 分页 SQL
    func getPaginationSql(sql: String, offset: Int64, limit: Int64): String
    
    // 获取自增主键 SQL
    func getIdentitySelectSql(): String
    
    // 是否支持序列
    func supportsSequences(): Bool
    
    // 获取序列下一个值 SQL
    func getSequenceNextValSql(sequenceName: String): String
    
    // 引用标识符(表名、列名)
    func quoteIdentifier(identifier: String): String
}

/**
 * MySqlDialect - MySQL 方言
 */
public class MySqlDialect <: Dialect {
    public func getDatabaseType(): DatabaseType {
        DatabaseType.MySql
    }
    
    public func getPaginationSql(sql: String, offset: Int64, limit: Int64): String {
        "${sql} LIMIT ${limit} OFFSET ${offset}"
    }
    
    public func getIdentitySelectSql(): String {
        "SELECT LAST_INSERT_ID()"
    }
    
    public func supportsSequences(): Bool {
        false
    }
    
    public func getSequenceNextValSql(sequenceName: String): String {
        ""
    }
    
    public func quoteIdentifier(identifier: String): String {
        "`${identifier}`"
    }
}

/**
 * PostgreSqlDialect - PostgreSQL 方言
 */
public class PostgreSqlDialect <: Dialect {
    public func getDatabaseType(): DatabaseType {
        DatabaseType.PostgreSQL
    }
    
    public func getPaginationSql(sql: String, offset: Int64, limit: Int64): String {
        "${sql} LIMIT ${limit} OFFSET ${offset}"
    }
    
    public func getIdentitySelectSql(): String {
        "SELECT currval(pg_get_serial_sequence('table', 'id'))"
    }
    
    public func supportsSequences(): Bool {
        true
    }
    
    public func getSequenceNextValSql(sequenceName: String): String {
        "SELECT nextval('${sequenceName}')"
    }
    
    public func quoteIdentifier(identifier: String): String {
        "\"${identifier}\""
    }
}

/**
 * SqliteDialect - SQLite 方言
 */
public class SqliteDialect <: Dialect {
    public func getDatabaseType(): DatabaseType {
        DatabaseType.Sqlite
    }
    
    public func getPaginationSql(sql: String, offset: Int64, limit: Int64): String {
        "${sql} LIMIT ${limit} OFFSET ${offset}"
    }
    
    public func getIdentitySelectSql(): String {
        "SELECT last_insert_rowid()"
    }
    
    public func supportsSequences(): Bool {
        false
    }
    
    public func getSequenceNextValSql(sequenceName: String): String {
        ""
    }
    
    public func quoteIdentifier(identifier: String): String {
        "\"${identifier}\""
    }
}

/**
 * DialectRegistry - 方言注册表
 */
public class DialectRegistry {
    private static var mySqlDialect: Option<Dialect> = None
    private static var postgreSqlDialect: Option<Dialect> = None
    private static var sqliteDialect: Option<Dialect> = None
    private static var initialized: Bool = false
    
    // 初始化默认方言
    private static func ensureInitialized(): Unit {
        if (initialized) {
            return
        }
        mySqlDialect = Some(MySqlDialect())
        postgreSqlDialect = Some(PostgreSqlDialect())
        sqliteDialect = Some(SqliteDialect())
        initialized = true
    }
    
    // 获取方言
    public static func getDialect(dbType: DatabaseType): Option<Dialect> {
        ensureInitialized()
        match (dbType) {
            case DatabaseType.MySql => mySqlDialect
            case DatabaseType.PostgreSQL => postgreSqlDialect
            case DatabaseType.Sqlite => sqliteDialect
            case DatabaseType.Oracle => None
            case DatabaseType.SqlServer => None
        }
    }
    
    // 注册方言
    public static func registerDialect(dbType: DatabaseType, dialect: Dialect): Unit {
        ensureInitialized()
        match (dbType) {
            case DatabaseType.MySql => mySqlDialect = Some(dialect)
            case DatabaseType.PostgreSQL => postgreSqlDialect = Some(dialect)
            case DatabaseType.Sqlite => sqliteDialect = Some(dialect)
            case DatabaseType.Oracle => ()
            case DatabaseType.SqlServer => ()
        }
    }
}