/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* This source file is part of the Cangjie project, licensed under Apache-2.0
* with Runtime Library Exception.
*
* See https://cangjie-lang.cn/pages/LICENSE for license information.
*/
package stdx.net.tls
import stdx.net.tls.common.*
import stdx.crypto.x509.*
import stdx.crypto.common.*
import std.net.{StreamingSocket, SocketAddress}
import std.sync.*
import std.collection.HashMap
foreign {
/* For keyless provider */
func CJ_TLS_GetCertId(data: CPointer<Byte>, len: Int64): CString
func CJ_TLS_GetKeylessKeyId(ctx: CPointer<Ctx>): CString
func CJ_TLS_FreeKeylessId(id: CString): Unit
func DYN_CJ_TLS_InitEmbeddedKeylessProvider(dynMsg: CPointer<DynMsg>): Int8
func DYN_CJ_TLS_RegisterKeylessSignCallback(keyId: CString, cb: CKeylessSignCallback, exception: CPointer<ExceptionData>, dynMsg: CPointer<DynMsg>): Unit
func DYN_CJ_TLS_RegisterKeylessDecryptCallback(keyId: CString, cb: CKeylessDecryptCallback, exception: CPointer<ExceptionData>, dynMsg: CPointer<DynMsg>): Unit
}
public class TlsSocket <: TlsConnection & Equatable<TlsSocket> & Hashable {
private static let idCounter = AtomicInt64(1)
static var keylessCallback = HashMap<String, (KeylessSignFunc, ?KeylessDecryptFunc)>()
static let defaultTimeout: Duration = Duration.second * 30
private let state: AtomicReference<TlsSocketState>
private let id = idCounter.fetchAdd(1) // it is required for hashCode
// for now it's size is 1 but in theory server may provide more sessions
// and it actually does in TLS 1.3
private let negotiatedSessions_ = AtomicOptionReference<Box<TlsClientSession>>()
private var _handshakeResult: ?TlsSocketHandshakeResult = None
var resumedSessionId: ?Array<Byte> = None
var newSessionId: ?Array<Byte> = None
private init(socket: StreamingSocket, handshake: HandshakeConfig) {
this.state = AtomicReference<TlsSocketState>(SocketReady(socket, handshake))
}
/**
* Creates a client TLS stream connected to the specified peer.
* Can be optionally configured to resume a TLS session by TlsClientSession instance.
* However having a session doesn't guarantee that it will be successfully resumed,
* in which case a full TLS handshake will happen and handshake() still may fail for some reason.
* For example some servers may reject old sessions after restart or after timeout.
*/
public static func client(
socket: StreamingSocket,
session!: ?TlsClientSession = None,
clientConfig!: TlsClientConfig = TlsClientConfig()
): TlsSocket {
TlsSocket(socket, HandshakeConfig.Client(clientConfig, session))
}
/**
* Creates a server TLS stream connected to the specified peer
* If sessionContext is unspecified then clients will not be able to resume TLS sessions
* and always require full TLS negotiation.
*/
public static func server(
socket: StreamingSocket,
session!: ?TlsServerSession = None,
serverConfig!: TlsServerConfig
): TlsSocket {
TlsSocket(socket, HandshakeConfig.Server(serverConfig, session))
}
public static func server(
socket: StreamingSocket,
session!: ?TlsServerSession = None, // cjlint-ignore !G.FUN.02
serverConfig!: KeylessTlsServerConfig
): TlsSocket {
let (cert, _) = serverConfig.serverCertificate
unsafe {
let data = cert[0].encodeToDer().body
let cp = acquireArrayRawData(data)
let keyId = CJ_TLS_GetCertId(cp.pointer, data.size)
releaseArrayRawData(cp)
TlsSocket.keylessCallback.add(keyId.toString(), (serverConfig._keylessSignFunc, serverConfig._keylessDecryptFunc))
LibC.free(keyId)
}
TlsSocket(socket, HandshakeConfig.KeylessServer(serverConfig, None))
}
/**
* The underlying streaming socket that has been provided during construction
* or throws an exception if already closed
*/
public prop socket: StreamingSocket {
get() {
tryGetSocket(state.load()) ?? SocketClosed.throwAlreadyClosed()
}
}
public prop handshakeResult: ?TlsHandshakeResult {
get() {
_handshakeResult ?? None
}
}
mut prop negotiatedSession: ?TlsClientSession {
get() {
negotiatedSessions_.load()?.value
}
set(value) {
match (value) {
case Some(value) => negotiatedSessions_.store(Box(value))
case None => throw Exception("Removing sessions is not allowed.")
}
}
}
public prop certificate: Array<X509Certificate> {
get() {
connected.getOrThrow(SocketClosed.alreadyClosedException).certificate ?? []
}
}
// returns None if closed, fails if not negotiated yet (too early)
private prop connected: ?SocketConnected {
get() {
match (state.load()) {
case _: SocketReady => throw TlsException("TLS socket didn't pass handshake yet.")
case _: SocketInHandshake => throw TlsException("TLS socket didn't pass handshake yet.")
case connected: SocketConnected => connected
case _ => None
}
}
}
public override mut prop readTimeout: ?Duration {
get() {
socket.readTimeout
}
set(newTimeout) {
socket.readTimeout = newTimeout
}
}
public override mut prop writeTimeout: ?Duration {
get() {
socket.writeTimeout
}
set(newTimeout) {
socket.writeTimeout = newTimeout
}
}
public override prop localAddress: SocketAddress {
get() {
socket.localAddress
}
}
public override prop remoteAddress: SocketAddress {
get() {
socket.remoteAddress
}
}
/**
* Negotiated ALPN protocol name
*/
prop alpnProtocolName: ?String {
get() {
let connected = connected ?? SocketClosed.throwAlreadyClosed()
connected.stream.getAlpnSelected()
}
}
/**
* Negotiated TLS version.
*
* Works only after successful handshake()
*/
prop tlsVersion: TlsVersion {
get() {
getVersion()
}
}
/*
* On client it can be captured after a successful handshake()
* and later used to reconnect without a full handshake.
*
* This may be None if we are not yet successfully negotiated, session resumption is not enabled
* or if the feature is not supported by client. Depending on the underlying implementation
* and TLS version it may appear in the middle of handshake or much later after,
* or not appear at all if server decides to not start/resume session ticket.
*
* If session is negotiated, it remains unchanged after closing socket.
* Later negotiated sessions overwrite previous (when server send multiple sessions) so the latest
* session is always observed.
*
* On server it is always None because server can't initiate a session resumption.
*/
prop session: ?TlsClientSession {
get() {
negotiatedSession
}
}
// negotiated domain name (SNI)
prop domain: ?String {
get() {
getHostName()
}
}
/**
* Peer certificate if provided by the peer or `None`
*/
prop peerCertificate: ?Array<X509Certificate> {
get() {
let connState = connected.getOrThrow(SocketClosed.alreadyClosedException)
if (let Some(v) <- connState.peerCertificate) {
return v
}
if (connState.isClient) {
return None
}
// some resumed conn may be faster than full handshake conn, check again
if (let Some(id) <- resumedSessionId && let Some(store) <- connState.bridge.sessionStore) {
if (let Some(holder) <- store.tryGetHolder(SessionKey(id))) {
connState.setPeerCerts(holder.getCerts())
return holder.getCerts()
}
}
return None
}
}
/*
* The cipher suite negotiated after a handshake.
* Cipher suite contains encryption algorithm, hash function for message authentication, key exchange algorithm.
*
* @throws TlsException if handshake has not yet been performed or socket already closed.
*/
prop cipherSuite: CipherSuite {
get() {
let state = connected ?? SocketClosed.throwAlreadyClosed()
state.stream.getCipherSuite()
}
}
/*
* Does handshake optionally limited by a timeout duration
* depending on how the socket was created and configured, this function does
* either client or server handshake
* this could be done only once since renegotiating handshake is not supported
*/
public func handshake(timeout!: ?Duration = None): TlsHandshakeResult {
let (started, handshake) = tryStartHandshake()
let success = try {
match (handshake) {
case HandshakeConfig.Client(config, session) =>
connect(started.socket, timeout ?? defaultTimeout, config, session)
case HandshakeConfig.Server(config, context) =>
handleAccepted(started.socket, timeout ?? defaultTimeout, config, context)
case HandshakeConfig.KeylessServer(config, context) =>
handleAccepted(started.socket, timeout ?? defaultTimeout, config, context)
}
} catch (e: Exception) {
if (isClosed()) {
throw TlsException("TLS socket closed during handshake.")
}
close(selfClosed: true)
throw e
}
if (!state.compareAndSwap(started, success)) {
success.close()
SocketClosed.throwAlreadyClosed()
}
let result = TlsSocketHandshakeResult(this)
_handshakeResult = result
result
}
/**
* Returns the number of bytes read from TLS socket or 0 if closed or reached EOF
*
* Bypass exceptions from the underlying socket.
* Note that due to the TLS protocol nature, this may both read and write with the underlying socket.
* In particular it means that read() function may invoke socket.write() that is legal.
*
* This is concurrent safe but invoking close() during read() may cause exception thrown
* from read() in some cases. Because the TLS steam is a composed stream based on another socket,
* the particular kind of exception could be different depending on where exactly close error
* has been detected: it may be either SocketException or TlsException.
*
* @throws TlsException if closed (especially concurrently) or the TLS stream is corrupted
* @throws SocketException or other from the underlying socket, including concurent close
* @throws TlsException if reading data fails, TLS socket is not ready (no handshake yet)
*/
public override func read(buffer: Array<Byte>): Int64 {
if (buffer.size == 0) {
throw TlsException("The buffer is empty.")
}
let state = connected ?? return readClosed()
state.stream.read(buffer)
}
private func readClosed(): Int64 {
match (state.load()) {
case closed: SocketClosed where closed.selfClosed => 0 // self-closed stream returns 0
case _ => SocketClosed.throwAlreadyClosed()
}
}
/*
* @throws TlsException if tls socket is closed
* @throws TlsException if tls socket is not connected
* @throws TlsException if writing data fails
*/
public func write(buffer: Array<Byte>): Unit {
let state = connected ?? SocketClosed.throwAlreadyClosed()
if (buffer.size == 0) {
return
}
state.stream.write(buffer)
}
/**
* Terminates TLS connection, trying to shutdown it properly if possible.
* If invoked concurrently with read(), write() or handshake()
* these operations may be aborted with exception or without.
*
* This is reentrant and concurrent-safe.
* Closing a socket that is already closed makes no effect.
*
* An exception thrown from this function doesn't "cancel" it's termination: no need
* to invoke it again. isClosed() will report `true` in this case.
*
* @throws SocketException or other exceptions from the underlying socket
*/
public func close(): Unit {
close(selfClosed: false)
}
private func close(selfClosed!: Bool) {
_handshakeResult = None
do {
match (state.load()) {
case closed: SocketClosed where !closed.selfClosed => return
case currentState =>
if (state.compareAndSwap(currentState, SocketClosed.getInstance(selfClosed))) {
currentState.close()
return
}
}
} while (true)
}
/**
* Whether TLS socket has been explicitly closed via close() invocation.
* If invoked concurrently, it may report `true` while the shutdown is in progress yet
* even if there are still running operations, e.g. read() is shutting down at the moment.
*/
public func isClosed(): Bool {
match (state.load()) {
case closed: SocketClosed where !closed.selfClosed => true
case _ => false
}
}
public func toString(): String {
return "TlsSocket(${state.load()})"
}
public override operator func ==(other: TlsSocket): Bool {
refEq(this, other)
}
public override operator func !=(other: TlsSocket): Bool {
!(refEq(this, other))
}
public override func hashCode(): Int64 {
id.hashCode()
}
private func getVersion(): TlsVersion {
let state = connected ?? SocketClosed.throwAlreadyClosed()
let versionString = state.withStream<CString> {stream => unsafe { CJ_TLS_GetVersion(stream) }}
if (versionString.isNull()) {
throw TlsException("Unknown TLS version.")
}
// note: we don't own this CString so no need to invoke free()
match (versionString.toString()) {
case "TLSv1.2" => TlsVersion.V1_2
case "TLSv1.3" => TlsVersion.V1_3
case _ => throw TlsException("Unknown TLS version.")
}
}
private func getHostName(): ?String {
let state = connected ?? SocketClosed.throwAlreadyClosed()
state.withStream<?String> {
stream =>
let s = unsafe { CJ_TLS_GetHostName(stream) }
if (s.isNull()) {
None
} else {
s.toString()
}
// note: we don't own this CString so no need to invoke free()
// note: we convert it inside of withStream otherwise the native placement may disappear
// and it may segfault here
}
}
private func tryStartHandshake(): (SocketInHandshake, HandshakeConfig) {
while (true) {
match (state.load()) {
case s: SocketReady =>
let started = SocketInHandshake(s.socket)
if (state.compareAndSwap(s, started)) {
return (started, s.config)
}
case _: SocketInHandshake => throw TlsException("TLS handshake is already in progress.")
case _: SocketConnected => throw TlsException("TLS socket handshake is already complete.")
case _ => throw TlsException("TLS socket is already closed.")
}
}
throw TlsException("Shouldn't reach here.")
}
private static func tryGetSocket(state: TlsSocketState): ?StreamingSocket {
match (state) {
case s: SocketReady => s.socket
case s: SocketInHandshake => s.socket
case s: SocketConnected => s.socket
case _ => None
}
}
/*
* @throws TlsException while context is null, or get the tls socket stream failed,
* or tls socket client verify mode is set failed, or the CA file is set failed, or the CA file is empty,
* or the certificate chain file is set failed, or the private key file is set failed,
* or the certificate chain dose not match the private key,
* or the minVersion and maxVersion are not set together, or proto versions is set failed,
* @throws TlsException while cipher suites of TLS1.2 is set failed,
* or the TLS versions does not contain 1.2 if cipherSuitesV1_2 is not empty.
* @throws TlsException while cipher suites of TLS1.3 is set failed,
* or the TLS versions does not contain 1.3 if cipherSuitesV1_3 is not empty.
* @throws TlsException if it was not possible to select the ALPN protocol.
* @throws SocketException while socket is closed or already connected,
* or address's size is not match ipv4 or ipv6.
* or some system errors are happened.
* @throws SocketTimeoutException while socket connect timeout.
*/
private func connect(
socket: StreamingSocket,
timeout: ?Duration,
cfg: TlsClientConfig,
session: ?TlsClientSession
): SocketConnected {
let timeoutsBefore = (socket.readTimeout, socket.writeTimeout)
socket.writeTimeout = timeout
socket.readTimeout = timeout
let context = TlsContext(enableKeylog: cfg.keylogCallback.isSome())
try {
context.configureClient(cfg, session)
let stream = context.createStream(socket)
let certificateVerifyCallback: ?CertificateVerifyCallbackFunction = match (cfg.verifyMode) {
case CustomVerify(callback) => callback
case _ => None
}
let bridge = stream.createBridge(this, sessionStore: None, keylogCallback: cfg.keylogCallback,
certificateVerifyCallback: certificateVerifyCallback)
try {
Bridge.register(bridge)
if (let Some(session) <- session) {
stream.setSession(session)
}
if (let Some(hostname) <- cfg.serverName) {
if (!hostname.isEmpty()) {
stream.setRequestedHostName(hostname)
match (cfg.verifyMode) {
case TrustAll => ()
case _ => stream.setHostNameForVerify(hostname)
}
}
}
stream.handshake()
if (negotiatedSession.isNone()) {
// In TLS 1.3 sessions are negotiated after the handshake
// and could be signifincantly delayed
// or even not sent at all
// so here is a warkaround for this - this is not exactly correct
// but we keep it for now before get the proper TLS 1.3
// implementation for sessions and session tickets
if (let Some(session) <- session) {
negotiatedSession = session
}
}
} catch (e: Exception) {
Bridge.remove(bridge)
stream.close()
throw e
}
let myCertificate = cfg.clientCertificate?[0]
SocketConnected(stream, socket, myCertificate, true, bridge)
} finally {
context.close()
if (!socket.isClosed()) {
readTimeout = timeoutsBefore[0]
writeTimeout = timeoutsBefore[1]
}
}
}
private func handleAccepted(
socket: StreamingSocket,
timeout: Duration,
cfg: TlsServerConfig,
sessionContext: ?TlsServerSession
): SocketConnected {
unsafe {
let timeoutsBefore = (socket.readTimeout, socket.writeTimeout)
socket.writeTimeout = timeout
socket.readTimeout = timeout
let context = TlsContext(server: true, enableKeylog: cfg.keylogCallback.isSome())
try {
context.configureServer(cfg, sessionContext)
let stream = context.createStream(socket)
let certificateVerifyCallback: ?CertificateVerifyCallbackFunction = match (cfg.verifyMode) {
case CustomVerify(callback) => callback
case _ => None
}
let bridge = stream.createBridge(
this,
sessionStore: sessionContext?.store,
keylogCallback: cfg.keylogCallback,
certificateVerifyCallback: certificateVerifyCallback
)
try {
Bridge.register(bridge)
stream.handshake()
// The server certificate is not supposed to be null
let myCertificate = cfg.serverCertificate[0]
let socketConnected = SocketConnected(stream, socket, myCertificate, false, bridge)
let store = bridge.sessionStore ?? return socketConnected
let certsOp = getAndStorePeerCertChain(store, stream)
socketConnected.setPeerCerts(certsOp)
return socketConnected
} catch (e: Exception) {
Bridge.remove(bridge)
stream.close()
throw e
}
} finally {
context.close()
if (!socket.isClosed()) {
socket.readTimeout = timeoutsBefore[0]
socket.writeTimeout = timeoutsBefore[1]
}
}
}
}
private func handleAccepted(socket: StreamingSocket, timeout: Duration,
cfg: KeylessTlsServerConfig, sessionContext: ?TlsServerSession): SocketConnected {
unsafe {
let (cert, _) = cfg.serverCertificate
let leafCert = cert[0]
let timeoutsBefore = (socket.readTimeout, socket.writeTimeout)
socket.writeTimeout = timeout
socket.readTimeout = timeout
let context = TlsContext(server: true, enableKeylog: cfg.keylogCallback.isSome())
try {
context.configureServer(cfg, sessionContext)
let stream = context.createStream(socket)
let certificateVerifyCallback: ?CertificateVerifyCallbackFunction = match (cfg.verifyMode) {
case CustomVerify(callback) => callback
case _ => None
}
let bridge = stream.createBridge(
this,
sessionStore: sessionContext?.store,
keylogCallback: cfg.keylogCallback,
certificateVerifyCallback: certificateVerifyCallback
)
try {
Bridge.register(bridge)
stream.handshake()
// The server certificate is not supposed to be null
let myCertificate = cfg.serverCertificate[0]
let socketConnected = SocketConnected(stream, socket, myCertificate, false, bridge)
let store = bridge.sessionStore ?? return socketConnected
let certsOp = getAndStorePeerCertChain(store, stream)
socketConnected.setPeerCerts(certsOp)
return socketConnected
} catch (e: Exception) {
Bridge.remove(bridge)
stream.close()
throw e
}
} finally {
context.close()
if (!socket.isClosed()) {
socket.readTimeout = timeoutsBefore[0]
socket.writeTimeout = timeoutsBefore[1]
}
}
}
}
private func getAndStorePeerCertChain(store: SessionStore, stream: TlsRawSocket): ?Array<X509Certificate> {
var certsOp: ?Array<X509Certificate> = None
// If the session is resumed, then we have stored the peer certs (given that the original connection required a cert chain from client during establish).
// So we can get the peer certs from session store.
if (let Some(id) <- resumedSessionId) {
if (let Some(holder) <- store.tryGetHolder(SessionKey(id))) {
certsOp = holder.getCerts()
}
}
// If the connection is established with a full handshake, a new session is constructed,
// and we store the peer certs, which are received during handshake, in session store, with the session id as it's key.
// If the connection is established with a resumed session, and a new session is constructed at the same time (happens in TLS 1.3),
// we bind the peer certs, which were bound to the resumed session, to the new session.
if (let Some(id) <- newSessionId) {
if (let Some(holder) <- store.tryGetHolder(SessionKey(id))) {
certsOp = certsOp ?? stream.getPeerCertificate() // cjlint-ignore !G.EXP.03
if (let Some(certs) <- certsOp) {
holder.setCerts(certs)
}
}
}
return certsOp
}
}
class TlsSocketHandshakeResult <: TlsHandshakeResult {
let _socket: TlsSocket
init(socket: TlsSocket) {
_socket = socket
}
public prop version: TlsVersion {
get() {
_socket.tlsVersion
}
}
public prop cipherSuite: String {
get() {
_socket.cipherSuite.toString()
}
}
public prop peerCertificate: Array<Certificate> {
get() {
_socket.peerCertificate?.map({c => (c as Certificate).getOrThrow()}) ?? []
}
}
public prop session: ?TlsSession {
get() {
_socket.session ?? None
}
}
public prop alpnProtocol: String {
get() {
_socket.alpnProtocolName ?? ""
}
}
public prop serverName: String {
get() {
_socket.domain ?? ""
}
}
}