/*
* driver.cj - 数据库驱动模块
*
* 提供数据库驱动抽象层:
* - Driver 接口定义
* - DriverManager 驱动管理
* - 连接工厂
* - SPI 加载机制
*
* 设计参考 JDBC Driver 架构。
*/
package tybb2026::tycj_orm
import std.collection.*
import std.sync.*
// ============================================================================
// 数据库驱动接口
// ============================================================================
/**
* Driver - 数据库驱动接口
*
* 所有数据库驱动必须实现此接口。
*/
public interface Driver {
// 获取驱动名称
func getName(): String
// 获取驱动版本
func getVersion(): String
// 支持的数据库类型
func getSupportedDatabaseType(): DatabaseType
// 创建连接
func connect(config: DatabaseConfig): Option<Connection>
// 检查 URL 是否匹配
func acceptsURL(url: String): Bool
}
/**
* DriverInfo - 驱动信息
*/
public class DriverInfo {
public var name: String = ""
public var version: String = ""
public var dbType: DatabaseType = DatabaseType.MySql
public var driver: Option<Driver> = None
public init() {}
public init(name: String, version: String, dbType: DatabaseType) {
this.name = name
this.version = version
this.dbType = dbType
}
}
// ============================================================================
// 驱动管理器
// ============================================================================
/**
* DriverManager - 驱动管理器
*
* 管理已注册的数据库驱动,提供连接获取功能。
*/
public class DriverManager {
private static var registeredDrivers: ArrayList<DriverInfo> = ArrayList<DriverInfo>()
private static var lock: ReentrantMutex = ReentrantMutex()
// 注册驱动
public static func registerDriver(driver: Driver): Unit {
lock.lock()
try {
let info = DriverInfo()
info.name = driver.getName()
info.version = driver.getVersion()
info.dbType = driver.getSupportedDatabaseType()
info.driver = Some(driver)
registeredDrivers.add(info)
} finally {
lock.unlock()
}
println("[tycj_orm] 注册驱动: ${driver.getName()} v${driver.getVersion()}")
}
// 注销驱动
public static func deregisterDriver(driver: Driver): Unit {
lock.lock()
try {
let newDrivers = ArrayList<DriverInfo>()
for (info in registeredDrivers) {
match (info.driver) {
case Some(d) =>
if (d.getName() != driver.getName()) {
newDrivers.add(info)
}
case None => ()
}
}
registeredDrivers = newDrivers
} finally {
lock.unlock()
}
println("[tycj_orm] 注销驱动: ${driver.getName()}")
}
// 获取连接
public static func getConnection(config: DatabaseConfig): Option<Connection> {
lock.lock()
try {
for (info in registeredDrivers) {
match (info.driver) {
case Some(driver) =>
match (driver.connect(config)) {
case Some(conn) =>
return Some(conn)
case None => ()
}
case None => ()
}
}
None
} finally {
lock.unlock()
}
}
// 获取所有已注册驱动
public static func getDrivers(): ArrayList<DriverInfo> {
lock.lock()
try {
let result = ArrayList<DriverInfo>()
for (info in registeredDrivers) {
result.add(info)
}
result
} finally {
lock.unlock()
}
}
// 清除所有驱动
public static func clear(): Unit {
lock.lock()
try {
registeredDrivers.clear()
} finally {
lock.unlock()
}
}
}
// ============================================================================
// Mock 驱动实现
// ============================================================================
/**
* MockDriver - Mock 数据库驱动
*
* 用于测试和开发,不连接真实数据库。
*/
public class MockDriver <: Driver {
public func getName(): String {
"MockDriver"
}
public func getVersion(): String {
"1.0.0"
}
public func getSupportedDatabaseType(): DatabaseType {
DatabaseType.MySql
}
public func connect(config: DatabaseConfig): Option<Connection> {
Some(MockConnection())
}
public func acceptsURL(url: String): Bool {
url.startsWith("mock:")
}
}
// ============================================================================
// 连接工厂
// ============================================================================
/**
* ConnectionFactory - 连接工厂接口
*/
public interface ConnectionFactory {
func createConnection(): Option<Connection>
func validateConnection(conn: Connection): Bool
}
/**
* DefaultConnectionFactory - 默认连接工厂
*/
public class DefaultConnectionFactory <: ConnectionFactory {
private var config: DatabaseConfig
private var driver: Option<Driver> = None
public init(config: DatabaseConfig) {
this.config = config
}
public func setDriver(driver: Driver): DefaultConnectionFactory {
this.driver = Some(driver)
return this
}
public func createConnection(): Option<Connection> {
match (driver) {
case Some(d) => d.connect(config)
case None => DriverManager.getConnection(config)
}
}
public func validateConnection(conn: Connection): Bool {
// 简化实现:尝试创建语句
try {
let stmt = conn.prepareStatement("SELECT 1")
stmt.close()
return true
} catch (_: Exception) {
return false
}
}
}
// ============================================================================
// SQL 方言
// ============================================================================
/**
* Dialect - SQL 方言接口
*
* 不同数据库的 SQL 语法差异处理。
*/
public interface Dialect {
// 获取数据库类型
func getDatabaseType(): DatabaseType
// 分页 SQL
func getPaginationSql(sql: String, offset: Int64, limit: Int64): String
// 获取自增主键 SQL
func getIdentitySelectSql(): String
// 是否支持序列
func supportsSequences(): Bool
// 获取序列下一个值 SQL
func getSequenceNextValSql(sequenceName: String): String
// 引用标识符(表名、列名)
func quoteIdentifier(identifier: String): String
}
/**
* MySqlDialect - MySQL 方言
*/
public class MySqlDialect <: Dialect {
public func getDatabaseType(): DatabaseType {
DatabaseType.MySql
}
public func getPaginationSql(sql: String, offset: Int64, limit: Int64): String {
"${sql} LIMIT ${limit} OFFSET ${offset}"
}
public func getIdentitySelectSql(): String {
"SELECT LAST_INSERT_ID()"
}
public func supportsSequences(): Bool {
false
}
public func getSequenceNextValSql(sequenceName: String): String {
""
}
public func quoteIdentifier(identifier: String): String {
"`${identifier}`"
}
}
/**
* PostgreSqlDialect - PostgreSQL 方言
*/
public class PostgreSqlDialect <: Dialect {
public func getDatabaseType(): DatabaseType {
DatabaseType.PostgreSQL
}
public func getPaginationSql(sql: String, offset: Int64, limit: Int64): String {
"${sql} LIMIT ${limit} OFFSET ${offset}"
}
public func getIdentitySelectSql(): String {
"SELECT currval(pg_get_serial_sequence('table', 'id'))"
}
public func supportsSequences(): Bool {
true
}
public func getSequenceNextValSql(sequenceName: String): String {
"SELECT nextval('${sequenceName}')"
}
public func quoteIdentifier(identifier: String): String {
"\"${identifier}\""
}
}
/**
* SqliteDialect - SQLite 方言
*/
public class SqliteDialect <: Dialect {
public func getDatabaseType(): DatabaseType {
DatabaseType.Sqlite
}
public func getPaginationSql(sql: String, offset: Int64, limit: Int64): String {
"${sql} LIMIT ${limit} OFFSET ${offset}"
}
public func getIdentitySelectSql(): String {
"SELECT last_insert_rowid()"
}
public func supportsSequences(): Bool {
false
}
public func getSequenceNextValSql(sequenceName: String): String {
""
}
public func quoteIdentifier(identifier: String): String {
"\"${identifier}\""
}
}
/**
* DialectRegistry - 方言注册表
*/
public class DialectRegistry {
private static var mySqlDialect: Option<Dialect> = None
private static var postgreSqlDialect: Option<Dialect> = None
private static var sqliteDialect: Option<Dialect> = None
private static var initialized: Bool = false
// 初始化默认方言
private static func ensureInitialized(): Unit {
if (initialized) {
return
}
mySqlDialect = Some(MySqlDialect())
postgreSqlDialect = Some(PostgreSqlDialect())
sqliteDialect = Some(SqliteDialect())
initialized = true
}
// 获取方言
public static func getDialect(dbType: DatabaseType): Option<Dialect> {
ensureInitialized()
match (dbType) {
case DatabaseType.MySql => mySqlDialect
case DatabaseType.PostgreSQL => postgreSqlDialect
case DatabaseType.Sqlite => sqliteDialect
case DatabaseType.Oracle => None
case DatabaseType.SqlServer => None
}
}
// 注册方言
public static func registerDialect(dbType: DatabaseType, dialect: Dialect): Unit {
ensureInitialized()
match (dbType) {
case DatabaseType.MySql => mySqlDialect = Some(dialect)
case DatabaseType.PostgreSQL => postgreSqlDialect = Some(dialect)
case DatabaseType.Sqlite => sqliteDialect = Some(dialect)
case DatabaseType.Oracle => ()
case DatabaseType.SqlServer => ()
}
}
}