/*
 * 文件名: enterprise.cj
 * 功能: ORM模块企业级增强
 * 说明: 提供性能监控、慢查询日志、连接池健康检查等企业级功能
 */

package tybb2026::tycj_orm

import std.collection.*
import std.time.*
import std.sync.*
import std.convert.*

// ============================================================================
// SQL执行统计
// ============================================================================

/**
 * SQL执行统计信息
 * 企业级监控:记录执行次数、耗时、影响行数等
 */
public class SqlExecutionStats {
    public var sqlHash: String = ""
    public var sqlPreview: String = ""
    public var executionCount: Int64 = 0
    public var totalExecutionTime: Int64 = 0
    public var maxExecutionTime: Int64 = 0
    public var minExecutionTime: Int64 = 9223372036854775807
    public var avgExecutionTime: Int64 = 0
    public var totalRowsAffected: Int64 = 0
    public var errorCount: Int64 = 0
    public var lastExecutionTime: Int64 = 0
    public var slowQueryCount: Int64 = 0

    public init() {}

    public init(sqlHash: String, sqlPreview: String) {
        this.sqlHash = sqlHash
        this.sqlPreview = sqlPreview
    }

    public func recordExecution(duration: Int64, rowsAffected: Int64, slowThreshold: Int64): Unit {
        executionCount += 1
        totalExecutionTime += duration
        lastExecutionTime = duration
        totalRowsAffected += rowsAffected

        if (duration > maxExecutionTime) {
            maxExecutionTime = duration
        }
        if (duration < minExecutionTime) {
            minExecutionTime = duration
        }

        avgExecutionTime = totalExecutionTime / executionCount

        if (duration >= slowThreshold) {
            slowQueryCount += 1
        }
    }

    public func recordError(): Unit {
        errorCount += 1
    }

    public func toJson(): String {
        "{\"sqlHash\":\"${sqlHash}\",\"sqlPreview\":\"${escapeJson(sqlPreview)}\"," +
        "\"executionCount\":${executionCount},\"avgExecutionTime\":${avgExecutionTime}," +
        "\"maxExecutionTime\":${maxExecutionTime},\"minExecutionTime\":${minExecutionTime}," +
        "\"totalRowsAffected\":${totalRowsAffected},\"errorCount\":${errorCount}," +
        "\"slowQueryCount\":${slowQueryCount}}"
    }

    private func escapeJson(s: String): String {
        var result = s
        // 按照 JSON 规范转义特殊字符
        result = result.replace("\\", "\\\\")  // 反斜杠
        result = result.replace("\"", "\\\"")  // 双引号
        result = result.replace("\n", "\\n")   // 换行
        result = result.replace("\r", "\\r")   // 回车
        result = result.replace("\t", "\\t")   // 制表符
        result = result.replace("\b", "\\b")   // 退格
        result = result.replace("\f", "\\f")   // 换页
        // 处理控制字符 (0x00-0x1F)
        result = escapeControlChars(result)
        result
    }

    // 转义控制字符
    private func escapeControlChars(s: String): String {
        var result = ""
        for (i in 0..s.size) {
            let c = s[i]
            if (c < 32u8) {
                // 控制字符使用 \uXXXX 格式
                result = result + "\\u" + formatHex(c)
            } else {
                result = result + String.fromUtf8([c])
            }
        }
        result
    }

    // 将字节格式化为4位十六进制
    private func formatHex(b: UInt8): String {
        let hexChars = "0123456789ABCDEF"
        let high = Int64((b >> 4) & 0x0F)
        let low = Int64(b & 0x0F)
        String.fromUtf8([hexChars[high], hexChars[low], hexChars[0], hexChars[0]])
    }
}

// ============================================================================
// 慢查询日志
// ============================================================================

public class SlowQueryLog {
    public var timestamp: Int64 = 0
    public var sql: String = ""
    public var duration: Int64 = 0
    public var params: ArrayList<String> = ArrayList<String>()

    public init() {}

    public init(sql: String, duration: Int64, params: ArrayList<String>) {
        this.timestamp = DateTime.now().toUnixTimeStamp().toMilliseconds()
        this.sql = sql
        this.duration = duration
        this.params = params
    }

    public func toLogString(): String {
        let timeStr = DateTime.now().toString()
        "[SLOW QUERY] ${timeStr} - ${duration}ms - ${sql}"
    }
}

public class SlowQueryLogger {
    private var slowQueryThreshold: Int64 = 1000
    private var maxLogSize: Int64 = 1000
    private var logs: ArrayList<SlowQueryLog> = ArrayList<SlowQueryLog>()
    private var lock: ReentrantMutex = ReentrantMutex()

    public init() {}

    public init(threshold: Int64, maxSize: Int64) {
        this.slowQueryThreshold = threshold
        this.maxLogSize = maxSize
    }

    public func logQuery(sql: String, duration: Int64, params: ArrayList<String>): Unit {
        if (duration < slowQueryThreshold) {
            return
        }

        lock.lock()
        try {
            let log = SlowQueryLog(sql, duration, params)
            logs.add(log)

            if (logs.size > maxLogSize) {
                // 移除最早的日志
                logs.remove(0..1)
            }

            println(log.toLogString())
        } finally {
            lock.unlock()
        }
    }

    public func getLogs(): ArrayList<SlowQueryLog> {
        lock.lock()
        try {
            // 使用 clone 方法复制列表,避免创建新列表的开销
            ArrayList<SlowQueryLog>(logs)
        } finally {
            lock.unlock()
        }
    }

    public func clear(): Unit {
        lock.lock()
        try {
            logs.clear()
        } finally {
            lock.unlock()
        }
    }

    public func getSlowQueryCount(): Int64 {
        lock.lock()
        try {
            logs.size
        } finally {
            lock.unlock()
        }
    }

    public func setThreshold(threshold: Int64): Unit {
        this.slowQueryThreshold = threshold
    }
}

// ============================================================================
// 连接池健康检查
// ============================================================================

public class ConnectionPoolHealth {
    public var healthy: Bool = true
    public var activeConnections: Int64 = 0
    public var idleConnections: Int64 = 0
    public var totalConnections: Int64 = 0
    public var lastCheckTime: Int64 = 0
    public var issues: ArrayList<String> = ArrayList<String>()

    public init() {}

    public func addIssue(issue: String): Unit {
        issues.add(issue)
        healthy = false
    }

    public func toJson(): String {
        var result = "{\"healthy\":${healthy},\"activeConnections\":${activeConnections}," +
                     "\"idleConnections\":${idleConnections},\"totalConnections\":${totalConnections}," +
                     "\"issues\":["

        var first = true
        for (issue in issues) {
            if (!first) {
                result += ","
            }
            first = false
            result += "\"${issue}\""
        }

        result += "]}"
        result
    }
}

public class ConnectionPoolHealthChecker {
    private var pool: ConnectionPool
    private var maxActiveRatio: Float64 = 0.8

    public init(pool: ConnectionPool) {
        this.pool = pool
    }

    public func checkHealth(): ConnectionPoolHealth {
        let health = ConnectionPoolHealth()
        health.lastCheckTime = DateTime.now().toUnixTimeStamp().toMilliseconds()

        health.activeConnections = pool.getActiveConnections()
        health.totalConnections = pool.getTotalConnections()
        health.idleConnections = pool.getIdleConnections()

        if (health.totalConnections > 0) {
            let activeF = Float64.parse(health.activeConnections.toString())
            let totalF = Float64.parse(health.totalConnections.toString())
            let activeRatio = activeF / totalF
            if (activeRatio > maxActiveRatio) {
                health.addIssue("活跃连接比例过高")
            }
        }

        if (health.totalConnections == 0) {
            health.addIssue("连接池为空")
        }

        health
    }
}

// ============================================================================
// SQL注入防护
// ============================================================================

public class SqlInjectionDetector {
    private var dangerousKeywords: HashSet<String> = HashSet<String>()
    private var dangerousPatterns: ArrayList<String> = ArrayList<String>()

    public init() {
        dangerousKeywords.add("DROP")
        dangerousKeywords.add("DELETE")
        dangerousKeywords.add("TRUNCATE")
        dangerousKeywords.add("ALTER")
        dangerousKeywords.add("EXEC")
        dangerousKeywords.add("EXECUTE")
        dangerousKeywords.add("UNION")
        dangerousKeywords.add("INSERT")
        dangerousKeywords.add("UPDATE")

        dangerousPatterns.add("--")
        dangerousPatterns.add("/*")
        dangerousPatterns.add("*/")
        dangerousPatterns.add(";")
    }

    // 检测 SQL 中的危险关键字(仅用于日志/告警,不替代参数化查询)
    public func detect(sql: String): String {
        for (keyword in dangerousKeywords) {
            if (containsIgnoreCase(sql, keyword)) {
                if (!isInString(sql, keyword)) {
                    return "检测到潜在危险关键字: ${keyword}"
                }
            }
        }

        for (pattern in dangerousPatterns) {
            if (sql.contains(pattern)) {
                return "检测到潜在危险模式: ${pattern}"
            }
        }

        ""
    }

    // 验证参数安全性
    // 注意:此方法仅作为辅助检查,真正的 SQL 注入防护应依赖参数化查询
    // 白名单模式:只允许安全字符通过
    public func validateParam(param: String): Bool {
        // 空参数始终安全
        if (param.isEmpty()) { return true }
        
        // 检查是否包含可能导致 SQL 注入的元字符
        // 使用更精确的检测,避免误报
        var i: Int64 = 0
        while (i < param.size) {
            let c = param[i]
            // 单引号后跟 SQL 关键字模式才是危险的
            if (c == 39u8) { // '\''
                // 检查引号后是否紧跟危险关键字
                if (i + 1 < param.size) {
                    var rest = ""
                    var j = i + 1
                    while (j < param.size) {
                        rest = rest + String.fromUtf8([param[j]])
                        j++
                    }
                    let upperRest = toUpperCase(rest)
                    if (upperRest.startsWith("OR ") || upperRest.startsWith("AND ") ||
                        upperRest.startsWith("--") || upperRest.startsWith(";")) {
                        return false
                    }
                }
            }
            i++
        }
        true
    }

    private func containsIgnoreCase(s: String, sub: String): Bool {
        let upperS = toUpperCase(s)
        let upperSub = toUpperCase(sub)
        upperS.contains(upperSub)
    }

    private func toUpperCase(s: String): String {
        // 使用 StringBuilder 提高性能
        let sb = StringBuilder()
        for (i in 0..s.size) {
            let c = s[i]
            if (c >= 97u8 && c <= 122u8) {
                sb.append(String.fromUtf8([c - 32u8]))
            } else {
                sb.append(String.fromUtf8([c]))
            }
        }
        sb.toString()
    }

    // 检查关键字是否在字符串字面量内(修正版)
    // 返回 true 表示关键字在字符串内(安全),false 表示在字符串外(需检查)
    private func isInString(sql: String, keyword: String): Bool {
        var inSingleQuote = false
        var i: Int64 = 0
        let keywordLen = keyword.size

        while (i < sql.size) {
            let c = sql[i]
            if (c == 39u8) { // '\''
                // 检查是否为转义引号 ''
                if (i + 1 < sql.size && sql[i + 1] == 39u8) {
                    i += 2
                    continue
                }
                inSingleQuote = !inSingleQuote
            } else if (!inSingleQuote && i + keywordLen <= sql.size) {
                // 检查是否匹配关键字(不在字符串内时)
                var isMatch = true
                for (j in 0..keywordLen) {
                    if (toUpperCase(String.fromUtf8([sql[i + j]])) !=
                        toUpperCase(String.fromUtf8([keyword[j]]))) {
                        isMatch = false
                        break
                    }
                }
                if (isMatch) {
                    return false // 关键字在字符串外,需要检查
                }
            }
            i++
        }
        true // 关键字在字符串内或不存在,视为安全
    }
}

// ============================================================================
// ORM异常
// ============================================================================

public class OrmException <: Exception {
    public init(message: String) {
        super(message)
    }
}

// ============================================================================
// SQL执行监控器
// ============================================================================

public class SqlExecutionMonitor {
    private var sqlStats: HashMap<String, SqlExecutionStats> = HashMap<String, SqlExecutionStats>()
    private var slowQueryLogger: SlowQueryLogger = SlowQueryLogger()
    private var injectionDetector: SqlInjectionDetector = SqlInjectionDetector()
    private var lock: ReentrantMutex = ReentrantMutex()

    private var slowQueryThreshold: Int64 = 1000
    private var enableStats: Bool = true
    private var enableSlowQueryLog: Bool = true
    private var enableInjectionCheck: Bool = true

    public init() {}

    public func startMonitoring(sql: String, params: ArrayList<String>): MonitoringContext {
        if (enableInjectionCheck) {
            let checkResult = injectionDetector.detect(sql)
            if (!checkResult.isEmpty()) {
                throw OrmException("SQL安全检查失败: ${checkResult}")
            }

            for (param in params) {
                if (!injectionDetector.validateParam(param)) {
                    throw OrmException("参数安全检查失败: ${param}")
                }
            }
        }

        MonitoringContext(sql, params, DateTime.now().toUnixTimeStamp().toMilliseconds())
    }

    public func endMonitoring(context: MonitoringContext, rowsAffected: Int64): Unit {
        let duration = DateTime.now().toUnixTimeStamp().toMilliseconds() - context.startTime

        if (enableStats) {
            recordExecution(context.sql, duration, rowsAffected)
        }

        if (enableSlowQueryLog && duration >= slowQueryThreshold) {
            slowQueryLogger.logQuery(context.sql, duration, context.params)
        }
    }

    public func recordExecutionError(sql: String): Unit {
        recordError(sql)
    }

    public func getSqlStats(): HashMap<String, SqlExecutionStats> {
        lock.lock()
        try {
            let copy = HashMap<String, SqlExecutionStats>()
            for ((key, value) in sqlStats) {
                copy[key] = value
            }
            copy
        } finally {
            lock.unlock()
        }
    }

    public func getSlowQueryLogs(): ArrayList<SlowQueryLog> {
        slowQueryLogger.getLogs()
    }

    public func setSlowQueryThreshold(threshold: Int64): Unit {
        this.slowQueryThreshold = threshold
        slowQueryLogger.setThreshold(threshold)
    }

    public func setEnableStats(enable: Bool): Unit {
        this.enableStats = enable
    }

    public func setEnableSlowQueryLog(enable: Bool): Unit {
        this.enableSlowQueryLog = enable
    }

    public func setEnableInjectionCheck(enable: Bool): Unit {
        this.enableInjectionCheck = enable
    }

    private func recordExecution(sql: String, duration: Int64, rowsAffected: Int64): Unit {
        let sqlHash = hashSql(sql)
        let sqlPreview = if (sql.size > 100) { sql[0..100] + "..." } else { sql }

        lock.lock()
        try {
            let statsOpt = sqlStats.get(sqlHash)
            let stats = if (let Some(s) <- statsOpt) {
                s
            } else {
                let newStats = SqlExecutionStats(sqlHash, sqlPreview)
                sqlStats[sqlHash] = newStats
                newStats
            }
            stats.recordExecution(duration, rowsAffected, slowQueryThreshold)
        } finally {
            lock.unlock()
        }
    }

    private func recordError(sql: String): Unit {
        let sqlHash = hashSql(sql)
        let sqlPreview = if (sql.size > 100) { sql[0..100] + "..." } else { sql }

        lock.lock()
        try {
            let statsOpt = sqlStats.get(sqlHash)
            let stats = if (let Some(s) <- statsOpt) {
                s
            } else {
                let newStats = SqlExecutionStats(sqlHash, sqlPreview)
                sqlStats[sqlHash] = newStats
                newStats
            }
            stats.recordError()
        } finally {
            lock.unlock()
        }
    }

    private func hashSql(sql: String): String {
        if (sql.size > 50) {
            sql[0..50]
        } else {
            sql
        }
    }
}

public class MonitoringContext {
    public var sql: String
    public var params: ArrayList<String>
    public var startTime: Int64

    public init(sql: String, params: ArrayList<String>, startTime: Int64) {
        this.sql = sql
        this.params = params
        this.startTime = startTime
    }
}