/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

package stdx.net.tls

import stdx.net.tls.common.*

const NULL_BYTE = "\0"

func enableSNI(context: CPointer<Ctx>): Unit {
    unsafe {
        CJ_TLS_ServerEnableSNI(context)
    }
}

extend TlsVersion {
    func toNumericConstant(): Int32 {
        match (this) {
            case V1_2 => CJTLS_PROTO_VERSION_1_2
            case V1_3 => CJTLS_PROTO_VERSION_1_3
            case _ => CJTLS_PROTO_VERSION_BUTT
        }
    }
}

/*
 * @throws TlsException while the minVersion and maxVersion are not set together,
 * or proto versions is set failed.
 */
func setProtoVersions(ctx: CPointer<Ctx>, supportedVersions: Array<TlsVersion>): Unit {
    if (supportedVersions.isEmpty()) {
        return
    }
    let (minVersion, maxVersion) = match ((supportedVersions.contains(V1_2), supportedVersions.contains(V1_3))) {
        case (false, false) => return
        case (true, false) => (V1_2, V1_2)
        case (false, true) => (V1_3, V1_3)
        case (true, true) => (V1_2, V1_3)
    }
    var ret = unsafe { CJ_TLS_SetProtoVersions(ctx, minVersion.toNumericConstant(), maxVersion.toNumericConstant()) }
    if (ret <= 0) {
        throw TlsException("Failed to set proto versions: ${minVersion} ~ ${maxVersion}.")
    }
}

/*
 * @throws IllegalArgumentException while string contains null character.
 */
func checkString(str: String, sourceProperty: String): Unit {
    if (str.contains(NULL_BYTE)) {
        throw IllegalArgumentException("The TLS config property (${sourceProperty}) cannot contain null character!")
    }
}

// try-with-resources having proper return type
func tryWith<R, Ret>(resource: R, block: (R) -> Ret): Ret where R <: Resource {
    try {
        block(resource)
    } finally {
        resource.close()
    }
}

func copy(source!: CPointer<Byte>, destination!: Array<Byte>, count!: UIntNative) {
    unsafe {
        let acq = acquireArrayRawData(destination)
        let result = memcpy_s(acq.pointer, UIntNative(destination.size), source, count)
        releaseArrayRawData(acq)
        if (result != 0) {
            throw TlsException("Failed to copy ${count} bytes.")
        }
    }
}

extend<T> CPointer<T> {
    func ifNotNull<R>(block: (CPointer<T>) -> R): ?R {
        if (isNull()) {
            None
        } else {
            block(this)
        }
    }
}

unsafe func toArray(ptr: CPointer<Byte>, size: UIntNative): Array<Byte> {
    if (ptr.isNull() && size != 0) {
        throw NoneValueException("Native pointer is NULL.")
    }

    let result = Array<Byte>(Int64(size), repeat: 0)
    if (size != 0) {
        copy(source: ptr, destination: result, count: size)
    }
    return result
}

func malloc<T>(count!: Int64 = 1, initial!: ?T = None): NativePointerResource<T> where T <: CType {
    NativePointerResource<T>(count, initial)
}

func mallocCopyOf<T>(source: ?Array<T>): NativePointerResource<T> where T <: CType {
    match (source) {
        case Some(source) => mallocCopyOf(source)
        case None => NativePointerResource<T>()
    }
}

func mallocCopyOf<T>(source: Array<T>): NativePointerResource<T> where T <: CType {
    let result = NativePointerResource<T>(source.size, None)
    for (index in 0..source.size) {
        unsafe {
            result.pointer.write(index, source[index])
        }
    }
    return result
}

class NativePointerResource<T> <: Resource where T <: CType {
    private var pointer_: CPointer<T>

    prop pointer: CPointer<T> {
        get() {
            pointer_
        }
    }

    mut prop value: T {
        get() {
            unsafe { pointer_.read() }
        }
        set(newValue) {
            unsafe { pointer_.write(newValue) }
        }
    }

    init() {
        pointer_ = CPointer()
    }

    init(count: Int64, sample: ?T) {
        unsafe {
            if (count <= 0) {
                throw IllegalArgumentException("Count should be positive: ${count}.")
            }
            pointer_ = LibC.malloc(count: count)
            if (pointer_.isNull()) {
                throw TlsException("Failed to allocate memory.")
            }
            if (let Some(s) <- sample) {
                for (index in 0..count) {
                    pointer_.write(index, s)
                }
            }
        }
    }

    public override func isClosed(): Bool {
        pointer_.isNull()
    }
    public override func close(): Unit {
        unsafe {
            LibC.free(pointer_)
            pointer_ = CPointer()
        }
    }
}

class ArrayBuilder<T> {
    private var items: Array<?T>
    private var size_ = 0

    init(capacity: Int64) {
        items = Array<?T>(capacity, repeat: None)
    }

    init() {
        this(10)
    }

    func append(item: T): Unit {
        if (size_ == items.size) {
            grow()
        }

        items[size_] = Some(item)
        size_++
    }

    func toArray(): Array<T> {
        Array<T>(size_) {
            index => items[index].getOrThrow()
        }
    }

    func isEmpty() {
        size_ == 0
    }

    private func grow(): Unit {
        let newArray = Array<?T>((items.size * 3 / 2) + 1, repeat: None)
        items.copyTo(newArray, 0, 0, items.size)
        items = newArray
    }
}

func checkDynMsg(dynMsg: DynMsg): Unit {
    if (!dynMsg.found) {
        let funcName = CString(dynMsg.funcName).toString()
        throw TlsException("Can not load openssl library or function ${funcName}.")
    }
}