/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

package stdx.net.tls

import std.collection.{Map, HashMap}
import stdx.net.tls.common.*
import stdx.crypto.x509.*
import stdx.crypto.common.*
import stdx.crypto.keys.RSAPrivateKey

foreign func CJ_TLS_GenerateFakePrivateKey(bits: Int64): CString

/* For keyless provider load status */
const KEYLESS_LOAD_SUCCESS = 0i8
const KEYLESS_PROVIDER_ADD_FAILED = 1i8
const KEYLESS_PROVIDER_LOAD_FAILED = 2i8

public struct TlsServerConfig <: TlsConfig {
    private var _supportedAlpnProtocols: Array<String> = Array<String>()
    private var _certificate: (Array<X509Certificate>, PrivateKey)
    private var _dhParameters: ?DHParameters = None
    /* level 2 means DH key lengh 2048, ECDH key length 224 */
    private var _securityLevel: Int32 = 2

    /* Client certificate verify mode */
    private var _verifyMode: CertificateVerifyMode = CertificateVerifyMode.Default
    /* Supported TLS versions */
    private var _supportedVersions: Array<TlsVersion> = []
    private var _supportedCipherSuites: Map<TlsVersion, Array<String>> = HashMap<TlsVersion, Array<String>>()
    /* Whether we require client to send certificate */
    private var _clientIdentityRequired: TlsClientIdentificationMode = Disabled

    /*
     * Callback that is invoked for every handshake providing TLS initial
     * key data that is useful for debugging and decrypting a recorded
     * network dump.
     */
    public var keylogCallback: ?(TlsSocket, String) -> Unit = None

    public init(
        certChain: Array<X509Certificate>,
        certKey: PrivateKey
    ) {
        this._certificate = (certChain, certKey)
    }

    /**
     * Server certificate and the corresponding private key
     */
    mut prop serverCertificate: (Array<X509Certificate>, PrivateKey) {
        get() {
            _certificate
        }
        set(value) {
            _certificate = value
        }
    }

    public mut prop verifyMode: CertificateVerifyMode {
        get() {
            _verifyMode
        }
        set(v) {
            _verifyMode = v
        }
    }

    /**
     * A list of supported ALPN protocol names. If a client is trying to negotiate ALPN
     * providing it's list of protocol, the server TLS socket will intersect these lists and
     * negotiate a matching protocol (a protocol is matching if it exists in both server supported list and
     * client requested protocols list). Once negotiated, the resulting protocol name will be available
     * in TlsSocket instance.
     *
     * Clients that don't negotiate ALPN will connect as usual and this list will be ignored.
     *
     * @throws IllegalArgumentException while alpnList contains null character.
     */
    public mut prop supportedAlpnProtocols: Array<String> {
        get() {
            _supportedAlpnProtocols
        }
        set(v) {
            for (s in v) {
                checkString(s, "supportedAlpnProtocols")
            }
            _supportedAlpnProtocols = v
        }
    }

    public mut prop supportedVersions: Array<TlsVersion> {
        get() {
            _supportedVersions
        }
        set(v) {
            _supportedVersions = v
        }
    }

    public mut prop supportedCipherSuites: Map<TlsVersion, Array<String>> {
        get() {
            _supportedCipherSuites
        }
        set(v) {
            for ((_, cipherSuites) in v) {
                for (cipherSuite in cipherSuites) {
                    checkString(cipherSuite, "supportedCipherSuites")
                }
            }
            _supportedCipherSuites = v
        }
    }

    /**
     * Server certificate and the corresponding private key
     */
    public mut prop certificate: ?(Array<Certificate>, PrivateKey) {
        get() {
            (_certificate[0].map({c => c}), _certificate[1])
        }
        set(v) {
            let cert = v ?? throw TlsException("The server certificate cannot be null.")
            _certificate = (cert[0].map({c => c as X509Certificate ??
                throw TlsException("Only certificates of type `X509Certificate` are allowed.")}), cert[1])
        }
    }

    public mut prop clientIdentityRequired: TlsClientIdentificationMode {
        get() {
            _clientIdentityRequired
        }
        set(v) {
            _clientIdentityRequired = v
        }
    }

    /**
     * self-generated DH parameters for DH/DHE/ECDH/ECDHE ciphers.
     * When it's None, use openssl auto generated DH parameters
     */
    public mut prop dhParameters: ?DHParameters {
        get() {
            _dhParameters
        }
        set(value) {
            _dhParameters = value
        }
    }

    /**
     * securityLevel 0-5, refer to openssl SSL_CTX_set_security_level
     */
    public mut prop securityLevel: Int32 {
        get() {
            _securityLevel
        }
        set(value) {
            if (value < 0 || value > 5) {
                throw IllegalArgumentException("SecurityLevel should be from 0 to 5.")
            }
            _securityLevel = value
        }
    }
}

extend TlsContext {
    func configureServer(cfg: TlsServerConfig, session: ?TlsServerSession): Unit {
        withContext<Unit> {
            nativeContext, _ => configureServerContext(nativeContext, cfg, session)
        }
    }

    func configureServer(cfg: KeylessTlsServerConfig, session: ?TlsServerSession): Unit {
        withContext<Unit> {
            nativeContext, _ => configureServerContext(nativeContext, cfg, session)
        }
    }

    private func configureServerContext(
        context: CPointer<Ctx>,
        cfg: TlsServerConfig,
        session: ?TlsServerSession
    ): Unit {
        configureVerifyMode(cfg.verifyMode)
        configureSecurityLevel(cfg.securityLevel)
        configureClientIdentification(cfg.clientIdentityRequired)

        let (cert, key) = cfg.serverCertificate // The server certificate is not supposed to be null
        setCertificateChainAndPrivateKey(cert, key)

        setDHParam(cfg.dhParameters)

        configureServerContextProtocols(context, cfg)

        let sessionId = session?.name ?? ""
        setServerSessionId(context, sessionId)
    }

    private func configureServerContext(
        context: CPointer<Ctx>,
        cfg: KeylessTlsServerConfig,
        session: ?TlsServerSession
    ): Unit {
        configureVerifyMode(cfg.verifyMode)
        configureSecurityLevel(cfg.securityLevel)
        configureClientIdentification(cfg.clientIdentityRequired)

        var dynMsg = DynMsg()
        let status = unsafe { DYN_CJ_TLS_InitEmbeddedKeylessProvider(inout dynMsg) }
        if (status == KEYLESS_LOAD_SUCCESS) {
            checkDynMsg(dynMsg)
        } else if (status == KEYLESS_PROVIDER_ADD_FAILED) {
            throw TlsException("Failed to add builtin keyless provider.")
        } else if (status == KEYLESS_PROVIDER_LOAD_FAILED) {
            throw TlsException("Failed to load keyless provider.")
        }

        let (cert, _) = cfg.serverCertificate // The server certificate is not supposed to be null
        setCertificateChainAndCallback(cert)

        setDHParam(cfg.dhParameters)

        enableSNI(context)
        configureServerContextProtocols(context, cfg)

        let sessionId = session?.name ?? ""
        setServerSessionId(context, sessionId)
    }

    private func configureServerContextProtocols(context: CPointer<Ctx>, cfg: TlsConfig): Unit {
        enableSNI(context)

        setProtoVersions(context, cfg.supportedVersions)

        if (cfg.supportedVersions.contains(V1_2) &&
            let Some(cipherSuite) <- cfg.supportedCipherSuites.entryView(V1_2).value) {
            setCipherSuitesV1_2(context, cipherSuite)
        }

        if (cfg.supportedVersions.contains(V1_3) &&
            let Some(cipherSuite) <- cfg.supportedCipherSuites.entryView(V1_3).value) {
            setCipherSuitesV1_3(context, cipherSuite)
        }

        let alpnList = cfg.supportedAlpnProtocols
        if (!alpnList.isEmpty()) {
            setServerAlpnProtos(context, alpnList)
        }
    }

    internal func setKeylessCallbacks(
        signCallback: CKeylessSignCallback,
        decryptCallback: ?CKeylessDecryptCallback,
        exception: CPointer<ExceptionData>
    ): Unit {
        withContext<Unit> {
            nativeContext, _ => unsafe {
                let keyId = CJ_TLS_GetKeylessKeyId(nativeContext)
                if (keyId.isNull()) {
                    throw TlsException("Failed to derive keyless key identifier.")
                }
                try {
                    var dynMsg = DynMsg()
                    DYN_CJ_TLS_RegisterKeylessSignCallback(keyId, signCallback, exception, inout dynMsg)
                    checkDynMsg(dynMsg)
                    if (let Some(cb) <- decryptCallback) {
                        DYN_CJ_TLS_RegisterKeylessDecryptCallback(keyId, cb, exception, inout dynMsg)
                    }
                } finally {
                    CJ_TLS_FreeKeylessId(keyId)
                }
            }
        }
    }
}


/*
* Keyless TLS server configuration.
*/

public type KeylessSignFunc = (hashValue: Array<Byte>) -> Array<Byte>
public type KeylessDecryptFunc = (cipherText: Array<Byte>) -> Array<Byte>

type CKeylessSignCallback = CFunc<(CString, CString, CPointer<Byte>, Int64, CPointer<Int64>) -> CPointer<Byte>>
type CKeylessDecryptCallback = CFunc<(CString, CPointer<Byte>, Int64, CPointer<Int64>) -> CPointer<Byte>>

public class KeylessTlsServerConfig <: TlsConfig {
    private var _supportedAlpnProtocols: Array<String> = Array<String>()
    private var _certificate: (Array<X509Certificate>, PrivateKey)
    private var _dhParameters: ?DHParameters = None
    /* level 2 means DH key lengh 2048, ECDH key length 224 */
    private var _securityLevel: Int32 = 2

    /* Client certificate verify mode */
    private var _verifyMode: CertificateVerifyMode = CertificateVerifyMode.Default
    /* Supported TLS versions */
    private var _supportedVersions: Array<TlsVersion> = []
    private var _supportedCipherSuites: Map<TlsVersion, Array<String>> = HashMap<TlsVersion, Array<String>>()
    /* Whether we require client to send certificate */
    private var _clientIdentityRequired: TlsClientIdentificationMode = Disabled

    internal var _keylessSignFunc: KeylessSignFunc
    internal var _keylessDecryptFunc: ?KeylessDecryptFunc = None<KeylessDecryptFunc>

    /*
     * Callback that is invoked for every handshake providing TLS initial
     * key data that is useful for debugging and decrypting a recorded
     * network dump.
     */
    public var keylogCallback: ?(TlsSocket, String) -> Unit = None

    mut prop serverCertificate: (Array<X509Certificate>, PrivateKey) {
        get() {
            _certificate
        }
        set(value) {
            _certificate = value
        }
    }

    public mut prop verifyMode: CertificateVerifyMode {
        get() {
            _verifyMode
        }
        set(v) {
            _verifyMode = v
        }
    }

    /**
     * A list of supported ALPN protocol names. If a client is trying to negotiate ALPN
     * providing it's list of protocol, the server TLS socket will intersect these lists and
     * negotiate a matching protocol (a protocol is matching if it exists in both server supported list and
     * client requested protocols list). Once negotiated, the resulting protocol name will be available
     * in TlsSocket instance.
     *
     * Clients that don't negotiate ALPN will connect as usual and this list will be ignored.
     *
     * @throws IllegalArgumentException while alpnList contains null character.
     */
    public mut prop supportedAlpnProtocols: Array<String> {
        get() {
            _supportedAlpnProtocols
        }
        set(v) {
            for (s in v) {
                checkString(s, "supportedAlpnProtocols")
            }
            _supportedAlpnProtocols = v
        }
    }

    public mut prop supportedVersions: Array<TlsVersion> {
        get() {
            _supportedVersions
        }
        set(v) {
            _supportedVersions = v
        }
    }

    public mut prop supportedCipherSuites: Map<TlsVersion, Array<String>> {
        get() {
            _supportedCipherSuites
        }
        set(v) {
            for ((_, cipherSuites) in v) {
                for (cipherSuite in cipherSuites) {
                    checkString(cipherSuite, "supportedCipherSuites")
                }
            }
            _supportedCipherSuites = v
        }
    }

    /**
     * Server certificate and the corresponding private key
     */
    public mut prop certificate: ?(Array<Certificate>, PrivateKey) {
        get() {
            (_certificate[0].map({c => c}), _certificate[1])
        }
        set(v) {
            let cert = v ?? throw TlsException("The server certificate cannot be null.")
            _certificate = (cert[0].map({c => c as X509Certificate ??
                throw TlsException("Only certificates of type `X509Certificate` are allowed.")}), cert[1])
        }
    }

    public mut prop clientIdentityRequired: TlsClientIdentificationMode {
        get() {
            _clientIdentityRequired
        }
        set(v) {
            _clientIdentityRequired = v
        }
    }

    /**
     * self-generated DH parameters for DH/DHE/ECDH/ECDHE ciphers.
     * When it's None, use openssl auto generated DH parameters
     */
    public mut prop dhParameters: ?DHParameters {
        get() {
            _dhParameters
        }
        set(value) {
            _dhParameters = value
        }
    }

    /**
     * securityLevel 0-5, refer to openssl SSL_CTX_set_security_level
     */
    public mut prop securityLevel: Int32 {
        get() {
            _securityLevel
        }
        set(value) {
            if (value < 0 || value > 5) {
                throw IllegalArgumentException("SecurityLevel should be from 0 to 5.")
            }
            _securityLevel = value
        }
    }



    public init(certChain: Array<X509Certificate>, signCallback: KeylessSignFunc, decryptCallback!: ?KeylessDecryptFunc = None<KeylessDecryptFunc>) {
        if (certChain.isEmpty()) {
            throw IllegalArgumentException("The server certificate cannot be empty.")
        }
        this._keylessSignFunc = signCallback
        this._keylessDecryptFunc = decryptCallback
        this._certificate = (certChain, RSAPrivateKey(2048)) // Dummy private key, not used in keyless TLS
    }
}