/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

package stdx.net.tls

import std.sync.*
import std.time.DateTime
import std.collection.concurrent.{ConcurrentHashMap, ArrayBlockingQueue}
import stdx.crypto.x509.*
import stdx.encoding.hex.toHexString
import stdx.net.tls.common.*

/**
 * Represents a client session established. This type is opaque and it's internals are implementation specific.
 * Instances of this type are never created by users: instead one should borrow instance from a successfully
 * negotiated TlsSocket after handshake() invocation.
 */
public class TlsClientSession <: TlsSession & Equatable<TlsClientSession> & ToString & Hashable {
    private TlsClientSession(
        let holder: SessionHolder,
        let id: Array<Byte>
    ) {}

    /**
     * This type is not constructed directly. Use TlsSocket.session property to get a session instance.
     */
    init(pointer: CPointer<NativeSession>) {
        // it's important to get it first, before creating a holder
        id = unsafe { getSessionId(pointer) }

        holder = SessionHolder(pointer)
    }

    public override operator func ==(other: TlsClientSession): Bool {
        id == other.id
    }

    public override operator func !=(other: TlsClientSession): Bool {
        id != other.id
    }

    public override func toString(): String {
        "TlsClientSession(${toHexString(id)})"
    }

    public override func hashCode(): Int64 {
        SessionKey.hashOf(id)
    }

    private static unsafe func getSessionId(session: CPointer<NativeSession>): Array<Byte> {
        var result: Array<Byte> = []
        try (array = malloc<CPointer<Byte>>(initial: CPointer()), length = malloc<UIntNative>(initial: 0)) {
            CJ_TLS_GetSessionId(session, array.pointer, length.pointer)
            if (!array.value.isNull() && length.value > 0) {
                result = toArray(array.value, length.value)
            }
        }
        return result
    }
}

extend TlsRawSocket {
    func setSession(session: TlsClientSession): Unit {
        otherNonIO<Unit> {
            ssl, _ => session.holder.setSessoinAt(ssl)
        }
    }
}

class SessionHolder {
    private let pointer: CPointer<NativeSession>
    private static let counter = AtomicInt64(1)
    private let refCount = RefCounter() // secondary refcount for ~init
    private var certsOp: ?Array<X509Certificate> = None

    let id = counter.fetchAdd(1)

    init(pointer: CPointer<NativeSession>) {
        if (pointer.isNull()) {
            throw IllegalArgumentException("Session pointer shouldn't be NULL.")
        }

        // the increment should be before assigning to the field
        // but after checking the pointer for NULL
        unsafe { CJ_TLS_IncrementUse(pointer) }

        this.pointer = pointer
    }

    ~init() {
        // this may be invoked right in the middle of setSessionAt
        // so we need special care here
        // we can't use any sync primitives like mutexes in finalizers
        // and can only rely on atomics
        // that are used here for manual refcounting
        if (refCount.shutdown()) {
            // no concurrent invocations setSessionAt were detected
            free(pointer)
        }
    }

    func setCerts(certs: Array<X509Certificate>): Unit {
        certsOp = certs
    }

    func getCerts(): ?Array<X509Certificate> {
        certsOp
    }

    func setSessoinAt(stream: CPointer<Ssl>): Unit {
        withNativeSession<Unit> {
            pointer => unsafe {
                // put session to the SSL* object
                // and increment it's internal ref-counter (NOT our refCount)
                if (CJ_TLS_SetSession(stream, pointer) != 0) {
                    throw TlsException("Failed to set session to context.")
                }
            }
        }
    }

    func withNativeSession<R>(block: (CPointer<NativeSession>) -> R): R {
        unsafe {
            // we increment refcount BEFORE reading "pointer" from "this"
            refCount.enter()
            // here we have "pointer" field access AFTER the barrier (inside of enter())
            // so ~init can't be triggered before enter() invocation
            let pointer = pointer
            // here we touched "this" last time: from this point ~init may happen at any time
            // however the secondary refCount will prevent it from destroying the session
            // so if ~init will happened while we are here, session free will be postponed
            // until we do leave() in the finally block

            try {
                block(pointer)

            // we would use "blackhole" intrinsic to avoid GC of "this"
            // with concurrent ~init invocation while we are running this function
            // so we could avoid all this terrible atomic staff and replace everything with just
            // a single line fix like this:
            // @blackhole(this) // this should be done AFTER using session pointer
            } finally {
                if (refCount.leave()) {
                    // the finalizer has been already triggered
                    // and we are the last concurrent setSessionAt invocation
                    // so we are doing free here
                    free(pointer)

                    // note that it's unlikely will cause session disposal (see the comment in free)
                }
            }
        }
    }

    /**
     * Returns SSL_SESSION* pointer having incremented use-count
     */
    func getIncremented(): CPointer<NativeSession> {
        unsafe {
            refCount.enter()
            let pointer = pointer
            CJ_TLS_IncrementUse(pointer)
            if (refCount.leave()) {
                free(pointer)
            }
            return pointer
        }
    }

    private static func free(pointer: CPointer<NativeSession>) {
        unsafe {
            // remember that it's not necessary means actual session disposal
            // because there is a builtin ref-counter inside of the session object
            // so here we are just decrementing the counter that MAY lead to disposal
            CJ_TLS_DeleteSession(pointer)
        }
    }
}

/*
 * @throws TlsException while the length of sessionId exceed 32 bytes,
 * or the sessionId is set failed.
 */
func setServerSessionId(serverCtx: CPointer<Ctx>, sessionId: String): Unit {
    let sessionIdSize = sessionId.size
    if (sessionIdSize > 32) {
        throw TlsException("The length of the session ID cannot exceed 32 bytes.")
    }
    var cStr = unsafe { LibC.mallocCString(sessionId) }
    try {
        let ret = unsafe { CJ_TLS_SetSessionIdContext(serverCtx, cStr, UInt32(sessionIdSize)) }
        if (ret <= 0) {
            throw TlsException("Failed to set tls socket server session ID.")
        }
    } finally {
        unsafe { LibC.free(cStr) }
    }
}

@C
struct NativeSession {}

foreign {
    func CJ_TLS_DYN_SetSessionCallback(
        put: CFunc<(CPointer<Ssl>, CPointer<Byte>, UIntNative, CPointer<NativeSession>) -> Unit>,
        remove: CFunc<(CPointer<Ctx>, CPointer<Byte>, UIntNative, CPointer<NativeSession>) -> Unit>,
        find: CFunc<(CPointer<Ssl>, CPointer<Byte>, UIntNative) -> CPointer<NativeSession>>,
        assign: CFunc<(CPointer<Ssl>, CPointer<NativeSession>) -> Unit>,
        dynMsgPtr: CPointer<DynMsg>
    ): Unit

    func CJ_TLS_DYN_SetSessionIdContext(ctx: CPointer<Ctx>, sidCtx: CString, sidCtxLen: UInt32,
        dynMsgPtr: CPointer<DynMsg>): Int32

    func CJ_TLS_DYN_DeleteSession(pointer: CPointer<NativeSession>, dynMsgPtr: CPointer<DynMsg>): Unit

    func CJ_TLS_DYN_SetSession(context: CPointer<Ssl>, session: CPointer<NativeSession>, dynMsgPtr: CPointer<DynMsg>): Int32

    func CJ_TLS_DYN_IncrementUse(session: CPointer<NativeSession>, dynMsgPtr: CPointer<DynMsg>): Unit

    func CJ_TLS_DYN_GetSessionId(
        session: CPointer<NativeSession>,
        data: CPointer<CPointer<Byte>>,
        length: CPointer<UIntNative>,
        dynMsgPtr: CPointer<DynMsg>
    ): Unit

    func CJ_TLS_DYN_AddSession(context: CPointer<Ctx>, session: CPointer<NativeSession>, dynMsgPtr: CPointer<DynMsg>): Int32
}

func CJ_TLS_SetSessionCallback(
    put: CFunc<(CPointer<Ssl>, CPointer<Byte>, UIntNative, CPointer<NativeSession>) -> Unit>,
    remove: CFunc<(CPointer<Ctx>, CPointer<Byte>, UIntNative, CPointer<NativeSession>) -> Unit>,
    find: CFunc<(CPointer<Ssl>, CPointer<Byte>, UIntNative) -> CPointer<NativeSession>>,
    assign: CFunc<(CPointer<Ssl>, CPointer<NativeSession>) -> Unit>
): Unit {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetSessionCallback(put, remove, find, assign, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetSessionIdContext(ctx: CPointer<Ctx>, sidCtx: CString, sidCtxLen: UInt32): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetSessionIdContext(ctx, sidCtx, sidCtxLen, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_DeleteSession(pointer: CPointer<NativeSession>): Unit {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_DeleteSession(pointer, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_SetSession(context: CPointer<Ssl>, session: CPointer<NativeSession>): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_SetSession(context, session, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_IncrementUse(session: CPointer<NativeSession>): Unit {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_IncrementUse(session, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_GetSessionId(
    session: CPointer<NativeSession>,
    data: CPointer<CPointer<Byte>>,
    length: CPointer<UIntNative>
): Unit {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_GetSessionId(session, data, length, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

func CJ_TLS_AddSession(context: CPointer<Ctx>, session: CPointer<NativeSession>): Int32 {
    unsafe {
        var dynMsg = DynMsg()
        let res = CJ_TLS_DYN_AddSession(context, session, inout dynMsg)
        checkDynMsg(dynMsg)
        return res
    }
}

class RefCounter {
    private static let Terminated = Int64.Min
    private let state = AtomicInt64(0) // 0 idle, positive - in use refcounter, negative shutdown+in-use

    func enter(): Unit {
        updateAndGet {
            before =>
            let after = match {
                case before == Terminated => throw IllegalStateException("The session has been already collected.")
                case before < 0 => before - 1
                case _ => before + 1
            }
            if (after == Terminated) {
                throw IllegalStateException("Too many users.")
            }
            after
        }
    }

    func leave(): Bool {
        let after = updateAndGet {
            before => match {
                case before == -1 => Terminated
                case before < 0 => before + 1
                case _ => before - 1
            }
        }

        return after == Terminated
    }

    func shutdown(): Bool {
        let after = updateAndGet {
            before => match {
                case before == 0 => Terminated
                case before < 0 => None
                case _ => -before
            }
        }

        return after == Terminated
    }

    // this could be a part of stdlib (std/sync) but it can't as lambdas are too slow
    // here we don't care because it's not a hot-path
    private func updateAndGet(map: (Int64) -> ?Int64): ?Int64 {
        var after: Int64 = 0

        do {
            let before = state.load()
            after = map(before) ?? return None

            if (state.compareAndSwap(before, after)) {
                break
            }
        } while (true)

        return after
    }
}

// we could distinguish SessionStore instances and make them configurable
// but for now we use a single session store for the whole app

@C
func CJ_TLS_put_session(
    ssl: CPointer<Ssl>,
    id: CPointer<Byte>,
    idLength: UIntNative,
    session: CPointer<NativeSession>
): Unit {
    if (let Some(bridge) <- Bridge.findByStream(ssl)) {
        if (let Some(sessionStore) <- bridge.sessionStore) {
            bridge.socket.newSessionId = unsafe { toArray(id, idLength) }
            sessionStore.put(SessionKey.copyFrom(id, idLength), session)
        } else if (!bridge.server) {
            bridge.socket.negotiatedSession = TlsClientSession(session)
        }
    }
}

@C
func CJ_TLS_assign_session(
    ssl: CPointer<Ssl>,
    session: CPointer<NativeSession>
): Unit {
    // at this point SSL's session refcount is preincremented
    // so we should return 1 to keep it incremented or 0 to decrement it again

    if (let Some(bridge) <- Bridge.findByStream(ssl)) {
        if (!bridge.server) {
            bridge.socket.negotiatedSession = TlsClientSession(session)
        }
    }
}

@C
func CJ_TLS_remove_session(
    ctx: CPointer<Ctx>,
    id: CPointer<Byte>,
    idLength: UIntNative,
    _: CPointer<NativeSession>
): Unit {
    if (let Some(bridge) <- Bridge.findByContext(ctx)) {
        if (let Some(sessionStore) <- bridge.sessionStore) {
            sessionStore.remove(SessionKey.copyFrom(id, idLength))
        }
    }
}

@C
func CJ_TLS_find_session(
    ssl: CPointer<Ssl>,
    id: CPointer<Byte>,
    idLength: UIntNative
): CPointer<NativeSession> {
    try {
        let bridge = Bridge.findByStream(ssl) ?? return CPointer<NativeSession>()
        let sessionStore = bridge.sessionStore ?? return CPointer<NativeSession>()
        return match (sessionStore.find(SessionKey.copyFrom(id, idLength))) {
            case Some(nativeSession) =>
                bridge.socket.resumedSessionId = unsafe { toArray(id, idLength) }
                nativeSession
            case _ => CPointer<NativeSession>()
        }
    } catch (_: Exception) {
        // we should not throw Exception from cj code to c code
        return CPointer<NativeSession>()
    }
}

// we only need this just because Array<Byte> is not Hashable and can't be
struct SessionKey <: Hashable & Equatable<SessionKey> {
    SessionKey(let bytes: Array<Byte>) {}

    public override func hashCode(): Int64 {
        hashOf(bytes)
    }

    public operator func ==(other: SessionKey): Bool {
        bytes == other.bytes
    }

    public operator func !=(other: SessionKey): Bool {
        bytes != other.bytes
    }

    static func copyFrom(ptr: CPointer<Byte>, length: UIntNative): SessionKey {
        unsafe { SessionKey(toArray(ptr, length)) }
    }

    static func hashOf(bytes: Array<Byte>): Int64 {
        var hashser = DefaultHasher()
        for (b in bytes) {
            hashser.write(b)
        }
        return hashser.finish()
    }
}

struct WaitQueueEntry {
    private let createdAt = DateTime.nowUTC()

    WaitQueueEntry(let key: SessionKey) {}

    prop age: Duration {
        get() {
            DateTime.nowUTC() - createdAt
        }
    }

    func remaining(timeout: Duration): Duration {
        timeout - age
    }
}

struct SessionStore {
    private let sessions: ConcurrentHashMap<SessionKey, SessionHolder>
    private let queue: ArrayBlockingQueue<WaitQueueEntry>

    SessionStore(
        let capacity!: Int64,
        let timeout!: Duration
    ) {
        sessions = ConcurrentHashMap(capacity * 2)
        queue = ArrayBlockingQueue(capacity)
    }

    func put(key: SessionKey, session: CPointer<NativeSession>): TlsClientSession {
        let apiSession = TlsClientSession(session)
        sessions.add(key, apiSession.holder)

        enqueue(key)
        ensureWatchdog()

        return apiSession
    }

    func tryGetHolder(key: SessionKey): ?SessionHolder {
        return sessions.get(key)
    }

    func remove(key: SessionKey) {
        sessions.remove(key)
        // it is probably still remaining in the queue but we don't care
        // as it will be pushed away at some point later by the watchdog
        // or by another put operation
    }

    /*
     * Returns a session with incremented refcount
     * This is important to increment it before, otherwise a session could be released concurrently
     * and the invoker of function MUST handle this extra-increment after applying a session
     */
    func find(key: SessionKey): ?CPointer<NativeSession> {
        sessions.get(key)?.getIncremented()
    }

    func clear() {
        // missing queue clear
        while (let Some(_) <- queue.tryRemove()) {}

        // missing sessions clear
        for ((k, _) in sessions) {
            sessions.remove(k)
        }
    }

    private func enqueue(key: SessionKey) {
        while (!queue.tryAdd(WaitQueueEntry(key))) {
            if (let Some(old) <- queue.tryRemove()) {
                sessions.remove(old.key)
            }
        }
    }

    private func ensureWatchdog() {
        if (let Some(next) <- queue.peek()) {
            let timeToWait = next.remaining(timeout)
            Timer.after(timeToWait) {
                // here we may actually remove wrong session
                // because of concurrency with put doing remove if the queue if full
                // but we don't care as lost sessions is not dramatic failure but just
                // a little bit unfortunate
                // we need queue.tryDequeue(box) to workaround it
                if (let Some(removed) <- queue.tryRemove()) {
                    // even if we removed wrong session from the queue
                    // we must remove it from the map to avoid leak
                    sessions.remove(removed.key)
                }
                return None
            }
        }
    }
}

/**
 * When a client attempts to resume a session, both counterparts have
 * to ensure that they are resuming session with a legitime peer.
 * At server side there is a so called session context doing both things: verifying
 * clients and providing info to clients that there is still the same server instance it's connecting to.
 * For a stateful sessions they are literally stored in this context instance so an instance of a session context
 * should be shared between server TlsSocket instances inside of a single server node.
 */
public class TlsServerSession <: TlsSession & Equatable<TlsServerSession> & ToString {
    let name: String
    let store = SessionStore(capacity: 600, timeout: Duration.hour)

    init(name: String) {
        this.name = name
    }

    static init() {
        unsafe {
            CJ_TLS_SetSessionCallback(
                CJ_TLS_put_session,
                CJ_TLS_remove_session,
                CJ_TLS_find_session,
                CJ_TLS_assign_session
            )
        }
    }

    ~init() {
        store.clear()
    }

    /**
     * Creates a session context having name as sid and having the default session caching (usually in-memory).
     * The name is used to distinguish TLS servers so clients do rely on this name to avoid accidental attempts to
     * resume connection to the wrong server. It is not necessarily to use crypto-secure names here as the underlying
     * implementation does the job. Two session contexts returned from this function with the same name could be
     * non-equal and not guaranteed to be replaceable despite the same name they are created from.
     * So a server instance should create a single session context for the whole lifetime and
     * use it with every TlsSocket.server() invocation.
     */
    public static func fromName(name: String): TlsServerSession {
        TlsServerSession(name)
    }

    public override operator func ==(other: TlsServerSession): Bool {
        this.name == other.name
    }

    public override operator func !=(other: TlsServerSession): Bool {
        this.name != other.name
    }

    public override func toString(): String {
        "TlsServerSession(${name})"
    }
}