/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* This source file is part of the Cangjie project, licensed under Apache-2.0
* with Runtime Library Exception.
*
* See https://cangjie-lang.cn/pages/LICENSE for license information.
*/
package stdx.net.tls
import stdx.net.tls.common.*
const NULL_BYTE = "\0"
func enableSNI(context: CPointer<Ctx>): Unit {
unsafe {
CJ_TLS_ServerEnableSNI(context)
}
}
extend TlsVersion {
func toNumericConstant(): Int32 {
match (this) {
case V1_2 => CJTLS_PROTO_VERSION_1_2
case V1_3 => CJTLS_PROTO_VERSION_1_3
case _ => CJTLS_PROTO_VERSION_BUTT
}
}
}
/*
* @throws TlsException while the minVersion and maxVersion are not set together,
* or proto versions is set failed.
*/
func setProtoVersions(ctx: CPointer<Ctx>, supportedVersions: Array<TlsVersion>): Unit {
if (supportedVersions.isEmpty()) {
return
}
let (minVersion, maxVersion) = match ((supportedVersions.contains(V1_2), supportedVersions.contains(V1_3))) {
case (false, false) => return
case (true, false) => (V1_2, V1_2)
case (false, true) => (V1_3, V1_3)
case (true, true) => (V1_2, V1_3)
}
var ret = unsafe { CJ_TLS_SetProtoVersions(ctx, minVersion.toNumericConstant(), maxVersion.toNumericConstant()) }
if (ret <= 0) {
throw TlsException("Failed to set proto versions: ${minVersion} ~ ${maxVersion}.")
}
}
/*
* @throws IllegalArgumentException while string contains null character.
*/
func checkString(str: String, sourceProperty: String): Unit {
if (str.contains(NULL_BYTE)) {
throw IllegalArgumentException("The TLS config property (${sourceProperty}) cannot contain null character!")
}
}
// try-with-resources having proper return type
func tryWith<R, Ret>(resource: R, block: (R) -> Ret): Ret where R <: Resource {
try {
block(resource)
} finally {
resource.close()
}
}
func copy(source!: CPointer<Byte>, destination!: Array<Byte>, count!: UIntNative) {
unsafe {
let acq = acquireArrayRawData(destination)
let result = memcpy_s(acq.pointer, UIntNative(destination.size), source, count)
releaseArrayRawData(acq)
if (result != 0) {
throw TlsException("Failed to copy ${count} bytes.")
}
}
}
extend<T> CPointer<T> {
func ifNotNull<R>(block: (CPointer<T>) -> R): ?R {
if (isNull()) {
None
} else {
block(this)
}
}
}
unsafe func toArray(ptr: CPointer<Byte>, size: UIntNative): Array<Byte> {
if (ptr.isNull() && size != 0) {
throw NoneValueException("Native pointer is NULL.")
}
let result = Array<Byte>(Int64(size), repeat: 0)
if (size != 0) {
copy(source: ptr, destination: result, count: size)
}
return result
}
func malloc<T>(count!: Int64 = 1, initial!: ?T = None): NativePointerResource<T> where T <: CType {
NativePointerResource<T>(count, initial)
}
func mallocCopyOf<T>(source: ?Array<T>): NativePointerResource<T> where T <: CType {
match (source) {
case Some(source) => mallocCopyOf(source)
case None => NativePointerResource<T>()
}
}
func mallocCopyOf<T>(source: Array<T>): NativePointerResource<T> where T <: CType {
let result = NativePointerResource<T>(source.size, None)
for (index in 0..source.size) {
unsafe {
result.pointer.write(index, source[index])
}
}
return result
}
class NativePointerResource<T> <: Resource where T <: CType {
private var pointer_: CPointer<T>
prop pointer: CPointer<T> {
get() {
pointer_
}
}
mut prop value: T {
get() {
unsafe { pointer_.read() }
}
set(newValue) {
unsafe { pointer_.write(newValue) }
}
}
init() {
pointer_ = CPointer()
}
init(count: Int64, sample: ?T) {
unsafe {
if (count <= 0) {
throw IllegalArgumentException("Count should be positive: ${count}.")
}
pointer_ = LibC.malloc(count: count)
if (pointer_.isNull()) {
throw TlsException("Failed to allocate memory.")
}
if (let Some(s) <- sample) {
for (index in 0..count) {
pointer_.write(index, s)
}
}
}
}
public override func isClosed(): Bool {
pointer_.isNull()
}
public override func close(): Unit {
unsafe {
LibC.free(pointer_)
pointer_ = CPointer()
}
}
}
class ArrayBuilder<T> {
private var items: Array<?T>
private var size_ = 0
init(capacity: Int64) {
items = Array<?T>(capacity, repeat: None)
}
init() {
this(10)
}
func append(item: T): Unit {
if (size_ == items.size) {
grow()
}
items[size_] = Some(item)
size_++
}
func toArray(): Array<T> {
Array<T>(size_) {
index => items[index].getOrThrow()
}
}
func isEmpty() {
size_ == 0
}
private func grow(): Unit {
let newArray = Array<?T>((items.size * 3 / 2) + 1, repeat: None)
items.copyTo(newArray, 0, 0, items.size)
items = newArray
}
}
func checkDynMsg(dynMsg: DynMsg): Unit {
if (!dynMsg.found) {
let funcName = CString(dynMsg.funcName).toString()
throw TlsException("Can not load openssl library or function ${funcName}.")
}
}