/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* This source file is part of the Cangjie project, licensed under Apache-2.0
* with Runtime Library Exception.
*
* See https://cangjie-lang.cn/pages/LICENSE for license information.
*/
package stdx.net.tls
import stdx.net.tls.common.TlsException
extend TlsRawSocket {
/**
* @throws TlsException while tls socket is not connected.
* @throws IllegalMemoryException if malloc failed.
*/
func getAlpnSelected(): ?String {
unsafe {
let protoCCp = LibC.malloc<CPointer<Byte>>(count: 1)
if (protoCCp.isNull()) {
throw IllegalMemoryException("Failed to allocate memory.")
}
let lenCp = LibC.malloc<UInt32>(count: 1)
if (lenCp.isNull()) {
LibC.free(protoCCp)
throw IllegalMemoryException("Failed to allocate memory.")
}
lenCp.write(0)
protoCCp.write(CPointer())
try {
otherNonIO<Unit> {
stream, _ => CJ_TLS_GetAlpnSelected(stream, protoCCp, lenCp)
}
var protoCp: CPointer<Byte> = protoCCp.read(0)
var len: Int64 = Int64(lenCp.read(0))
if (protoCp.isNull() || len == 0) {
return None
}
// the string is not a zero-terminated string, can't use CString
let item: Array<Byte> = Array<Byte>(len, repeat: 0)
copy(source: protoCp, destination: item, count: UIntNative(len))
return String.fromUtf8(item)
} finally {
LibC.free(CPointer<Unit>(protoCCp))
LibC.free(CPointer<Unit>(lenCp))
}
}
}
}
/*
* @throws TlsException while the client alpn protocol is set failed.
*/
func setClientAlpnProtos(clientCtx: CPointer<Ctx>, alpnList: Array<String>): Unit {
if (alpnList.isEmpty()) {
return
}
var buf = alpnListToByteArray(alpnList)
unsafe {
var cph = acquireArrayRawData(buf)
try {
var ret = CJ_TLS_SetClientAlpnProtocols(clientCtx, cph.pointer, UInt32(buf.size))
if (ret == -1) {
throw TlsException("Failed to set client alpn protocol.")
}
} finally {
releaseArrayRawData(cph)
}
}
}
/*
* @throws TlsException while the server alpn protocol is set failed.
*/
func setServerAlpnProtos(serverCtx: CPointer<Ctx>, alpnList: Array<String>): Unit {
if (alpnList.isEmpty()) {
return
}
var buf = alpnListToByteArray(alpnList)
unsafe {
var cph = acquireArrayRawData(buf)
try {
var ret = CJ_TLS_SetServerAlpnProtos(serverCtx, cph.pointer, UInt32(buf.size))
if (ret <= 0) {
throw TlsException("Failed to set server alpn protocol.")
}
} finally {
releaseArrayRawData(cph)
}
}
}
/*
* @throws TlsException while The length of the alpn protocol is more than 255.
*/
func alpnListToByteArray(alpnList: Array<String>): Array<Byte> {
var bufLen = 0
for (s in alpnList) {
var ss = s.size
if (ss > 255) {
throw TlsException("The length of the alpn protocol must be less than 255.")
}
bufLen = bufLen + 1 + ss
}
var buf = Array<Byte>(bufLen, repeat: 0)
var bufIndex = 0
for (s in alpnList) {
let view = unsafe { s.rawData() }
let ss = view.size
buf[bufIndex] = UInt8(ss)
bufIndex++
view.copyTo(buf, 0, bufIndex, ss)
bufIndex += ss
}
return buf
}
@C
struct DynMsg {
var found = true
var funcName = CPointer<UInt8>()
}
foreign {
func CJ_TLS_DYN_GetAlpnSelected(aStream: CPointer<Ssl>, proto: CPointer<CPointer<Byte>>, len: CPointer<UInt32>,
dynMsgPtr: CPointer<DynMsg>): Unit
func CJ_TLS_DYN_SetClientAlpnProtocols(ctx: CPointer<Ctx>, protos: CPointer<Byte>, protosLen: UInt32,
dynMsgPtr: CPointer<DynMsg>): Int32
func CJ_TLS_DYN_SetServerAlpnProtos(ctx: CPointer<Ctx>, protos: CPointer<Byte>, protosLen: UInt32,
dynMsgPtr: CPointer<DynMsg>): Int32
}
func CJ_TLS_SetClientAlpnProtocols(ctx: CPointer<Ctx>, protos: CPointer<Byte>, protosLen: UInt32): Int32 {
unsafe {
var dynMsg: DynMsg = DynMsg()
let res = CJ_TLS_DYN_SetClientAlpnProtocols(ctx, protos, protosLen, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_GetAlpnSelected(aStream: CPointer<Ssl>, proto: CPointer<CPointer<Byte>>, len: CPointer<UInt32>): Unit {
unsafe {
var dynMsg = DynMsg()
CJ_TLS_DYN_GetAlpnSelected(aStream, proto, len, inout dynMsg)
checkDynMsg(dynMsg)
}
}
func CJ_TLS_SetServerAlpnProtos(ctx: CPointer<Ctx>, protos: CPointer<Byte>, protosLen: UInt32): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetServerAlpnProtos(ctx, protos, protosLen, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}