/*
* 文件名: enterprise.cj
* 功能: ORM模块企业级增强
* 说明: 提供性能监控、慢查询日志、连接池健康检查等企业级功能
*/
package tybb2026::tycj_orm
import std.collection.*
import std.time.*
import std.sync.*
import std.convert.*
// ============================================================================
// SQL执行统计
// ============================================================================
/**
* SQL执行统计信息
* 企业级监控:记录执行次数、耗时、影响行数等
*/
public class SqlExecutionStats {
public var sqlHash: String = ""
public var sqlPreview: String = ""
public var executionCount: Int64 = 0
public var totalExecutionTime: Int64 = 0
public var maxExecutionTime: Int64 = 0
public var minExecutionTime: Int64 = 9223372036854775807
public var avgExecutionTime: Int64 = 0
public var totalRowsAffected: Int64 = 0
public var errorCount: Int64 = 0
public var lastExecutionTime: Int64 = 0
public var slowQueryCount: Int64 = 0
public init() {}
public init(sqlHash: String, sqlPreview: String) {
this.sqlHash = sqlHash
this.sqlPreview = sqlPreview
}
public func recordExecution(duration: Int64, rowsAffected: Int64, slowThreshold: Int64): Unit {
executionCount += 1
totalExecutionTime += duration
lastExecutionTime = duration
totalRowsAffected += rowsAffected
if (duration > maxExecutionTime) {
maxExecutionTime = duration
}
if (duration < minExecutionTime) {
minExecutionTime = duration
}
avgExecutionTime = totalExecutionTime / executionCount
if (duration >= slowThreshold) {
slowQueryCount += 1
}
}
public func recordError(): Unit {
errorCount += 1
}
public func toJson(): String {
"{\"sqlHash\":\"${sqlHash}\",\"sqlPreview\":\"${escapeJson(sqlPreview)}\"," +
"\"executionCount\":${executionCount},\"avgExecutionTime\":${avgExecutionTime}," +
"\"maxExecutionTime\":${maxExecutionTime},\"minExecutionTime\":${minExecutionTime}," +
"\"totalRowsAffected\":${totalRowsAffected},\"errorCount\":${errorCount}," +
"\"slowQueryCount\":${slowQueryCount}}"
}
private func escapeJson(s: String): String {
var result = s
// 按照 JSON 规范转义特殊字符
result = result.replace("\\", "\\\\") // 反斜杠
result = result.replace("\"", "\\\"") // 双引号
result = result.replace("\n", "\\n") // 换行
result = result.replace("\r", "\\r") // 回车
result = result.replace("\t", "\\t") // 制表符
result = result.replace("\b", "\\b") // 退格
result = result.replace("\f", "\\f") // 换页
// 处理控制字符 (0x00-0x1F)
result = escapeControlChars(result)
result
}
// 转义控制字符
private func escapeControlChars(s: String): String {
var result = ""
for (i in 0..s.size) {
let c = s[i]
if (c < 32u8) {
// 控制字符使用 \uXXXX 格式
result = result + "\\u" + formatHex(c)
} else {
result = result + String.fromUtf8([c])
}
}
result
}
// 将字节格式化为4位十六进制
private func formatHex(b: UInt8): String {
let hexChars = "0123456789ABCDEF"
let high = Int64((b >> 4) & 0x0F)
let low = Int64(b & 0x0F)
String.fromUtf8([hexChars[high], hexChars[low], hexChars[0], hexChars[0]])
}
}
// ============================================================================
// 慢查询日志
// ============================================================================
public class SlowQueryLog {
public var timestamp: Int64 = 0
public var sql: String = ""
public var duration: Int64 = 0
public var params: ArrayList<String> = ArrayList<String>()
public init() {}
public init(sql: String, duration: Int64, params: ArrayList<String>) {
this.timestamp = DateTime.now().toUnixTimeStamp().toMilliseconds()
this.sql = sql
this.duration = duration
this.params = params
}
public func toLogString(): String {
let timeStr = DateTime.now().toString()
"[SLOW QUERY] ${timeStr} - ${duration}ms - ${sql}"
}
}
public class SlowQueryLogger {
private var slowQueryThreshold: Int64 = 1000
private var maxLogSize: Int64 = 1000
private var logs: ArrayList<SlowQueryLog> = ArrayList<SlowQueryLog>()
private var lock: ReentrantMutex = ReentrantMutex()
public init() {}
public init(threshold: Int64, maxSize: Int64) {
this.slowQueryThreshold = threshold
this.maxLogSize = maxSize
}
public func logQuery(sql: String, duration: Int64, params: ArrayList<String>): Unit {
if (duration < slowQueryThreshold) {
return
}
lock.lock()
try {
let log = SlowQueryLog(sql, duration, params)
logs.add(log)
if (logs.size > maxLogSize) {
// 移除最早的日志
logs.remove(0..1)
}
println(log.toLogString())
} finally {
lock.unlock()
}
}
public func getLogs(): ArrayList<SlowQueryLog> {
lock.lock()
try {
// 使用 clone 方法复制列表,避免创建新列表的开销
ArrayList<SlowQueryLog>(logs)
} finally {
lock.unlock()
}
}
public func clear(): Unit {
lock.lock()
try {
logs.clear()
} finally {
lock.unlock()
}
}
public func getSlowQueryCount(): Int64 {
lock.lock()
try {
logs.size
} finally {
lock.unlock()
}
}
public func setThreshold(threshold: Int64): Unit {
this.slowQueryThreshold = threshold
}
}
// ============================================================================
// 连接池健康检查
// ============================================================================
public class ConnectionPoolHealth {
public var healthy: Bool = true
public var activeConnections: Int64 = 0
public var idleConnections: Int64 = 0
public var totalConnections: Int64 = 0
public var lastCheckTime: Int64 = 0
public var issues: ArrayList<String> = ArrayList<String>()
public init() {}
public func addIssue(issue: String): Unit {
issues.add(issue)
healthy = false
}
public func toJson(): String {
var result = "{\"healthy\":${healthy},\"activeConnections\":${activeConnections}," +
"\"idleConnections\":${idleConnections},\"totalConnections\":${totalConnections}," +
"\"issues\":["
var first = true
for (issue in issues) {
if (!first) {
result += ","
}
first = false
result += "\"${issue}\""
}
result += "]}"
result
}
}
public class ConnectionPoolHealthChecker {
private var pool: ConnectionPool
private var maxActiveRatio: Float64 = 0.8
public init(pool: ConnectionPool) {
this.pool = pool
}
public func checkHealth(): ConnectionPoolHealth {
let health = ConnectionPoolHealth()
health.lastCheckTime = DateTime.now().toUnixTimeStamp().toMilliseconds()
health.activeConnections = pool.getActiveConnections()
health.totalConnections = pool.getTotalConnections()
health.idleConnections = pool.getIdleConnections()
if (health.totalConnections > 0) {
let activeF = Float64.parse(health.activeConnections.toString())
let totalF = Float64.parse(health.totalConnections.toString())
let activeRatio = activeF / totalF
if (activeRatio > maxActiveRatio) {
health.addIssue("活跃连接比例过高")
}
}
if (health.totalConnections == 0) {
health.addIssue("连接池为空")
}
health
}
}
// ============================================================================
// SQL注入防护
// ============================================================================
public class SqlInjectionDetector {
private var dangerousKeywords: HashSet<String> = HashSet<String>()
private var dangerousPatterns: ArrayList<String> = ArrayList<String>()
public init() {
dangerousKeywords.add("DROP")
dangerousKeywords.add("DELETE")
dangerousKeywords.add("TRUNCATE")
dangerousKeywords.add("ALTER")
dangerousKeywords.add("EXEC")
dangerousKeywords.add("EXECUTE")
dangerousKeywords.add("UNION")
dangerousKeywords.add("INSERT")
dangerousKeywords.add("UPDATE")
dangerousPatterns.add("--")
dangerousPatterns.add("/*")
dangerousPatterns.add("*/")
dangerousPatterns.add(";")
}
// 检测 SQL 中的危险关键字(仅用于日志/告警,不替代参数化查询)
public func detect(sql: String): String {
for (keyword in dangerousKeywords) {
if (containsIgnoreCase(sql, keyword)) {
if (!isInString(sql, keyword)) {
return "检测到潜在危险关键字: ${keyword}"
}
}
}
for (pattern in dangerousPatterns) {
if (sql.contains(pattern)) {
return "检测到潜在危险模式: ${pattern}"
}
}
""
}
// 验证参数安全性
// 注意:此方法仅作为辅助检查,真正的 SQL 注入防护应依赖参数化查询
// 白名单模式:只允许安全字符通过
public func validateParam(param: String): Bool {
// 空参数始终安全
if (param.isEmpty()) { return true }
// 检查是否包含可能导致 SQL 注入的元字符
// 使用更精确的检测,避免误报
var i: Int64 = 0
while (i < param.size) {
let c = param[i]
// 单引号后跟 SQL 关键字模式才是危险的
if (c == 39u8) { // '\''
// 检查引号后是否紧跟危险关键字
if (i + 1 < param.size) {
var rest = ""
var j = i + 1
while (j < param.size) {
rest = rest + String.fromUtf8([param[j]])
j++
}
let upperRest = toUpperCase(rest)
if (upperRest.startsWith("OR ") || upperRest.startsWith("AND ") ||
upperRest.startsWith("--") || upperRest.startsWith(";")) {
return false
}
}
}
i++
}
true
}
private func containsIgnoreCase(s: String, sub: String): Bool {
let upperS = toUpperCase(s)
let upperSub = toUpperCase(sub)
upperS.contains(upperSub)
}
private func toUpperCase(s: String): String {
// 使用 StringBuilder 提高性能
let sb = StringBuilder()
for (i in 0..s.size) {
let c = s[i]
if (c >= 97u8 && c <= 122u8) {
sb.append(String.fromUtf8([c - 32u8]))
} else {
sb.append(String.fromUtf8([c]))
}
}
sb.toString()
}
// 检查关键字是否在字符串字面量内(修正版)
// 返回 true 表示关键字在字符串内(安全),false 表示在字符串外(需检查)
private func isInString(sql: String, keyword: String): Bool {
var inSingleQuote = false
var i: Int64 = 0
let keywordLen = keyword.size
while (i < sql.size) {
let c = sql[i]
if (c == 39u8) { // '\''
// 检查是否为转义引号 ''
if (i + 1 < sql.size && sql[i + 1] == 39u8) {
i += 2
continue
}
inSingleQuote = !inSingleQuote
} else if (!inSingleQuote && i + keywordLen <= sql.size) {
// 检查是否匹配关键字(不在字符串内时)
var isMatch = true
for (j in 0..keywordLen) {
if (toUpperCase(String.fromUtf8([sql[i + j]])) !=
toUpperCase(String.fromUtf8([keyword[j]]))) {
isMatch = false
break
}
}
if (isMatch) {
return false // 关键字在字符串外,需要检查
}
}
i++
}
true // 关键字在字符串内或不存在,视为安全
}
}
// ============================================================================
// ORM异常
// ============================================================================
public class OrmException <: Exception {
public init(message: String) {
super(message)
}
}
// ============================================================================
// SQL执行监控器
// ============================================================================
public class SqlExecutionMonitor {
private var sqlStats: HashMap<String, SqlExecutionStats> = HashMap<String, SqlExecutionStats>()
private var slowQueryLogger: SlowQueryLogger = SlowQueryLogger()
private var injectionDetector: SqlInjectionDetector = SqlInjectionDetector()
private var lock: ReentrantMutex = ReentrantMutex()
private var slowQueryThreshold: Int64 = 1000
private var enableStats: Bool = true
private var enableSlowQueryLog: Bool = true
private var enableInjectionCheck: Bool = true
public init() {}
public func startMonitoring(sql: String, params: ArrayList<String>): MonitoringContext {
if (enableInjectionCheck) {
let checkResult = injectionDetector.detect(sql)
if (!checkResult.isEmpty()) {
throw OrmException("SQL安全检查失败: ${checkResult}")
}
for (param in params) {
if (!injectionDetector.validateParam(param)) {
throw OrmException("参数安全检查失败: ${param}")
}
}
}
MonitoringContext(sql, params, DateTime.now().toUnixTimeStamp().toMilliseconds())
}
public func endMonitoring(context: MonitoringContext, rowsAffected: Int64): Unit {
let duration = DateTime.now().toUnixTimeStamp().toMilliseconds() - context.startTime
if (enableStats) {
recordExecution(context.sql, duration, rowsAffected)
}
if (enableSlowQueryLog && duration >= slowQueryThreshold) {
slowQueryLogger.logQuery(context.sql, duration, context.params)
}
}
public func recordExecutionError(sql: String): Unit {
recordError(sql)
}
public func getSqlStats(): HashMap<String, SqlExecutionStats> {
lock.lock()
try {
let copy = HashMap<String, SqlExecutionStats>()
for ((key, value) in sqlStats) {
copy[key] = value
}
copy
} finally {
lock.unlock()
}
}
public func getSlowQueryLogs(): ArrayList<SlowQueryLog> {
slowQueryLogger.getLogs()
}
public func setSlowQueryThreshold(threshold: Int64): Unit {
this.slowQueryThreshold = threshold
slowQueryLogger.setThreshold(threshold)
}
public func setEnableStats(enable: Bool): Unit {
this.enableStats = enable
}
public func setEnableSlowQueryLog(enable: Bool): Unit {
this.enableSlowQueryLog = enable
}
public func setEnableInjectionCheck(enable: Bool): Unit {
this.enableInjectionCheck = enable
}
private func recordExecution(sql: String, duration: Int64, rowsAffected: Int64): Unit {
let sqlHash = hashSql(sql)
let sqlPreview = if (sql.size > 100) { sql[0..100] + "..." } else { sql }
lock.lock()
try {
let statsOpt = sqlStats.get(sqlHash)
let stats = if (let Some(s) <- statsOpt) {
s
} else {
let newStats = SqlExecutionStats(sqlHash, sqlPreview)
sqlStats[sqlHash] = newStats
newStats
}
stats.recordExecution(duration, rowsAffected, slowQueryThreshold)
} finally {
lock.unlock()
}
}
private func recordError(sql: String): Unit {
let sqlHash = hashSql(sql)
let sqlPreview = if (sql.size > 100) { sql[0..100] + "..." } else { sql }
lock.lock()
try {
let statsOpt = sqlStats.get(sqlHash)
let stats = if (let Some(s) <- statsOpt) {
s
} else {
let newStats = SqlExecutionStats(sqlHash, sqlPreview)
sqlStats[sqlHash] = newStats
newStats
}
stats.recordError()
} finally {
lock.unlock()
}
}
private func hashSql(sql: String): String {
if (sql.size > 50) {
sql[0..50]
} else {
sql
}
}
}
public class MonitoringContext {
public var sql: String
public var params: ArrayList<String>
public var startTime: Int64
public init(sql: String, params: ArrayList<String>, startTime: Int64) {
this.sql = sql
this.params = params
this.startTime = startTime
}
}