/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* This source file is part of the Cangjie project, licensed under Apache-2.0
* with Runtime Library Exception.
*
* See https://cangjie-lang.cn/pages/LICENSE for license information.
*/
package stdx.net.tls
import std.sync.*
import std.time.DateTime
import std.collection.concurrent.{ConcurrentHashMap, ArrayBlockingQueue}
import stdx.crypto.x509.*
import stdx.encoding.hex.toHexString
import stdx.net.tls.common.*
/**
* Represents a client session established. This type is opaque and it's internals are implementation specific.
* Instances of this type are never created by users: instead one should borrow instance from a successfully
* negotiated TlsSocket after handshake() invocation.
*/
public class TlsClientSession <: TlsSession & Equatable<TlsClientSession> & ToString & Hashable {
private TlsClientSession(
let holder: SessionHolder,
let id: Array<Byte>
) {}
/**
* This type is not constructed directly. Use TlsSocket.session property to get a session instance.
*/
init(pointer: CPointer<NativeSession>) {
// it's important to get it first, before creating a holder
id = unsafe { getSessionId(pointer) }
holder = SessionHolder(pointer)
}
public override operator func ==(other: TlsClientSession): Bool {
id == other.id
}
public override operator func !=(other: TlsClientSession): Bool {
id != other.id
}
public override func toString(): String {
"TlsClientSession(${toHexString(id)})"
}
public override func hashCode(): Int64 {
SessionKey.hashOf(id)
}
private static unsafe func getSessionId(session: CPointer<NativeSession>): Array<Byte> {
var result: Array<Byte> = []
try (array = malloc<CPointer<Byte>>(initial: CPointer()), length = malloc<UIntNative>(initial: 0)) {
CJ_TLS_GetSessionId(session, array.pointer, length.pointer)
if (!array.value.isNull() && length.value > 0) {
result = toArray(array.value, length.value)
}
}
return result
}
}
extend TlsRawSocket {
func setSession(session: TlsClientSession): Unit {
otherNonIO<Unit> {
ssl, _ => session.holder.setSessoinAt(ssl)
}
}
}
class SessionHolder {
private let pointer: CPointer<NativeSession>
private static let counter = AtomicInt64(1)
private let refCount = RefCounter() // secondary refcount for ~init
private var certsOp: ?Array<X509Certificate> = None
let id = counter.fetchAdd(1)
init(pointer: CPointer<NativeSession>) {
if (pointer.isNull()) {
throw IllegalArgumentException("Session pointer shouldn't be NULL.")
}
// the increment should be before assigning to the field
// but after checking the pointer for NULL
unsafe { CJ_TLS_IncrementUse(pointer) }
this.pointer = pointer
}
~init() {
// this may be invoked right in the middle of setSessionAt
// so we need special care here
// we can't use any sync primitives like mutexes in finalizers
// and can only rely on atomics
// that are used here for manual refcounting
if (refCount.shutdown()) {
// no concurrent invocations setSessionAt were detected
free(pointer)
}
}
func setCerts(certs: Array<X509Certificate>): Unit {
certsOp = certs
}
func getCerts(): ?Array<X509Certificate> {
certsOp
}
func setSessoinAt(stream: CPointer<Ssl>): Unit {
withNativeSession<Unit> {
pointer => unsafe {
// put session to the SSL* object
// and increment it's internal ref-counter (NOT our refCount)
if (CJ_TLS_SetSession(stream, pointer) != 0) {
throw TlsException("Failed to set session to context.")
}
}
}
}
func withNativeSession<R>(block: (CPointer<NativeSession>) -> R): R {
unsafe {
// we increment refcount BEFORE reading "pointer" from "this"
refCount.enter()
// here we have "pointer" field access AFTER the barrier (inside of enter())
// so ~init can't be triggered before enter() invocation
let pointer = pointer
// here we touched "this" last time: from this point ~init may happen at any time
// however the secondary refCount will prevent it from destroying the session
// so if ~init will happened while we are here, session free will be postponed
// until we do leave() in the finally block
try {
block(pointer)
// we would use "blackhole" intrinsic to avoid GC of "this"
// with concurrent ~init invocation while we are running this function
// so we could avoid all this terrible atomic staff and replace everything with just
// a single line fix like this:
// @blackhole(this) // this should be done AFTER using session pointer
} finally {
if (refCount.leave()) {
// the finalizer has been already triggered
// and we are the last concurrent setSessionAt invocation
// so we are doing free here
free(pointer)
// note that it's unlikely will cause session disposal (see the comment in free)
}
}
}
}
/**
* Returns SSL_SESSION* pointer having incremented use-count
*/
func getIncremented(): CPointer<NativeSession> {
unsafe {
refCount.enter()
let pointer = pointer
CJ_TLS_IncrementUse(pointer)
if (refCount.leave()) {
free(pointer)
}
return pointer
}
}
private static func free(pointer: CPointer<NativeSession>) {
unsafe {
// remember that it's not necessary means actual session disposal
// because there is a builtin ref-counter inside of the session object
// so here we are just decrementing the counter that MAY lead to disposal
CJ_TLS_DeleteSession(pointer)
}
}
}
/*
* @throws TlsException while the length of sessionId exceed 32 bytes,
* or the sessionId is set failed.
*/
func setServerSessionId(serverCtx: CPointer<Ctx>, sessionId: String): Unit {
let sessionIdSize = sessionId.size
if (sessionIdSize > 32) {
throw TlsException("The length of the session ID cannot exceed 32 bytes.")
}
var cStr = unsafe { LibC.mallocCString(sessionId) }
try {
let ret = unsafe { CJ_TLS_SetSessionIdContext(serverCtx, cStr, UInt32(sessionIdSize)) }
if (ret <= 0) {
throw TlsException("Failed to set tls socket server session ID.")
}
} finally {
unsafe { LibC.free(cStr) }
}
}
@C
struct NativeSession {}
foreign {
func CJ_TLS_DYN_SetSessionCallback(
put: CFunc<(CPointer<Ssl>, CPointer<Byte>, UIntNative, CPointer<NativeSession>) -> Unit>,
remove: CFunc<(CPointer<Ctx>, CPointer<Byte>, UIntNative, CPointer<NativeSession>) -> Unit>,
find: CFunc<(CPointer<Ssl>, CPointer<Byte>, UIntNative) -> CPointer<NativeSession>>,
assign: CFunc<(CPointer<Ssl>, CPointer<NativeSession>) -> Unit>,
dynMsgPtr: CPointer<DynMsg>
): Unit
func CJ_TLS_DYN_SetSessionIdContext(ctx: CPointer<Ctx>, sidCtx: CString, sidCtxLen: UInt32,
dynMsgPtr: CPointer<DynMsg>): Int32
func CJ_TLS_DYN_DeleteSession(pointer: CPointer<NativeSession>, dynMsgPtr: CPointer<DynMsg>): Unit
func CJ_TLS_DYN_SetSession(context: CPointer<Ssl>, session: CPointer<NativeSession>, dynMsgPtr: CPointer<DynMsg>): Int32
func CJ_TLS_DYN_IncrementUse(session: CPointer<NativeSession>, dynMsgPtr: CPointer<DynMsg>): Unit
func CJ_TLS_DYN_GetSessionId(
session: CPointer<NativeSession>,
data: CPointer<CPointer<Byte>>,
length: CPointer<UIntNative>,
dynMsgPtr: CPointer<DynMsg>
): Unit
func CJ_TLS_DYN_AddSession(context: CPointer<Ctx>, session: CPointer<NativeSession>, dynMsgPtr: CPointer<DynMsg>): Int32
}
func CJ_TLS_SetSessionCallback(
put: CFunc<(CPointer<Ssl>, CPointer<Byte>, UIntNative, CPointer<NativeSession>) -> Unit>,
remove: CFunc<(CPointer<Ctx>, CPointer<Byte>, UIntNative, CPointer<NativeSession>) -> Unit>,
find: CFunc<(CPointer<Ssl>, CPointer<Byte>, UIntNative) -> CPointer<NativeSession>>,
assign: CFunc<(CPointer<Ssl>, CPointer<NativeSession>) -> Unit>
): Unit {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetSessionCallback(put, remove, find, assign, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SetSessionIdContext(ctx: CPointer<Ctx>, sidCtx: CString, sidCtxLen: UInt32): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetSessionIdContext(ctx, sidCtx, sidCtxLen, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_DeleteSession(pointer: CPointer<NativeSession>): Unit {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_DeleteSession(pointer, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SetSession(context: CPointer<Ssl>, session: CPointer<NativeSession>): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SetSession(context, session, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_IncrementUse(session: CPointer<NativeSession>): Unit {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_IncrementUse(session, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_GetSessionId(
session: CPointer<NativeSession>,
data: CPointer<CPointer<Byte>>,
length: CPointer<UIntNative>
): Unit {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_GetSessionId(session, data, length, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_AddSession(context: CPointer<Ctx>, session: CPointer<NativeSession>): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_AddSession(context, session, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
class RefCounter {
private static let Terminated = Int64.Min
private let state = AtomicInt64(0) // 0 idle, positive - in use refcounter, negative shutdown+in-use
func enter(): Unit {
updateAndGet {
before =>
let after = match {
case before == Terminated => throw IllegalStateException("The session has been already collected.")
case before < 0 => before - 1
case _ => before + 1
}
if (after == Terminated) {
throw IllegalStateException("Too many users.")
}
after
}
}
func leave(): Bool {
let after = updateAndGet {
before => match {
case before == -1 => Terminated
case before < 0 => before + 1
case _ => before - 1
}
}
return after == Terminated
}
func shutdown(): Bool {
let after = updateAndGet {
before => match {
case before == 0 => Terminated
case before < 0 => None
case _ => -before
}
}
return after == Terminated
}
// this could be a part of stdlib (std/sync) but it can't as lambdas are too slow
// here we don't care because it's not a hot-path
private func updateAndGet(map: (Int64) -> ?Int64): ?Int64 {
var after: Int64 = 0
do {
let before = state.load()
after = map(before) ?? return None
if (state.compareAndSwap(before, after)) {
break
}
} while (true)
return after
}
}
// we could distinguish SessionStore instances and make them configurable
// but for now we use a single session store for the whole app
@C
func CJ_TLS_put_session(
ssl: CPointer<Ssl>,
id: CPointer<Byte>,
idLength: UIntNative,
session: CPointer<NativeSession>
): Unit {
if (let Some(bridge) <- Bridge.findByStream(ssl)) {
if (let Some(sessionStore) <- bridge.sessionStore) {
bridge.socket.newSessionId = unsafe { toArray(id, idLength) }
sessionStore.put(SessionKey.copyFrom(id, idLength), session)
} else if (!bridge.server) {
bridge.socket.negotiatedSession = TlsClientSession(session)
}
}
}
@C
func CJ_TLS_assign_session(
ssl: CPointer<Ssl>,
session: CPointer<NativeSession>
): Unit {
// at this point SSL's session refcount is preincremented
// so we should return 1 to keep it incremented or 0 to decrement it again
if (let Some(bridge) <- Bridge.findByStream(ssl)) {
if (!bridge.server) {
bridge.socket.negotiatedSession = TlsClientSession(session)
}
}
}
@C
func CJ_TLS_remove_session(
ctx: CPointer<Ctx>,
id: CPointer<Byte>,
idLength: UIntNative,
_: CPointer<NativeSession>
): Unit {
if (let Some(bridge) <- Bridge.findByContext(ctx)) {
if (let Some(sessionStore) <- bridge.sessionStore) {
sessionStore.remove(SessionKey.copyFrom(id, idLength))
}
}
}
@C
func CJ_TLS_find_session(
ssl: CPointer<Ssl>,
id: CPointer<Byte>,
idLength: UIntNative
): CPointer<NativeSession> {
try {
let bridge = Bridge.findByStream(ssl) ?? return CPointer<NativeSession>()
let sessionStore = bridge.sessionStore ?? return CPointer<NativeSession>()
return match (sessionStore.find(SessionKey.copyFrom(id, idLength))) {
case Some(nativeSession) =>
bridge.socket.resumedSessionId = unsafe { toArray(id, idLength) }
nativeSession
case _ => CPointer<NativeSession>()
}
} catch (_: Exception) {
// we should not throw Exception from cj code to c code
return CPointer<NativeSession>()
}
}
// we only need this just because Array<Byte> is not Hashable and can't be
struct SessionKey <: Hashable & Equatable<SessionKey> {
SessionKey(let bytes: Array<Byte>) {}
public override func hashCode(): Int64 {
hashOf(bytes)
}
public operator func ==(other: SessionKey): Bool {
bytes == other.bytes
}
public operator func !=(other: SessionKey): Bool {
bytes != other.bytes
}
static func copyFrom(ptr: CPointer<Byte>, length: UIntNative): SessionKey {
unsafe { SessionKey(toArray(ptr, length)) }
}
static func hashOf(bytes: Array<Byte>): Int64 {
var hashser = DefaultHasher()
for (b in bytes) {
hashser.write(b)
}
return hashser.finish()
}
}
struct WaitQueueEntry {
private let createdAt = DateTime.nowUTC()
WaitQueueEntry(let key: SessionKey) {}
prop age: Duration {
get() {
DateTime.nowUTC() - createdAt
}
}
func remaining(timeout: Duration): Duration {
timeout - age
}
}
struct SessionStore {
private let sessions: ConcurrentHashMap<SessionKey, SessionHolder>
private let queue: ArrayBlockingQueue<WaitQueueEntry>
SessionStore(
let capacity!: Int64,
let timeout!: Duration
) {
sessions = ConcurrentHashMap(capacity * 2)
queue = ArrayBlockingQueue(capacity)
}
func put(key: SessionKey, session: CPointer<NativeSession>): TlsClientSession {
let apiSession = TlsClientSession(session)
sessions.add(key, apiSession.holder)
enqueue(key)
ensureWatchdog()
return apiSession
}
func tryGetHolder(key: SessionKey): ?SessionHolder {
return sessions.get(key)
}
func remove(key: SessionKey) {
sessions.remove(key)
// it is probably still remaining in the queue but we don't care
// as it will be pushed away at some point later by the watchdog
// or by another put operation
}
/*
* Returns a session with incremented refcount
* This is important to increment it before, otherwise a session could be released concurrently
* and the invoker of function MUST handle this extra-increment after applying a session
*/
func find(key: SessionKey): ?CPointer<NativeSession> {
sessions.get(key)?.getIncremented()
}
func clear() {
// missing queue clear
while (let Some(_) <- queue.tryRemove()) {}
// missing sessions clear
for ((k, _) in sessions) {
sessions.remove(k)
}
}
private func enqueue(key: SessionKey) {
while (!queue.tryAdd(WaitQueueEntry(key))) {
if (let Some(old) <- queue.tryRemove()) {
sessions.remove(old.key)
}
}
}
private func ensureWatchdog() {
if (let Some(next) <- queue.peek()) {
let timeToWait = next.remaining(timeout)
Timer.after(timeToWait) {
// here we may actually remove wrong session
// because of concurrency with put doing remove if the queue if full
// but we don't care as lost sessions is not dramatic failure but just
// a little bit unfortunate
// we need queue.tryDequeue(box) to workaround it
if (let Some(removed) <- queue.tryRemove()) {
// even if we removed wrong session from the queue
// we must remove it from the map to avoid leak
sessions.remove(removed.key)
}
return None
}
}
}
}
/**
* When a client attempts to resume a session, both counterparts have
* to ensure that they are resuming session with a legitime peer.
* At server side there is a so called session context doing both things: verifying
* clients and providing info to clients that there is still the same server instance it's connecting to.
* For a stateful sessions they are literally stored in this context instance so an instance of a session context
* should be shared between server TlsSocket instances inside of a single server node.
*/
public class TlsServerSession <: TlsSession & Equatable<TlsServerSession> & ToString {
let name: String
let store = SessionStore(capacity: 600, timeout: Duration.hour)
init(name: String) {
this.name = name
}
static init() {
unsafe {
CJ_TLS_SetSessionCallback(
CJ_TLS_put_session,
CJ_TLS_remove_session,
CJ_TLS_find_session,
CJ_TLS_assign_session
)
}
}
~init() {
store.clear()
}
/**
* Creates a session context having name as sid and having the default session caching (usually in-memory).
* The name is used to distinguish TLS servers so clients do rely on this name to avoid accidental attempts to
* resume connection to the wrong server. It is not necessarily to use crypto-secure names here as the underlying
* implementation does the job. Two session contexts returned from this function with the same name could be
* non-equal and not guaranteed to be replaceable despite the same name they are created from.
* So a server instance should create a single session context for the whole lifetime and
* use it with every TlsSocket.server() invocation.
*/
public static func fromName(name: String): TlsServerSession {
TlsServerSession(name)
}
public override operator func ==(other: TlsServerSession): Bool {
this.name == other.name
}
public override operator func !=(other: TlsServerSession): Bool {
this.name != other.name
}
public override func toString(): String {
"TlsServerSession(${name})"
}
}