/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

package stdx.net.tls

import stdx.net.tls.common.*

public struct CipherSuite <: ToString & Equatable<CipherSuite> {
    CipherSuite(let name: String) {
        if (name.isEmpty()) {
            throw IllegalArgumentException("Cipher suite name cannot be empty.")
        }
    }

    public func toString(): String {
        name
    }

    public operator func ==(that: CipherSuite): Bool {
        this.name == that.name
    }

    public operator func !=(that: CipherSuite): Bool {
        !(this == that)
    }

    public static prop allSupported: Array<CipherSuite> {
        get() {
            let suites = unsafe { CJ_TLS_GetAllCipherSuites() }
            if (suites.isNull()) {
                throw TlsException("Unable to get supported cipher suites.")
            }

            let result = ArrayBuilder<CipherSuite>()

            var idx: Int64 = 0
            while (true) {
                let suite: CPointer<TlsCipherSuite> = unsafe { suites.read(idx) }
                if (suite.isNull()) {
                    break
                }

                try {
                    let cipherSuiteName = unsafe { suite.read().name.toString() }
                    result.append(CipherSuite(cipherSuiteName))
                } finally {
                    freeTlsCipherSuite(suite)
                }

                idx++
            }
            unsafe { LibC.free(suites) }

            return result.toArray()
        }
    }
}

const CJTLS_PROTO_VERSION_1_2: Int32 = 0 /* TLS1.2 */
const CJTLS_PROTO_VERSION_1_3: Int32 = 1 /* TLS1.3 */
const CJTLS_PROTO_VERSION_BUTT: Int32 = 2

@C
struct TlsCipherSuite {
    let name: CString = CString(CPointer())
}

func freeTlsCipherSuite(suite: CPointer<TlsCipherSuite>): Unit {
    unsafe { CRYPTO_free(suite.read().name.getChars()) }
    unsafe { LibC.free(suite) }
}

func cipherToOpenSslFormat(stdName: String): String {
    if (!stdName.startsWith("TLS_")) {
        return stdName
    }
    var result: String = ""
    unsafe {
        try (cName = LibC.mallocCString(stdName).asResource()) {
            let convert = OPENSSL_cipher_name(cName.value) // should not free
            result = if (convert.toString() == "(NONE)") {
                ""
            } else {
                convert.toString()
            }
        }
    }
    return result
}

extend TlsRawSocket {
    func getCipherSuite(): CipherSuite {
        otherNonIO<CipherSuite> {
            nativeStream, _ =>
            let cipherSuiteRaw: CPointer<TlsCipherSuite> = unsafe { CJ_TLS_GetCipherSuite(nativeStream) }
            if (cipherSuiteRaw.isNull()) {
                throw TlsException("Unable to get cipherSuite.")
            }

            try {
                let cipherSuiteName = unsafe { cipherSuiteRaw.read().name.toString() }
                let cipherSuite = CipherSuite(cipherSuiteName)

                return cipherSuite
            } finally {
                freeTlsCipherSuite(cipherSuiteRaw)
            }
        }
    }
}

/*
 * @throws TlsException while cipher suites of TLS1.2 is set failed,
 * or the TLS versions does not contain 1.2.
 */
func setCipherSuitesV1_2(ctx: CPointer<Ctx>, cipherSuites: Array<String>): Unit {
    if (cipherSuites.isEmpty()) {
        return
    }

    let cipherSuiteStr = String.join(cipherSuites.map(cipherToOpenSslFormat), delimiter: ":")
    unsafe {
        let cStr = LibC.mallocCString(cipherSuiteStr)
        try {
            let ret = CJ_TLS_SetCipherList(ctx, cStr)
            if (ret <= 0) {
                throw TlsException("Failed to set cipher suites of TLS1.2.")
            }
        } finally {
            LibC.free(cStr)
        }
    }
}

/*
 * @throws TlsException while cipher suites of TLS1.3 is set failed,
 * or the TLS versions does not contain 1.3.
 */
func setCipherSuitesV1_3(ctx: CPointer<Ctx>, cipherSuites: Array<String>): Unit {
    if (cipherSuites.isEmpty()) {
        return
    }

    let cipherSuiteStr = String.join(cipherSuites, delimiter: ":")
    unsafe {
        let cStr = LibC.mallocCString(cipherSuiteStr)
        try {
            let ret = CJ_TLS_SetCipherSuites(ctx, cStr)
            if (ret <= 0) {
                throw TlsException("Failed to set cipher suites of TLS1.3.")
            }
        } finally {
            LibC.free(cStr)
        }
    }
}

foreign {
    func CJ_TLS_DYN_GetCipherSuite(stream: CPointer<Ssl>, dynMsgPtr: CPointer<DynMsg>): CPointer<TlsCipherSuite>

    func CJ_TLS_DYN_GetAllCipherSuites(dynMsgPtr: CPointer<DynMsg>): CPointer<CPointer<TlsCipherSuite>>

    func CJ_TLS_DYN_SetCipherList(ctx: CPointer<Ctx>, str: CString, dynMsgPtr: CPointer<DynMsg>): Int32

    func CJ_TLS_DYN_SetCipherSuites(ctx: CPointer<Ctx>, str: CString, dynMsgPtr: CPointer<DynMsg>): Int32

    func CJ_TLS_DYN_CRYPTO_free(ptr: CPointer<Byte>, dynMsgPtr: CPointer<DynMsg>): Unit

    func DYN_OPENSSL_cipher_name(stdname: CString, dynMsgPtr: CPointer<DynMsg>): CString
}

func CRYPTO_free(ptr: CPointer<Byte>): Unit {
    unsafe {
        var dynMsg = DynMsg()
        CJ_TLS_DYN_CRYPTO_free(ptr, inout dynMsg)
        checkDynMsg(dynMsg)
    }
}

func OPENSSL_cipher_name(stdname: CString): CString {
    unsafe {
        var dynMsg = DynMsg()
        let res = DYN_OPENSSL_cipher_name(stdname, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_GetCipherSuite(stream: CPointer<Ssl>): CPointer<TlsCipherSuite> {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_GetCipherSuite(stream, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_GetAllCipherSuites(): CPointer<CPointer<TlsCipherSuite>> {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_GetAllCipherSuites(inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetCipherList(ctx: CPointer<Ctx>, str: CString): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetCipherList(ctx, str, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetCipherSuites(ctx: CPointer<Ctx>, str: CString): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetCipherSuites(ctx, str, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}