/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

package stdx.net.tls

import stdx.net.tls.common.*
import stdx.crypto.x509.*
import stdx.crypto.common.*

enum HandshakeConfig {
    | Client(TlsClientConfig, ?TlsClientSession) // cjlint-ignore !G.ENU.01
    | Server(TlsServerConfig, ?TlsServerSession) // cjlint-ignore !G.ENU.01
    | KeylessServer(KeylessTlsServerConfig, ?TlsServerSession)
}

extend TlsContext {

    func setCertificateChainAndCallback(chain: Array<X509Certificate>): Unit {
        withContext<Unit> {
            nativeContext, exception => setCertificateChainAndCallback(nativeContext, chain, exception)
        }
    }
    func setCertificateChainAndPrivateKey(chain: Array<X509Certificate>, key: PrivateKey): Unit {
        withContext<Unit> {
            nativeContext, exception => setCertificateChainAndPrivateKey(nativeContext, chain, key, exception)
        }
    }

    func setDHParam(key: ?DHParameters): Unit {
        withContext<Unit> {
            nativeContext, exception => setDHParam(nativeContext, key, exception)
        }
    }

    func configureVerifyMode(verifyMode: CertificateVerifyMode): Unit {
        withContext<Unit> {
            nativeContext, exception => configureVerifyMode(nativeContext, verifyMode, exception)
        }
    }

    func configureClientIdentification(verifyMode: TlsClientIdentificationMode) {
        withContext<Unit> {
            nativeContext, _ => configureClientVerifyMode(nativeContext, verifyMode)
        }
    }

    func configureSecurityLevel(level: Int32): Unit {
        withContext<Unit> {
            nativeContext, _ =>
            let ret = unsafe { CJ_TLS_SetSecurityLevel(nativeContext, level) }
            if (ret == 0) {
                throw TlsException("Failed to set tls socket security level.")
            }
        }
    }

    private func signCallbackWrapper(): CKeylessSignCallback {
        return {
            key: CString, alg: CString, digestPtr: CPointer<Byte>, digestLen: Int64, written: CPointer<Int64> =>
                let (signCb, _) = TlsSocket.keylessCallback.get(key.toString()) ?? throw TlsException("Missing keyless sign callback.")

                let digest = Array<Byte>(digestLen, repeat: 0)
                for (i in 0..digestLen) {
                    unsafe { digest[i] = digestPtr.read(i) }
                }

                var signature: Array<Byte> = signCb(digest)
                let sigPtr = unsafe { LibC.malloc<Byte>(count: signature.size) }

                unsafe {
                    for (i in 0..signature.size) {
                        sigPtr.write(i, signature[i])
                    }

                    written.write(signature.size)
                }
                return sigPtr
        }
    }

    private func decryptCallbackWrapper(): CKeylessDecryptCallback {
        return {
            key: CString, ciphertextPtr: CPointer<Byte>, ciphertextLen: Int64, written: CPointer<Int64> =>
                let (_, decryptCbOpt) = TlsSocket.keylessCallback.get(key.toString()) ?? throw TlsException("Missing keyless decrypt callback.")
                let decryptCb = decryptCbOpt ?? throw TlsException("Decrypt callback is not provided.")

                let ciphertext = Array<Byte>(ciphertextLen, repeat: 0)
                for (i in 0..ciphertextLen) {
                    unsafe { ciphertext[i] = ciphertextPtr.read(i) }
                }
                let plain: Array<Byte> = decryptCb(ciphertext)
                let cipherPtr = unsafe { LibC.malloc<Byte>(count: plain.size) }

                unsafe {
                    for (i in 0..plain.size) {
                        cipherPtr.write(i, plain[i])
                    }
                    written.write(plain.size)
                }
                plain.fill(0)
                return cipherPtr

        }
    }

    private func setCertificateChainAndCallback(
        context: CPointer<Ctx>,
        chain: Array<X509Certificate>, // should not be empty, no need to check
        exception: CPointer<ExceptionData>
    ): Unit {
        var firstAdded = false
        for (cert in chain) {
            let entry = PemEntry(PemEntry.LABEL_CERTIFICATE, cert.encodeToDer())
            useAddCertImpl(context, entry, add: firstAdded, exception: exception)
            firstAdded = true
        }
        setKeylessPrivateKeyImpl(context, exception)
        setKeylessCallbacks(signCallbackWrapper(), decryptCallbackWrapper(), exception)
    }

    private func setCertificateChainAndPrivateKey(
        context: CPointer<Ctx>,
        chain: Array<X509Certificate>,
        key: PrivateKey,
        exception: CPointer<ExceptionData>
    ): Unit {
        if (chain.isEmpty()) {
            throw IllegalArgumentException("Certificate chain is empty: at least one certificate required.")
        }

        setPrivateKeyImpl(context, key.encodeToDer(), exception)

        var ifNotFirst = false
        for (cert in chain) {
            let entry = PemEntry(PemEntry.LABEL_CERTIFICATE, cert.encodeToDer())
            useAddCertImpl(context, entry, add: ifNotFirst, exception: exception)
            ifNotFirst = true
        }

        unsafe {
            let ret = CJ_TLS_CheckPrivateKey(context)
            if (ret <= 0) {
                throw TlsException("The Certificate chain doesn't match the private key.")
            }
        }
    }

    private func setKeylessPrivateKeyImpl(
        context: CPointer<Ctx>,
        exception: CPointer<ExceptionData>
    ) {
        let setKeyResult = unsafe { CJ_TLS_SetKeylessPrivateKey(context, exception) }
        if (setKeyResult == 0) {
            unsafe { exception.read().throwException(fallback: "Failed to set keyless private key.") }
        }
    }

    private func setPrivateKeyImpl(
        context: CPointer<Ctx>,
        key: DerBlob,
        exception: CPointer<ExceptionData>
    ) {
        let rawContent = key.body
        let rawContentSize = UIntNative(rawContent.size)
        let setKeyResult = unsafe {
            let pinned = acquireArrayRawData(rawContent)
            try {
                CJ_TLS_SetPrivateKey(
                    context,
                    pinned.pointer,
                    rawContentSize,
                    exception
                )
            } finally {
                releaseArrayRawData(pinned)
            }
        }
        if (setKeyResult == 0) {
            unsafe { exception.read().throwException(fallback: "Failed to set private key.") }
        }
    }

    private func useAddCertImpl(
        context: CPointer<Ctx>,
        entry: PemEntry,
        add!: Bool,
        exception!: CPointer<ExceptionData>
    ) {
        let useCertResult = unsafe {
            withPemImpl<Int32>(entry) {
                pointer, size => match (add) {
                    case true => CJ_TLS_Add_Cert(context, pointer, size, exception)
                    case false => CJ_TLS_Use_Cert(context, pointer, size, exception)
                }
            }
        }

        if (useCertResult == 0) {
            unsafe { exception.read().throwException(fallback: "Failed to add certificate.") }
        }
    }

    private func setDHParam(
        context: CPointer<Ctx>,
        key: ?DHParameters,
        exception: CPointer<ExceptionData>
    ): Unit {
        let result = if (let Some(v) <- key) {
            let entry = PemEntry(PemEntry.LABEL_DH_PARAMETERS, v.encodeToDer())
            unsafe {
                withPemImpl<Int32>(entry) {
                    pointer, size => CJ_TLS_SetDHParam(context, pointer, size, exception)
                }
            }
        } else {
            unsafe { CJ_TLS_SetDHParam(context, CPointer<Unit>(), 0, exception) }
        }
        if (result == 0) {
            unsafe { exception.read().throwException(fallback: "Failed to set DH Parameters.") }
        }
    }

    private func withPemImpl<R>(entry: PemEntry, block: (CPointer<Unit>, UIntNative) -> R): R {
        let pemText = Pem([entry]).encode().toArray()
        let pemSize = UIntNative(pemText.size)

        unsafe {
            let pinned = acquireArrayRawData(pemText)
            try {
                block(CPointer(pinned.pointer), pemSize)
            } finally {
                releaseArrayRawData(pinned)
            }
        }
    }

    private func configureVerifyMode(
        context: CPointer<Ctx>,
        verifyMode: CertificateVerifyMode,
        exception: CPointer<ExceptionData>
    ): Unit {
        match (verifyMode) {
            case Default => configureCustomCA(context, X509Certificate.systemRootCerts(), exception)
            case TrustAll => setContextTrustAll(context)
            case CustomCA(items) =>
                let ca = items.map({c => c as X509Certificate ??
                    throw TlsException("Only certificates of type `X509Certificate` are allowed.")})

                configureCustomCA(context, ca, exception)
                configureCustomCA(context, X509Certificate.systemRootCerts(), exception)
            case CustomVerify(callback) => setContextCustomVerify(callback)
            case _ => TlsException("Unsupported verify mode.")
        }
    }

    private func configureCustomCA(
        context: CPointer<Ctx>,
        items: Array<X509Certificate>,
        exception: CPointer<ExceptionData>
    ) {
        for (ca in items) {
            let pem = ca.encodeToPem().encode().toArray()
            let bytes = unsafe { acquireArrayRawData(pem) }
            try {
                let addResult = unsafe { CJ_TLS_Add_CA(context, CPointer(bytes.pointer), UIntNative(pem.size), exception
                ) }
                if (addResult == 0) {
                    throw TlsException("Failed to assign CA certificate.")
                }
            } finally {
                unsafe { releaseArrayRawData(bytes) }
            }
        }
    }

    private func setContextTrustAll(context: CPointer<Ctx>) {
        let ret = unsafe { CJ_TLS_SetTrustAll(context) }
        if (ret == 0) {
            throw TlsException("Failed to set tls socket verify mode to trust all.")
        }
    }

    private func configureClientVerifyMode(context: CPointer<Ctx>, verifyMode: TlsClientIdentificationMode) {
        let result = unsafe {
            match (verifyMode) {
                case Disabled => CJ_TLS_SetClientVerifyMode(context, 0, 0) // not required, not verify
                case Optional => CJ_TLS_SetClientVerifyMode(context, 0, 1) // not required, do verify
                case Required => CJ_TLS_SetClientVerifyMode(context, 1, 1) // required, do verify
            }
        }

        if (result == 0) {
            throw TlsException("Failed to configure TlsClientIdentificationMode.")
        }
    }

    private func setContextCustomVerify(callback: CertificateVerifyCallbackFunction): Unit { // cjlint-ignore !G.FUN.02
        withContext<Unit> {
            nativeContext, _ =>
                var dynMsg = DynMsg()
                unsafe { CJ_TLS_DYN_SetCustomVerifyMode(nativeContext, customVerifyCallback, inout dynMsg) }
                checkDynMsg(dynMsg)
        }
    }
}

foreign {
    func CJ_TLS_DYN_Add_CA(
        context: CPointer<Ctx>,
        ca: CPointer<Unit>,
        length: UIntNative,
        exception: CPointer<ExceptionData>,
        dynMsg: CPointer<DynMsg>
    ): Int32

    func CJ_TLS_DYN_Add_Cert(
        context: CPointer<Ctx>,
        pem: CPointer<Unit>,
        length: UIntNative,
        exception: CPointer<ExceptionData>,
        dynMsg: CPointer<DynMsg>
    ): Int32

    func CJ_TLS_DYN_Use_Cert(
        context: CPointer<Ctx>,
        pem: CPointer<Unit>,
        length: UIntNative,
        exception: CPointer<ExceptionData>,
        dynMsg: CPointer<DynMsg>
    ): Int32

    func CJ_TLS_DYN_SetPrivateKey(
        context: CPointer<Ctx>,
        key: CPointer<Byte>,
        length: UIntNative,
        exception: CPointer<ExceptionData>,
        dynMsg: CPointer<DynMsg>
    ): Int32

    func CJ_TLS_DYN_SetKeylessPrivateKey(
        context: CPointer<Ctx>,
        exception: CPointer<ExceptionData>,
        dynMsg: CPointer<DynMsg>
    ): Int32

    func CJ_TLS_DYN_SetDHParam(
        context: CPointer<Ctx>,
        key: CPointer<Unit>,
        length: UIntNative,
        exception: CPointer<ExceptionData>,
        dynMsgPtr: CPointer<DynMsg>
    ): Int32

    func CJ_TLS_DYN_SetTrustAll(context: CPointer<Ctx>, dynMsg: CPointer<DynMsg>): Int32

    func CJ_TLS_DYN_SetClientVerifyMode(context: CPointer<Ctx>, required: Int32, verify: Int32, dynMsg: CPointer<DynMsg>): Int32

    func CJ_TLS_DYN_SetSecurityLevel(context: CPointer<Ctx>, level: Int32, dynMsg: CPointer<DynMsg>): Int32

    func CJ_TLS_DYN_SetCustomVerifyMode(
        context: CPointer<Ctx>,
        verifyCallback: CFunc<(context: CPointer<Ctx>, chain: CPointer<CertChainItem>, count: Int32) -> Int32>,
        dynMsg: CPointer<DynMsg>): Int32
}

func CJ_TLS_Add_CA(
    context: CPointer<Ctx>,
    ca: CPointer<Unit>,
    length: UIntNative,
    exception: CPointer<ExceptionData>
): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_Add_CA(context, ca, length, exception, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_Add_Cert(
    context: CPointer<Ctx>,
    pem: CPointer<Unit>,
    length: UIntNative,
    exception: CPointer<ExceptionData>
): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_Add_Cert(context, pem, length, exception, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetPrivateKey(
    context: CPointer<Ctx>,
    key: CPointer<Byte>,
    length: UIntNative,
    exception: CPointer<ExceptionData>
): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetPrivateKey(context, key, length, exception, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetKeylessPrivateKey(
    context: CPointer<Ctx>,
    exception: CPointer<ExceptionData>
): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetKeylessPrivateKey(context, exception, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetDHParam(
    context: CPointer<Ctx>,
    key: CPointer<Unit>,
    length: UIntNative,
    exception: CPointer<ExceptionData>
): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetDHParam(context, key, length, exception, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetTrustAll(context: CPointer<Ctx>): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetTrustAll(context, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetClientVerifyMode(context: CPointer<Ctx>, required: Int32, verify: Int32): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetClientVerifyMode(context, required, verify, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetSecurityLevel(context: CPointer<Ctx>, level: Int32): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetSecurityLevel(context, level, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_Use_Cert(
    context: CPointer<Ctx>,
    pem: CPointer<Unit>,
    length: UIntNative,
    exception: CPointer<ExceptionData>
): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_Use_Cert(context, pem, length, exception, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

@C
func customVerifyCallback(context: CPointer<Ctx>, chain: CPointer<CertChainItem>, count: Int32): Int32 {
    let certificatesToVerify: Array<Certificate>
    if (chain.isNull() || count == 0) {
        certificatesToVerify = []
    } else {
        try {
            let certificates = unsafe { TlsRawSocket.convertNativeChain(chain, Int64(count)) } ?? []
            certificatesToVerify = certificates.map({c => c})
        } finally {
            unsafe { CJ_TLS_DYN_CertChainFree(chain, Int64(count), CPointer()) }
        }
    }

    if (let Some(bridge) <- Bridge.findByContext(context)) {
        if (let Some(verify) <- bridge.certificateVerifyCallback) {
            let passed = verify(certificatesToVerify)
            return if (passed) { 1 } else { 0 }
        }
    }
    return 1
}