/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* This source file is part of the Cangjie project, licensed under Apache-2.0
* with Runtime Library Exception.
*
* See https://cangjie-lang.cn/pages/LICENSE for license information.
*/
package stdx.net.tls
import stdx.net.tls.common.*
import stdx.crypto.x509.*
import stdx.crypto.common.*
enum HandshakeConfig {
| Client(TlsClientConfig, ?TlsClientSession) // cjlint-ignore !G.ENU.01
| Server(TlsServerConfig, ?TlsServerSession) // cjlint-ignore !G.ENU.01
| KeylessServer(KeylessTlsServerConfig, ?TlsServerSession)
}
extend TlsContext {
func setCertificateChainAndCallback(chain: Array<X509Certificate>): Unit {
withContext<Unit> {
nativeContext, exception => setCertificateChainAndCallback(nativeContext, chain, exception)
}
}
func setCertificateChainAndPrivateKey(chain: Array<X509Certificate>, key: PrivateKey): Unit {
withContext<Unit> {
nativeContext, exception => setCertificateChainAndPrivateKey(nativeContext, chain, key, exception)
}
}
func setDHParam(key: ?DHParameters): Unit {
withContext<Unit> {
nativeContext, exception => setDHParam(nativeContext, key, exception)
}
}
func configureVerifyMode(verifyMode: CertificateVerifyMode): Unit {
withContext<Unit> {
nativeContext, exception => configureVerifyMode(nativeContext, verifyMode, exception)
}
}
func configureClientIdentification(verifyMode: TlsClientIdentificationMode) {
withContext<Unit> {
nativeContext, _ => configureClientVerifyMode(nativeContext, verifyMode)
}
}
func configureSecurityLevel(level: Int32): Unit {
withContext<Unit> {
nativeContext, _ =>
let ret = unsafe { CJ_TLS_SetSecurityLevel(nativeContext, level) }
if (ret == 0) {
throw TlsException("Failed to set tls socket security level.")
}
}
}
private func signCallbackWrapper(): CKeylessSignCallback {
return {
key: CString, alg: CString, digestPtr: CPointer<Byte>, digestLen: Int64, written: CPointer<Int64> =>
let (signCb, _) = TlsSocket.keylessCallback.get(key.toString()) ?? throw TlsException("Missing keyless sign callback.")
let digest = Array<Byte>(digestLen, repeat: 0)
for (i in 0..digestLen) {
unsafe { digest[i] = digestPtr.read(i) }
}
var signature: Array<Byte> = signCb(digest)
let sigPtr = unsafe { LibC.malloc<Byte>(count: signature.size) }
unsafe {
for (i in 0..signature.size) {
sigPtr.write(i, signature[i])
}
written.write(signature.size)
}
return sigPtr
}
}
private func decryptCallbackWrapper(): CKeylessDecryptCallback {
return {
key: CString, ciphertextPtr: CPointer<Byte>, ciphertextLen: Int64, written: CPointer<Int64> =>
let (_, decryptCbOpt) = TlsSocket.keylessCallback.get(key.toString()) ?? throw TlsException("Missing keyless decrypt callback.")
let decryptCb = decryptCbOpt ?? throw TlsException("Decrypt callback is not provided.")
let ciphertext = Array<Byte>(ciphertextLen, repeat: 0)
for (i in 0..ciphertextLen) {
unsafe { ciphertext[i] = ciphertextPtr.read(i) }
}
let plain: Array<Byte> = decryptCb(ciphertext)
let cipherPtr = unsafe { LibC.malloc<Byte>(count: plain.size) }
unsafe {
for (i in 0..plain.size) {
cipherPtr.write(i, plain[i])
}
written.write(plain.size)
}
plain.fill(0)
return cipherPtr
}
}
private func setCertificateChainAndCallback(
context: CPointer<Ctx>,
chain: Array<X509Certificate>, // should not be empty, no need to check
exception: CPointer<ExceptionData>
): Unit {
var firstAdded = false
for (cert in chain) {
let entry = PemEntry(PemEntry.LABEL_CERTIFICATE, cert.encodeToDer())
useAddCertImpl(context, entry, add: firstAdded, exception: exception)
firstAdded = true
}
setKeylessPrivateKeyImpl(context, exception)
setKeylessCallbacks(signCallbackWrapper(), decryptCallbackWrapper(), exception)
}
private func setCertificateChainAndPrivateKey(
context: CPointer<Ctx>,
chain: Array<X509Certificate>,
key: PrivateKey,
exception: CPointer<ExceptionData>
): Unit {
if (chain.isEmpty()) {
throw IllegalArgumentException("Certificate chain is empty: at least one certificate required.")
}
setPrivateKeyImpl(context, key.encodeToDer(), exception)
var ifNotFirst = false
for (cert in chain) {
let entry = PemEntry(PemEntry.LABEL_CERTIFICATE, cert.encodeToDer())
useAddCertImpl(context, entry, add: ifNotFirst, exception: exception)
ifNotFirst = true
}
unsafe {
let ret = CJ_TLS_CheckPrivateKey(context)
if (ret <= 0) {
throw TlsException("The Certificate chain doesn't match the private key.")
}
}
}
private func setKeylessPrivateKeyImpl(
context: CPointer<Ctx>,
exception: CPointer<ExceptionData>
) {
let setKeyResult = unsafe { CJ_TLS_SetKeylessPrivateKey(context, exception) }
if (setKeyResult == 0) {
unsafe { exception.read().throwException(fallback: "Failed to set keyless private key.") }
}
}
private func setPrivateKeyImpl(
context: CPointer<Ctx>,
key: DerBlob,
exception: CPointer<ExceptionData>
) {
let rawContent = key.body
let rawContentSize = UIntNative(rawContent.size)
let setKeyResult = unsafe {
let pinned = acquireArrayRawData(rawContent)
try {
CJ_TLS_SetPrivateKey(
context,
pinned.pointer,
rawContentSize,
exception
)
} finally {
releaseArrayRawData(pinned)
}
}
if (setKeyResult == 0) {
unsafe { exception.read().throwException(fallback: "Failed to set private key.") }
}
}
private func useAddCertImpl(
context: CPointer<Ctx>,
entry: PemEntry,
add!: Bool,
exception!: CPointer<ExceptionData>
) {
let useCertResult = unsafe {
withPemImpl<Int32>(entry) {
pointer, size => match (add) {
case true => CJ_TLS_Add_Cert(context, pointer, size, exception)
case false => CJ_TLS_Use_Cert(context, pointer, size, exception)
}
}
}
if (useCertResult == 0) {
unsafe { exception.read().throwException(fallback: "Failed to add certificate.") }
}
}
private func setDHParam(
context: CPointer<Ctx>,
key: ?DHParameters,
exception: CPointer<ExceptionData>
): Unit {
let result = if (let Some(v) <- key) {
let entry = PemEntry(PemEntry.LABEL_DH_PARAMETERS, v.encodeToDer())
unsafe {
withPemImpl<Int32>(entry) {
pointer, size => CJ_TLS_SetDHParam(context, pointer, size, exception)
}
}
} else {
unsafe { CJ_TLS_SetDHParam(context, CPointer<Unit>(), 0, exception) }
}
if (result == 0) {
unsafe { exception.read().throwException(fallback: "Failed to set DH Parameters.") }
}
}
private func withPemImpl<R>(entry: PemEntry, block: (CPointer<Unit>, UIntNative) -> R): R {
let pemText = Pem([entry]).encode().toArray()
let pemSize = UIntNative(pemText.size)
unsafe {
let pinned = acquireArrayRawData(pemText)
try {
block(CPointer(pinned.pointer), pemSize)
} finally {
releaseArrayRawData(pinned)
}
}
}
private func configureVerifyMode(
context: CPointer<Ctx>,
verifyMode: CertificateVerifyMode,
exception: CPointer<ExceptionData>
): Unit {
match (verifyMode) {
case Default => configureCustomCA(context, X509Certificate.systemRootCerts(), exception)
case TrustAll => setContextTrustAll(context)
case CustomCA(items) =>
let ca = items.map({c => c as X509Certificate ??
throw TlsException("Only certificates of type `X509Certificate` are allowed.")})
configureCustomCA(context, ca, exception)
configureCustomCA(context, X509Certificate.systemRootCerts(), exception)
case CustomVerify(callback) => setContextCustomVerify(callback)
case _ => TlsException("Unsupported verify mode.")
}
}
private func configureCustomCA(
context: CPointer<Ctx>,
items: Array<X509Certificate>,
exception: CPointer<ExceptionData>
) {
for (ca in items) {
let pem = ca.encodeToPem().encode().toArray()
let bytes = unsafe { acquireArrayRawData(pem) }
try {
let addResult = unsafe { CJ_TLS_Add_CA(context, CPointer(bytes.pointer), UIntNative(pem.size), exception
) }
if (addResult == 0) {
throw TlsException("Failed to assign CA certificate.")
}
} finally {
unsafe { releaseArrayRawData(bytes) }
}
}
}
private func setContextTrustAll(context: CPointer<Ctx>) {
let ret = unsafe { CJ_TLS_SetTrustAll(context) }
if (ret == 0) {
throw TlsException("Failed to set tls socket verify mode to trust all.")
}
}
private func configureClientVerifyMode(context: CPointer<Ctx>, verifyMode: TlsClientIdentificationMode) {
let result = unsafe {
match (verifyMode) {
case Disabled => CJ_TLS_SetClientVerifyMode(context, 0, 0) // not required, not verify
case Optional => CJ_TLS_SetClientVerifyMode(context, 0, 1) // not required, do verify
case Required => CJ_TLS_SetClientVerifyMode(context, 1, 1) // required, do verify
}
}
if (result == 0) {
throw TlsException("Failed to configure TlsClientIdentificationMode.")
}
}
private func setContextCustomVerify(callback: CertificateVerifyCallbackFunction): Unit { // cjlint-ignore !G.FUN.02
withContext<Unit> {
nativeContext, _ =>
var dynMsg = DynMsg()
unsafe { CJ_TLS_DYN_SetCustomVerifyMode(nativeContext, customVerifyCallback, inout dynMsg) }
checkDynMsg(dynMsg)
}
}
}
foreign {
func CJ_TLS_DYN_Add_CA(
context: CPointer<Ctx>,
ca: CPointer<Unit>,
length: UIntNative,
exception: CPointer<ExceptionData>,
dynMsg: CPointer<DynMsg>
): Int32
func CJ_TLS_DYN_Add_Cert(
context: CPointer<Ctx>,
pem: CPointer<Unit>,
length: UIntNative,
exception: CPointer<ExceptionData>,
dynMsg: CPointer<DynMsg>
): Int32
func CJ_TLS_DYN_Use_Cert(
context: CPointer<Ctx>,
pem: CPointer<Unit>,
length: UIntNative,
exception: CPointer<ExceptionData>,
dynMsg: CPointer<DynMsg>
): Int32
func CJ_TLS_DYN_SetPrivateKey(
context: CPointer<Ctx>,
key: CPointer<Byte>,
length: UIntNative,
exception: CPointer<ExceptionData>,
dynMsg: CPointer<DynMsg>
): Int32
func CJ_TLS_DYN_SetKeylessPrivateKey(
context: CPointer<Ctx>,
exception: CPointer<ExceptionData>,
dynMsg: CPointer<DynMsg>
): Int32
func CJ_TLS_DYN_SetDHParam(
context: CPointer<Ctx>,
key: CPointer<Unit>,
length: UIntNative,
exception: CPointer<ExceptionData>,
dynMsgPtr: CPointer<DynMsg>
): Int32
func CJ_TLS_DYN_SetTrustAll(context: CPointer<Ctx>, dynMsg: CPointer<DynMsg>): Int32
func CJ_TLS_DYN_SetClientVerifyMode(context: CPointer<Ctx>, required: Int32, verify: Int32, dynMsg: CPointer<DynMsg>): Int32
func CJ_TLS_DYN_SetSecurityLevel(context: CPointer<Ctx>, level: Int32, dynMsg: CPointer<DynMsg>): Int32
func CJ_TLS_DYN_SetCustomVerifyMode(
context: CPointer<Ctx>,
verifyCallback: CFunc<(context: CPointer<Ctx>, chain: CPointer<CertChainItem>, count: Int32) -> Int32>,
dynMsg: CPointer<DynMsg>): Int32
}
func CJ_TLS_Add_CA(
context: CPointer<Ctx>,
ca: CPointer<Unit>,
length: UIntNative,
exception: CPointer<ExceptionData>
): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_Add_CA(context, ca, length, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_Add_Cert(
context: CPointer<Ctx>,
pem: CPointer<Unit>,
length: UIntNative,
exception: CPointer<ExceptionData>
): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_Add_Cert(context, pem, length, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SetPrivateKey(
context: CPointer<Ctx>,
key: CPointer<Byte>,
length: UIntNative,
exception: CPointer<ExceptionData>
): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetPrivateKey(context, key, length, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SetKeylessPrivateKey(
context: CPointer<Ctx>,
exception: CPointer<ExceptionData>
): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetKeylessPrivateKey(context, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SetDHParam(
context: CPointer<Ctx>,
key: CPointer<Unit>,
length: UIntNative,
exception: CPointer<ExceptionData>
): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetDHParam(context, key, length, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SetTrustAll(context: CPointer<Ctx>): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetTrustAll(context, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SetClientVerifyMode(context: CPointer<Ctx>, required: Int32, verify: Int32): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetClientVerifyMode(context, required, verify, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SetSecurityLevel(context: CPointer<Ctx>, level: Int32): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetSecurityLevel(context, level, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_Use_Cert(
context: CPointer<Ctx>,
pem: CPointer<Unit>,
length: UIntNative,
exception: CPointer<ExceptionData>
): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_Use_Cert(context, pem, length, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
@C
func customVerifyCallback(context: CPointer<Ctx>, chain: CPointer<CertChainItem>, count: Int32): Int32 {
let certificatesToVerify: Array<Certificate>
if (chain.isNull() || count == 0) {
certificatesToVerify = []
} else {
try {
let certificates = unsafe { TlsRawSocket.convertNativeChain(chain, Int64(count)) } ?? []
certificatesToVerify = certificates.map({c => c})
} finally {
unsafe { CJ_TLS_DYN_CertChainFree(chain, Int64(count), CPointer()) }
}
}
if (let Some(bridge) <- Bridge.findByContext(context)) {
if (let Some(verify) <- bridge.certificateVerifyCallback) {
let passed = verify(certificatesToVerify)
return if (passed) { 1 } else { 0 }
}
}
return 1
}