/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* This source file is part of the Cangjie project, licensed under Apache-2.0
* with Runtime Library Exception.
*
* See https://cangjie-lang.cn/pages/LICENSE for license information.
*/
package stdx.net.tls
import std.io.IOStream
import std.sync.*
import stdx.net.tls.common.TlsException
class TlsRawSocket <: IOStream & Resource & ToString {
private static const BUFFER_SIZE = 16384
private static let DEFAULT_CLOSE_TIMEOUT = Duration.second
// should be only accessed under sslLock but it's freeSpace buffer content can be accessed under fillLock
// reading from freeSpace should be also done under sslLock as free space boundaries computation
// should be done in sync with native code running under sslLock
private let readBuffer = InputBuffer(size: BUFFER_SIZE)
// should be only accessed under sslLock but it's data buffer content can be accessed under flushLock
// reading from data property should be also done under sslLock as data space boundaries computation
// should be done in sync with native code running under sslLock
private let writeBuffer = OutputBuffer(size: BUFFER_SIZE)
// touch only under sslLock {{{
private let ssl: CPointer<Ssl>
private let context: CPointer<Ctx>
private var disposed = false
private var shutdownStarted = false
private var pendingRead = 0
private let exceptionData: CPointer<ExceptionData>
private let bytesProcessed: CPointer<UIntNative> // data bytes read/written
private let bytesConsumed: CPointer<UIntNative> // raw input bytes consumed
private let bytesProduced: CPointer<UIntNative> // raw output bytes produced
// }}} touch only under sslLock
// the following locks are defined in the hierarchy order so lock in the declaration order
// otherwise deadlock may occur
// do not reorder the following lock declarations
private let readLock = Mutex()
private let writeLock = Mutex()
private let fillLock = Mutex()
private let flushLock = Mutex()
private let sslLock = Mutex()
// this should be accessed/used WITHOUT sslLock but possibly with other locks, e.g. flushLock
private let socket: IOStream
private let server: Bool
init(
ssl: CPointer<Ssl>,
context: CPointer<Ctx>,
socket: IOStream,
server!: Bool
) {
if (ssl.isNull() || context.isNull()) {
throw IllegalArgumentException("SSL pointer shouldn't be null.")
}
this.exceptionData = ExceptionData.create()
unsafe {
let counters = LibC.malloc<UIntNative>(count: 3)
if (counters.isNull()) {
throw IllegalMemoryException("Malloc memory failed.")
}
this.bytesProcessed = counters
this.bytesConsumed = counters + 1
this.bytesProduced = counters + 2
}
this.ssl = ssl
this.context = context
this.socket = socket
this.server = server
}
func createBridge(
tlsSocket: TlsSocket,
sessionStore!: ?SessionStore,
keylogCallback!: ?KeylogCallbackFunction,
certificateVerifyCallback!: ?CertificateVerifyCallbackFunction
): Bridge {
Bridge(tlsSocket, ssl, context, sessionStore, keylogCallback, certificateVerifyCallback, server)
}
/**
* Read at least one byte to the buffer or fail if already closed. Returns 0 when EOF reached.
*
* Bypass exceptions from the underlying buffer.
* Note that due to the TLS protocol nature, this may both read and write with the underlying socket.
* In particular it means that TlsRawSocket.read() function may invoke socket.write() that is legal.
*
* This is concurrent safe but invoking close() during read() may cause exception thrown from read() in some cases
* Because the TLS steam is a composed stream based on another socket, the particular kind of exception could
* be different depending on where exactly close error has been detected: it may be either
* SocketException or TlsException.
*
* @throws TlsException if closed (especially concurrently) or the TLS stream is corrupted
* @throws SocketException or other from the underlying socket, including concurrent close
*/
public override func read(buffer: Array<Byte>): Int64 {
synchronized(readLock) {
var bytesRead = 0
while (bytesRead == 0) {
let result = tryRead(buffer)
if (result == CJTLS_EOF) {
return 0
}
if (result > 0) {
bytesRead += Int64(result)
}
}
return bytesRead
}
}
/**
* Write all the bytes from buffer or fail if already closed or shutdown.
*
* Bypass exceptions from the underlying buffer.
* Note that due to the TLS protocol nature, this may both read and write with the underlying socket.
* In particular it means that TlsRawSocket.write() function may invoke socket.read() that is legal.
*
* This is concurrent safe but invoking close() during write() may cause exception thrown from write()
* if observed in the middle of write operation.
*
* @throws TlsException if closed/shutdown or the TLS stream is corrupted
* @throws SocketException or other from the underlying socket
*/
public override func write(buffer: Array<Byte>): Unit {
synchronized(writeLock) {
var written = 0
var current = buffer
while (written < buffer.size) {
let bytesWritten = tryWrite(current)
if (bytesWritten > 0) {
written += Int64(bytesWritten)
current = buffer[written..]
}
}
}
}
/**
* Does negotiation or renegotiation.
*
* This is concurrent safe but invoking close() during handshake() may cause exception thrown from handshake().
*
* @throws TlsException if already closed or shutdown
* @throws TlsException if handshake fails
* @throws SocketException or other types from the underlying socket
*/
func handshake(): Unit {
synchronized(readLock) {
synchronized(writeLock) {
try {
while (!tryHandshake()) {
flush()
fill()
}
} catch (e: Exception) {
try {
flush()
} catch (_) { /*Nothing to do, just catch Exception*/ }
throw e
}
// we may still have remaining bytes from handshake that
// are expected by the peer so it's important to flush flush
flush()
}
}
}
/**
* This is useful to implement property accessors that don't require I/O.
* This is concurrent-safe and can be invoked with any other functions.
*
* @throws TlsException if already closed or shutdown
*/
func otherNonIO<R>(block: (CPointer<Ssl>, CPointer<ExceptionData>) -> R): R {
synchronized(sslLock) {
if (disposed || shutdownStarted) {
throwClosedException()
}
block(ssl, exceptionData)
}
}
/**
* Flush pending outgoing bytes to the underlying socket.
*
* This is concurrent-safe and can be invoked with any other functions but close() that may abort
* writing to the underlying socket.
*
* @throws TlsException when already closed
* @throws Exception or SocketException from the underlying socket
*/
public override func flush(): Unit {
synchronized(sslLock) {
if (disposed) {
throwClosedException()
}
if (shutdownStarted) {
return
}
}
// if we are lucky enough it's ok to proceed to do flush() during close()
// unless the underlying socket is already closed due to close-timeout.
flushSilent()
}
/**
* Checks whether is has been closed explicitly by close()
*/
public override func isClosed(): Bool {
synchronized(sslLock) {
return disposed
}
}
/**
* Terminates TLS connection, trying to shutdown it properly if possible.
* If invoked concurrently with read(), write(), handshake(), shutdown() or flush() - these operations
* may be aborted with exception
*
* This is reentrant and concurrent-safe.
*
* @throws SocketException or other exceptions from the underlying socket
*/
public override func close(): Unit {
if (!tryStartClose()) {
return
}
// this is usually safe to invoke close() twice but generally we have to avoid it
let resourceClosed = AtomicBool(false)
func tryCloseSocket() {
if (resourceClosed.compareAndSwap(false, true)) {
// this does close the socket and all concurrently running operations on it should/may fail
closeUnderlyingSocket()
}
}
// if we are not writing anything right now then let's terminate TLS properly
// otherwise we are trying to abort operation
if (writeLock.tryLock()) {
try {
shutdownWithTimeout(timeout: DEFAULT_CLOSE_TIMEOUT, onTimeout: tryCloseSocket)
} finally {
writeLock.unlock()
}
}
if (readLock.tryLock()) {
try {
shutdownWithTimeout(timeout: DEFAULT_CLOSE_TIMEOUT, onTimeout: tryCloseSocket)
} finally {
readLock.unlock()
}
}
closeImpl()
tryCloseSocket()
}
public override func toString(): String {
match (socket) {
case s: ToString => s.toString()
case _ => "socket"
}
}
private func tryRead(buffer: Array<Byte>): Int32 {
let result: Int32
var exception: ?TlsException = None
synchronized(sslLock) {
if (disposed) {
throwClosedException()
}
result = tryReadImpl(buffer)
if (result == CJTLS_FAIL) {
exception = unsafe { exceptionData.read().getException(fallback: "TLS failed to read data.") }
}
if (result == CJTLS_NEED_READ) {
pendingRead++
}
}
if (let Some(e) <- exception) {
try {
flushSilent()
} catch (_) { /*Noting to do with this exception*/ }
throw e
}
flushSilent() // we should be able to read() after shutdown() invoked
if (result == CJTLS_NEED_READ) {
fill()
}
return result
}
// returns: bytesRead | CJTLS_FAIL | CJTLS_EOF | CJTLS_NEED_XXX
private func tryReadImpl(dataBuffer: Array<Byte>): Int32 {
let eof: Int32 = if (readBuffer.eof) {
1
} else {
0
}
// result <- CJTLS_OK | CJTLS_FAIL | CJTLS_EOF | CJTLS_NEED_XXX
let result = tryIO<Int32>(dataBuffer) {
dataHandle, inputHandle, outputHandle =>
// do not allocate here, including boxing and string literals usages
// do not throw exceptions here
unsafe {
CJ_TLS_SslRead(ssl, dataHandle.pointer, Int32(dataHandle.array.size), inputHandle.pointer,
UIntNative(inputHandle.array.size), eof, outputHandle.pointer, UIntNative(outputHandle.array.size),
bytesProcessed, bytesConsumed, bytesProduced, exceptionData)
}
}
if (result == CJTLS_OK) {
return unsafe { Int32(bytesProcessed.read()) }
}
return result
}
private func tryWrite(data: Array<Byte>): Int32 {
let result: Int32
var exception: ?TlsException = None
synchronized(sslLock) {
if (disposed || shutdownStarted) {
throwClosedException()
}
result = tryWriteImpl(data)
if (result == CJTLS_FAIL) {
exception = unsafe { exceptionData.read().getException(fallback: "TLS failed to write data.") }
}
if (result == CJTLS_NEED_READ) {
pendingRead++
}
}
if (let Some(e) <- exception) {
try {
flush()
} catch (_) { /*Noting to do with this exception*/ }
throw e
}
flush()
if (result == CJTLS_NEED_READ) {
fill()
}
return result
}
private func tryWriteImpl(dataBuffer: Array<Byte>): Int32 {
let eof: Int32 = if (readBuffer.eof) {
1
} else {
0
}
// result <- CJTLS_OK | CJTLS_FAIL | CJTLS_EOF | CJTLS_NEED_XXX
let result = tryIO<Int32>(dataBuffer) {
dataHandle, inputHandle, outputHandle =>
// do not allocate here, including boxing and string literals usages
// do not throw exceptions here
unsafe {
CJ_TLS_SslWrite(ssl, dataHandle.pointer, Int32(dataHandle.array.size), inputHandle.pointer,
UIntNative(inputHandle.array.size), eof, outputHandle.pointer, UIntNative(outputHandle.array.size),
bytesProcessed, bytesConsumed, bytesProduced, exceptionData)
}
}
if (result == CJTLS_OK) {
return unsafe { Int32(bytesProcessed.read()) }
}
return result
}
private func tryHandshake(): Bool {
synchronized(sslLock) {
if (disposed || shutdownStarted) {
throwClosedException()
}
let result = tryHandshakeImpl()
match {
case result == CJTLS_FAIL =>
let fallbackMessage = if (server) {
"TLS handshake failed (server)"
} else {
"TLS handshake failed (client)"
}
unsafe { exceptionData.read() }.throwException(fallback: fallbackMessage)
case result == CJTLS_OK => return true
case result == CJTLS_NEED_READ => pendingRead++
case _ => ()
}
return false
}
}
private func tryHandshakeImpl(): Int32 {
let eof: Int32 = if (readBuffer.eof) {
1
} else {
0
}
tryIO<Int32>(Array<Byte>()) {
_, inputHandle, outputHandle =>
// do not allocate here, including boxing and string literals usages
// do not throw exceptions here
unsafe {
CJ_TLS_SslHandshake(ssl, inputHandle.pointer, UIntNative(inputHandle.array.size), eof,
outputHandle.pointer, UIntNative(outputHandle.array.size), bytesConsumed, bytesProduced,
exceptionData)
}
}
}
private func fill(): Unit {
synchronized(fillLock) {
while (let Some(freeSpace) <- unsafe { ifReadNeeded() }) {
let result = socketRead(freeSpace)
commitRead(result)
if (result <= 0) {
break
}
}
}
}
private func socketRead(buffer: Array<Byte>): Int64 {
try {
socket.read(buffer)
} catch (e: Exception) {
socketReadFailed(e)
}
}
private func socketReadFailed(e: Exception): Nothing {
synchronized(sslLock) {
if (!disposed && !shutdownStarted) {
// this is potentially recoverable error
pendingRead++
}
}
throw e
}
/**
* Try to claim reading job
*
* @return read buffer of the corresponding size or None if no free space or no pending read requests
*/
private unsafe func ifReadNeeded(): ?Array<Byte> {
synchronized(sslLock) {
if (disposed) {
throwClosedException()
}
if (pendingRead == 0) {
return None
}
pendingRead = 0
// these compact() and grow() are safe here because we are under sslLock AND fillLock
// so nobody can look at bytes except us and we can move and copy
if (readBuffer.mayCompact) {
readBuffer.compact()
}
if (!readBuffer.hasFreeSpace) {
readBuffer.grow()
}
// only array range computation is under sslLock and this is intentional
return readBuffer.freeSpace
}
}
private func commitRead(result: Int64) {
synchronized(sslLock) {
if (result <= 0) {
readBuffer.markEof()
} else {
readBuffer.commit(bytesRead: result)
}
}
}
/**
* Unlike regular flush() it doesn't fail if this is already closed
* this is why it's "silent"
* It however still may fail if the underlying socket is closed
* This does ignore shutdownStarted as well is it is used as part of the shutdown sequence
*/
private func flushSilent(): Unit {
synchronized(flushLock) {
unsafe {
var batches = 0
while (let Some(batch) <- getOutgoingBatch()) {
socketWrite(batch)
commitWritten(batch.size)
batches++
}
if (batches > 0) {
socket.flush()
}
}
}
}
private func socketWrite(batch: Array<Byte>): Unit {
try {
socket.write(batch)
} catch (e: Exception) {
// even when it's invoked from close() it is still safe
closeImpl()
closeUnderlyingSocket()
throw e
}
}
/**
* Steal bytes from the native outgoing buffer
* It does only steal up to the BUFFER_SIZE bytes, if there are more bytes in the native buffer then
* it will only copy BUFFER_SIZE bytes and the remaining bytes will be compacted
* (will remain in the native buffer)
*
* @return batch or None if there are no pending outgoing bytes
*/
private func getOutgoingBatch(): ?Array<Byte> {
synchronized(sslLock) {
if (disposed) { // we don't care if shutdownStarted
// here we return None instead of error
// this makes flushSilent actually silent
return None
}
if (writeBuffer.isEmpty) {
return None
}
// this is safe because we are under sslLock AND flushLock
// so nobody is looking at bytes, we can move them with no risk
if (writeBuffer.mayCompact) {
writeBuffer.compact()
}
// here only array range compuation is under the lock
// and this is intentional
return writeBuffer.data
}
}
private func commitWritten(size: Int64) {
synchronized(sslLock) {
writeBuffer.consumed(size)
}
}
private func shutdownImpl(): Unit {
unsafe {
while (!tryShutdown()) {
flushSilent()
fill()
}
flushSilent()
}
}
private unsafe func tryShutdown(): Bool {
synchronized(sslLock) {
if (disposed) {
return true
}
let result = tryShutdownImpl()
match {
case result == CJTLS_FAIL => exceptionData.read().throwException(fallback: "TLS failed to do shutdown.")
case result == CJTLS_NEED_WRITE => false
case result == CJTLS_NEED_READ =>
pendingRead++
false
case result == CJTLS_OK => true
case _ => true // other states mean we don't care
}
}
}
private func tryShutdownImpl(): Int32 {
let eof: Int32 = if (readBuffer.eof) {
1
} else {
0
}
tryIO<Int32>(Array<Byte>()) {
_, inputHandle, outputHandle =>
// do not allocate here, including boxing and string literals usages
// do not throw exceptions here
unsafe {
CJ_TLS_SslShutdown(ssl, inputHandle.pointer, UIntNative(inputHandle.array.size), eof,
outputHandle.pointer, UIntNative(outputHandle.array.size), bytesConsumed, bytesProduced,
exceptionData)
}
}
}
// this should be invoked under sslLock
// block should never throw exceptions
// block should never allocate, box or use string literals
private func tryIO<R>(
dataBuffer: Array<Byte>,
block: (CPointerHandle<Byte>, CPointerHandle<Byte>, CPointerHandle<Byte>) -> R
): R where R <: CType {
// raw - encrypted
// data - user data
if (!writeBuffer.hasFreeSpace) {
// this is safe here because we are under sslLock
// and the only place we are accessing it without this lock is "socket.write()" in flush()
// but it is safe because socket.write() will keep using the old copy and
// therefore we don't care as we do copy, not compacting
writeBuffer.grow()
}
let rawInput = readBuffer.data
let rawOutput = writeBuffer.freeSpace
unsafe {
bytesConsumed.write(0)
bytesProduced.write(0)
let dataHandle = acquireArrayRawData(dataBuffer)
let inputHandle = acquireArrayRawData(rawInput)
let outputHandle = acquireArrayRawData(rawOutput)
let result = try {
block(dataHandle, inputHandle, outputHandle)
} finally {
releaseArrayRawData(outputHandle)
releaseArrayRawData(inputHandle)
releaseArrayRawData(dataHandle)
}
readBuffer.consumed(Int64(bytesConsumed.read()))
writeBuffer.commit(bytesWritten: Int64(bytesProduced.read()))
return result
}
}
private func tryStartClose(): Bool {
synchronized(sslLock) {
if (disposed) {
return false
}
shutdownStarted = true
return true
}
}
private func shutdownWithTimeout(timeout!: Duration, onTimeout!: () -> Unit) {
// we are trying to shutdown it gracefully by doing TLS shutdown
// this is very similar to socket linger
// if the subsequent shutdown() is unable to complete in time (unable to send close notify frame),
// we do close the underlying socket
let timer = Timer.after(timeout) {
=>
onTimeout()
None
}
try {
shutdownImpl()
} catch (_) {
// ignore as we don't care during close()
} finally {
timer.cancel()
}
}
private func closeImpl(): Unit {
synchronized(sslLock) {
if (disposed) {
return
}
disposed = true
pendingRead = 0
shutdownStarted = true
unsafe {
ExceptionData.free(exceptionData)
CJ_TLS_FreeSsl(ssl)
LibC.free(bytesProcessed) // bytesConsumed and bytesProduced are also released here
}
}
}
private func closeUnderlyingSocket() {
if (let Some(resource) <- (socket as Resource)) {
resource.close()
}
}
private static func throwClosedException(): Nothing {
throw TlsException("closed")
}
}
foreign {
func CJ_TLS_DYN_SslHandshake(ssl: CPointer<Ssl>, rawInput: CPointer<Byte>, rawInputSize: UIntNative,
rawInputLast: Int32, rawOutput: CPointer<Byte>, rawOutputSize: UIntNative,
rawBytesConsumed: CPointer<UIntNative>, rawBytesProduced: CPointer<UIntNative>,
exception: CPointer<ExceptionData>, dynMsg: CPointer<DynMsg>): Int32
func CJ_TLS_DYN_SslRead(ssl: CPointer<Ssl>, dataBuffer: CPointer<Byte>, dataBufferSize: Int32,
rawInput: CPointer<Byte>, rawInputSize: UIntNative, rawInputLast: Int32, rawOutput: CPointer<Byte>,
rawOutputSize: UIntNative, dataBytesRead: CPointer<UIntNative>, rawBytesConsumed: CPointer<UIntNative>,
rawBytesProduced: CPointer<UIntNative>, exception: CPointer<ExceptionData>, dynMsg: CPointer<DynMsg>): Int32
func CJ_TLS_DYN_SslWrite(ssl: CPointer<Ssl>, dataBuffer: CPointer<Byte>, dataBufferSize: Int32,
rawInput: CPointer<Byte>, rawInputSize: UIntNative, rawInputLast: Int32, rawOutput: CPointer<Byte>,
rawOutputSize: UIntNative, dataBytesWritten: CPointer<UIntNative>, rawBytesConsumed: CPointer<UIntNative>,
rawBytesProduced: CPointer<UIntNative>, exception: CPointer<ExceptionData>, dynMsg: CPointer<DynMsg>): Int32
func CJ_TLS_DYN_SslShutdown(ssl: CPointer<Ssl>, rawInput: CPointer<Byte>, rawInputSize: UIntNative,
rawInputLast: Int32, rawOutput: CPointer<Byte>, rawOutputSize: UIntNative,
rawBytesConsumed: CPointer<UIntNative>, rawBytesProduced: CPointer<UIntNative>,
exception: CPointer<ExceptionData>, dynMsg: CPointer<DynMsg>): Int32
}
func CJ_TLS_SslHandshake(ssl: CPointer<Ssl>, rawInput: CPointer<Byte>, rawInputSize: UIntNative, rawInputLast: Int32,
rawOutput: CPointer<Byte>, rawOutputSize: UIntNative, rawBytesConsumed: CPointer<UIntNative>,
rawBytesProduced: CPointer<UIntNative>, exception: CPointer<ExceptionData>): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SslHandshake(ssl, rawInput, rawInputSize, rawInputLast, rawOutput, rawOutputSize,
rawBytesConsumed, rawBytesProduced, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SslRead(ssl: CPointer<Ssl>, dataBuffer: CPointer<Byte>, dataBufferSize: Int32, rawInput: CPointer<Byte>,
rawInputSize: UIntNative, rawInputLast: Int32, rawOutput: CPointer<Byte>, rawOutputSize: UIntNative,
dataBytesRead: CPointer<UIntNative>, rawBytesConsumed: CPointer<UIntNative>, rawBytesProduced: CPointer<UIntNative>,
exception: CPointer<ExceptionData>): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SslRead(ssl, dataBuffer, dataBufferSize, rawInput, rawInputSize, rawInputLast, rawOutput,
rawOutputSize, dataBytesRead, rawBytesConsumed, rawBytesProduced, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SslWrite(ssl: CPointer<Ssl>, dataBuffer: CPointer<Byte>, dataBufferSize: Int32, rawInput: CPointer<Byte>,
rawInputSize: UIntNative, rawInputLast: Int32, rawOutput: CPointer<Byte>, rawOutputSize: UIntNative,
dataBytesWritten: CPointer<UIntNative>, rawBytesConsumed: CPointer<UIntNative>,
rawBytesProduced: CPointer<UIntNative>, exception: CPointer<ExceptionData>): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SslWrite(ssl, dataBuffer, dataBufferSize, rawInput, rawInputSize, rawInputLast, rawOutput,
rawOutputSize, dataBytesWritten, rawBytesConsumed, rawBytesProduced, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}
func CJ_TLS_SslShutdown(ssl: CPointer<Ssl>, rawInput: CPointer<Byte>, rawInputSize: UIntNative, rawInputLast: Int32,
rawOutput: CPointer<Byte>, rawOutputSize: UIntNative, rawBytesConsumed: CPointer<UIntNative>,
rawBytesProduced: CPointer<UIntNative>, exception: CPointer<ExceptionData>): Int32 {
unsafe {
var dynMsg = DynMsg()
let res = CJ_TLS_DYN_SslShutdown(ssl, rawInput, rawInputSize, rawInputLast, rawOutput, rawOutputSize,
rawBytesConsumed, rawBytesProduced, exception, inout dynMsg)
checkDynMsg(dynMsg)
return res
}
}