/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

package stdx.net.tls

import stdx.net.tls.common.*
import stdx.crypto.x509.*
import stdx.crypto.common.*
import std.net.{StreamingSocket, SocketAddress}
import std.sync.*
import std.collection.HashMap

foreign {
    /* For keyless provider */
    func CJ_TLS_GetCertId(data: CPointer<Byte>, len: Int64): CString
    func CJ_TLS_GetKeylessKeyId(ctx: CPointer<Ctx>): CString
    func CJ_TLS_FreeKeylessId(id: CString): Unit
    func DYN_CJ_TLS_InitEmbeddedKeylessProvider(dynMsg: CPointer<DynMsg>): Int8
    func DYN_CJ_TLS_RegisterKeylessSignCallback(keyId: CString, cb: CKeylessSignCallback, exception: CPointer<ExceptionData>, dynMsg: CPointer<DynMsg>): Unit
    func DYN_CJ_TLS_RegisterKeylessDecryptCallback(keyId: CString, cb: CKeylessDecryptCallback, exception: CPointer<ExceptionData>, dynMsg: CPointer<DynMsg>): Unit
}

public class TlsSocket <: TlsConnection & Equatable<TlsSocket> & Hashable {
    private static let idCounter = AtomicInt64(1)

    static var keylessCallback = HashMap<String, (KeylessSignFunc, ?KeylessDecryptFunc)>()

    static let defaultTimeout: Duration = Duration.second * 30

    private let state: AtomicReference<TlsSocketState>
    private let id = idCounter.fetchAdd(1) // it is required for hashCode

    // for now it's size is 1 but in theory server may provide more sessions
    // and it actually does in TLS 1.3
    private let negotiatedSessions_ = AtomicOptionReference<Box<TlsClientSession>>()
    private var _handshakeResult: ?TlsSocketHandshakeResult = None

    var resumedSessionId: ?Array<Byte> = None
    var newSessionId: ?Array<Byte> = None

    private init(socket: StreamingSocket, handshake: HandshakeConfig) {
        this.state = AtomicReference<TlsSocketState>(SocketReady(socket, handshake))
    }

    /**
     * Creates a client TLS stream connected to the specified peer.
     * Can be optionally configured to resume a TLS session by TlsClientSession instance.
     * However having a session doesn't guarantee that it will be successfully resumed,
     * in which case a full TLS handshake will happen and handshake() still may fail for some reason.
     * For example some servers may reject old sessions after restart or after timeout.
     */
    public static func client(
        socket: StreamingSocket,
        session!: ?TlsClientSession = None,
        clientConfig!: TlsClientConfig = TlsClientConfig()
    ): TlsSocket {
        TlsSocket(socket, HandshakeConfig.Client(clientConfig, session))
    }

    /**
     * Creates a server TLS stream connected to the specified peer
     * If sessionContext is unspecified then clients will not be able to resume TLS sessions
     * and always require full TLS negotiation.
     */
    public static func server(
        socket: StreamingSocket,
        session!: ?TlsServerSession = None,
        serverConfig!: TlsServerConfig
    ): TlsSocket {
        TlsSocket(socket, HandshakeConfig.Server(serverConfig, session))
    }

    public static func server(
        socket: StreamingSocket,
        session!: ?TlsServerSession = None, // cjlint-ignore !G.FUN.02
        serverConfig!: KeylessTlsServerConfig
    ): TlsSocket {
        let (cert, _) = serverConfig.serverCertificate
        unsafe {
            let data = cert[0].encodeToDer().body
            let cp = acquireArrayRawData(data)
            let keyId = CJ_TLS_GetCertId(cp.pointer, data.size)
            releaseArrayRawData(cp)
            TlsSocket.keylessCallback.add(keyId.toString(), (serverConfig._keylessSignFunc, serverConfig._keylessDecryptFunc))
            LibC.free(keyId)
        }

        TlsSocket(socket, HandshakeConfig.KeylessServer(serverConfig, None))
    }

    /**
     * The underlying streaming socket that has been provided during construction
     * or throws an exception if already closed
     */
    public prop socket: StreamingSocket {
        get() {
            tryGetSocket(state.load()) ?? SocketClosed.throwAlreadyClosed()
        }
    }

    public prop handshakeResult: ?TlsHandshakeResult {
        get() {
            _handshakeResult ?? None
        }
    }

    mut prop negotiatedSession: ?TlsClientSession {
        get() {
            negotiatedSessions_.load()?.value
        }
        set(value) {
            match (value) {
                case Some(value) => negotiatedSessions_.store(Box(value))
                case None => throw Exception("Removing sessions is not allowed.")
            }
        }
    }

    public prop certificate: Array<X509Certificate> {
        get() {
            connected.getOrThrow(SocketClosed.alreadyClosedException).certificate ?? []
        }
    }

    // returns None if closed, fails if not negotiated yet (too early)
    private prop connected: ?SocketConnected {
        get() {
            match (state.load()) {
                case _: SocketReady => throw TlsException("TLS socket didn't pass handshake yet.")
                case _: SocketInHandshake => throw TlsException("TLS socket didn't pass handshake yet.")
                case connected: SocketConnected => connected
                case _ => None
            }
        }
    }

    public override mut prop readTimeout: ?Duration {
        get() {
            socket.readTimeout
        }
        set(newTimeout) {
            socket.readTimeout = newTimeout
        }
    }

    public override mut prop writeTimeout: ?Duration {
        get() {
            socket.writeTimeout
        }
        set(newTimeout) {
            socket.writeTimeout = newTimeout
        }
    }

    public override prop localAddress: SocketAddress {
        get() {
            socket.localAddress
        }
    }

    public override prop remoteAddress: SocketAddress {
        get() {
            socket.remoteAddress
        }
    }

    /**
     * Negotiated ALPN protocol name
     */
    prop alpnProtocolName: ?String {
        get() {
            let connected = connected ?? SocketClosed.throwAlreadyClosed()
            connected.stream.getAlpnSelected()
        }
    }

    /**
     * Negotiated TLS version.
     *
     * Works only after successful handshake()
     */
    prop tlsVersion: TlsVersion {
        get() {
            getVersion()
        }
    }

    /*
     * On client it can be captured after a successful handshake()
     * and later used to reconnect without a full handshake.
     *
     * This may be None if we are not yet successfully negotiated, session resumption is not enabled
     * or if the feature is not supported by client. Depending on the underlying implementation
     * and TLS version it may appear in the middle of handshake or much later after,
     * or not appear at all if server decides to not start/resume session ticket.
     *
     * If session is negotiated, it remains unchanged after closing socket.
     * Later negotiated sessions overwrite previous (when server send multiple sessions) so the latest
     * session is always observed.
     *
     * On server it is always None because server can't initiate a session resumption.
     */
    prop session: ?TlsClientSession {
        get() {
            negotiatedSession
        }
    }

    // negotiated domain name (SNI)
    prop domain: ?String {
        get() {
            getHostName()
        }
    }

    /**
     * Peer certificate if provided by the peer or `None`
     */
    prop peerCertificate: ?Array<X509Certificate> {
        get() {
            let connState = connected.getOrThrow(SocketClosed.alreadyClosedException)
            if (let Some(v) <- connState.peerCertificate) {
                return v
            }

            if (connState.isClient) {
                return None
            }

            // some resumed conn may be faster than full handshake conn, check again
            if (let Some(id) <- resumedSessionId && let Some(store) <- connState.bridge.sessionStore) {
                if (let Some(holder) <- store.tryGetHolder(SessionKey(id))) {
                    connState.setPeerCerts(holder.getCerts())
                    return holder.getCerts()
                }
            }

            return None
        }
    }

    /*
     * The cipher suite negotiated after a handshake.
     * Cipher suite contains encryption algorithm, hash function for message authentication, key exchange algorithm.
     *
     * @throws TlsException if handshake has not yet been performed or socket already closed.
     */
    prop cipherSuite: CipherSuite {
        get() {
            let state = connected ?? SocketClosed.throwAlreadyClosed()
            state.stream.getCipherSuite()
        }
    }

    /*
     * Does handshake optionally limited by a timeout duration
     * depending on how the socket was created and configured, this function does
     * either client or server handshake
     * this could be done only once since renegotiating handshake is not supported
     */
    public func handshake(timeout!: ?Duration = None): TlsHandshakeResult {
        let (started, handshake) = tryStartHandshake()

        let success = try {
            match (handshake) {
                case HandshakeConfig.Client(config, session) =>
                    connect(started.socket, timeout ?? defaultTimeout, config, session)
                case HandshakeConfig.Server(config, context) =>
                    handleAccepted(started.socket, timeout ?? defaultTimeout, config, context)
                case HandshakeConfig.KeylessServer(config, context) =>
                    handleAccepted(started.socket, timeout ?? defaultTimeout, config, context)
            }
        } catch (e: Exception) {
            if (isClosed()) {
                throw TlsException("TLS socket closed during handshake.")
            }
            close(selfClosed: true)
            throw e
        }

        if (!state.compareAndSwap(started, success)) {
            success.close()
            SocketClosed.throwAlreadyClosed()
        }

        let result = TlsSocketHandshakeResult(this)
        _handshakeResult = result
        result
    }

    /**
     * Returns the number of bytes read from TLS socket or 0 if closed or reached EOF
     *
     * Bypass exceptions from the underlying socket.
     * Note that due to the TLS protocol nature, this may both read and write with the underlying socket.
     * In particular it means that read() function may invoke socket.write() that is legal.
     *
     * This is concurrent safe but invoking close() during read() may cause exception thrown
     * from read() in some cases. Because the TLS steam is a composed stream based on another socket,
     * the particular kind of exception could be different depending on where exactly close error
     * has been detected: it may be either SocketException or TlsException.
     *
     * @throws TlsException if closed (especially concurrently) or the TLS stream is corrupted
     * @throws SocketException or other from the underlying socket, including concurent close
     * @throws TlsException if reading data fails, TLS socket is not ready (no handshake yet)
     */
    public override func read(buffer: Array<Byte>): Int64 {
        if (buffer.size == 0) {
            throw TlsException("The buffer is empty.")
        }
        let state = connected ?? return readClosed()
        state.stream.read(buffer)
    }

    private func readClosed(): Int64 {
        match (state.load()) {
            case closed: SocketClosed where closed.selfClosed => 0 // self-closed stream returns 0
            case _ => SocketClosed.throwAlreadyClosed()
        }
    }

    /*
     * @throws TlsException if tls socket is closed
     * @throws TlsException if tls socket is not connected
     * @throws TlsException if writing data fails
     */
    public func write(buffer: Array<Byte>): Unit {
        let state = connected ?? SocketClosed.throwAlreadyClosed()

        if (buffer.size == 0) {
            return
        }

        state.stream.write(buffer)
    }

    /**
     * Terminates TLS connection, trying to shutdown it properly if possible.
     * If invoked concurrently with read(), write() or handshake()
     * these operations may be aborted with exception or without.
     *
     * This is reentrant and concurrent-safe.
     * Closing a socket that is already closed makes no effect.
     *
     * An exception thrown from this function doesn't "cancel" it's termination: no need
     * to invoke it again. isClosed() will report `true` in this case.
     *
     * @throws SocketException or other exceptions from the underlying socket
     */
    public func close(): Unit {
        close(selfClosed: false)
    }

    private func close(selfClosed!: Bool) {
        _handshakeResult = None
        do {
            match (state.load()) {
                case closed: SocketClosed where !closed.selfClosed => return
                case currentState =>
                    if (state.compareAndSwap(currentState, SocketClosed.getInstance(selfClosed))) {
                        currentState.close()
                        return
                    }
            }
        } while (true)
    }

    /**
     * Whether TLS socket has been explicitly closed via close() invocation.
     * If invoked concurrently, it may report `true` while the shutdown is in progress yet
     * even if there are still running operations, e.g. read() is shutting down at the moment.
     */
    public func isClosed(): Bool {
        match (state.load()) {
            case closed: SocketClosed where !closed.selfClosed => true
            case _ => false
        }
    }

    public func toString(): String {
        return "TlsSocket(${state.load()})"
    }

    public override operator func ==(other: TlsSocket): Bool {
        refEq(this, other)
    }

    public override operator func !=(other: TlsSocket): Bool {
        !(refEq(this, other))
    }

    public override func hashCode(): Int64 {
        id.hashCode()
    }

    private func getVersion(): TlsVersion {
        let state = connected ?? SocketClosed.throwAlreadyClosed()
        let versionString = state.withStream<CString> {stream => unsafe { CJ_TLS_GetVersion(stream) }}
        if (versionString.isNull()) {
            throw TlsException("Unknown TLS version.")
        }
        // note: we don't own this CString so no need to invoke free()

        match (versionString.toString()) {
            case "TLSv1.2" => TlsVersion.V1_2
            case "TLSv1.3" => TlsVersion.V1_3
            case _ => throw TlsException("Unknown TLS version.")
        }
    }

    private func getHostName(): ?String {
        let state = connected ?? SocketClosed.throwAlreadyClosed()
        state.withStream<?String> {
            stream =>
            let s = unsafe { CJ_TLS_GetHostName(stream) }
            if (s.isNull()) {
                None
            } else {
                s.toString()
            }
            // note: we don't own this CString so no need to invoke free()
            // note: we convert it inside of withStream otherwise the native placement may disappear
            // and it may segfault here
        }
    }

    private func tryStartHandshake(): (SocketInHandshake, HandshakeConfig) {
        while (true) {
            match (state.load()) {
                case s: SocketReady =>
                    let started = SocketInHandshake(s.socket)
                    if (state.compareAndSwap(s, started)) {
                        return (started, s.config)
                    }
                case _: SocketInHandshake => throw TlsException("TLS handshake is already in progress.")
                case _: SocketConnected => throw TlsException("TLS socket handshake is already complete.")
                case _ => throw TlsException("TLS socket is already closed.")
            }
        }

        throw TlsException("Shouldn't reach here.")
    }

    private static func tryGetSocket(state: TlsSocketState): ?StreamingSocket {
        match (state) {
            case s: SocketReady => s.socket
            case s: SocketInHandshake => s.socket
            case s: SocketConnected => s.socket
            case _ => None
        }
    }

    /*
     * @throws TlsException while context is null, or get the tls socket stream failed,
     * or tls socket client verify mode is set failed, or the CA file is set failed, or the CA file is empty,
     * or the certificate chain file is set failed, or the private key file is set failed,
     * or the certificate chain dose not match the private key,
     * or the minVersion and maxVersion are not set together, or proto versions is set failed,
     * @throws TlsException while cipher suites of TLS1.2 is set failed,
     * or the TLS versions does not contain 1.2 if cipherSuitesV1_2 is not empty.
     * @throws TlsException while cipher suites of TLS1.3 is set failed,
     * or the TLS versions does not contain 1.3 if cipherSuitesV1_3 is not empty.
     * @throws TlsException if it was not possible to select the ALPN protocol.
     * @throws SocketException while socket is closed or already connected,
     * or address's size is not match ipv4 or ipv6.
     * or some system errors are happened.
     * @throws SocketTimeoutException while socket connect timeout.
     */
    private func connect(
        socket: StreamingSocket,
        timeout: ?Duration,
        cfg: TlsClientConfig,
        session: ?TlsClientSession
    ): SocketConnected {
        let timeoutsBefore = (socket.readTimeout, socket.writeTimeout)
        socket.writeTimeout = timeout
        socket.readTimeout = timeout

        let context = TlsContext(enableKeylog: cfg.keylogCallback.isSome())
        try {
            context.configureClient(cfg, session)

            let stream = context.createStream(socket)
            let certificateVerifyCallback: ?CertificateVerifyCallbackFunction = match (cfg.verifyMode) {
                case CustomVerify(callback) => callback
                case _ => None
            }
            let bridge = stream.createBridge(this, sessionStore: None, keylogCallback: cfg.keylogCallback,
                certificateVerifyCallback: certificateVerifyCallback)
            try {
                Bridge.register(bridge)
                if (let Some(session) <- session) {
                    stream.setSession(session)
                }
                if (let Some(hostname) <- cfg.serverName) {
                    if (!hostname.isEmpty()) {
                        stream.setRequestedHostName(hostname)
                        match (cfg.verifyMode) {
                            case TrustAll => ()
                            case _ => stream.setHostNameForVerify(hostname)
                        }
                    }
                }

                stream.handshake()

                if (negotiatedSession.isNone()) {
                    // In TLS 1.3 sessions are negotiated after the handshake
                    // and could be signifincantly delayed
                    // or even not sent at all
                    // so here is a warkaround for this - this is not exactly correct
                    // but we keep it for now before get the proper TLS 1.3
                    //  implementation for sessions and session tickets
                    if (let Some(session) <- session) {
                        negotiatedSession = session
                    }
                }
            } catch (e: Exception) {
                Bridge.remove(bridge)
                stream.close()
                throw e
            }

            let myCertificate = cfg.clientCertificate?[0]

            SocketConnected(stream, socket, myCertificate, true, bridge)
        } finally {
            context.close()
            if (!socket.isClosed()) {
                readTimeout = timeoutsBefore[0]
                writeTimeout = timeoutsBefore[1]
            }
        }
    }

    private func handleAccepted(
        socket: StreamingSocket,
        timeout: Duration,
        cfg: TlsServerConfig,
        sessionContext: ?TlsServerSession
    ): SocketConnected {
        unsafe {
            let timeoutsBefore = (socket.readTimeout, socket.writeTimeout)
            socket.writeTimeout = timeout
            socket.readTimeout = timeout

            let context = TlsContext(server: true, enableKeylog: cfg.keylogCallback.isSome())
            try {
                context.configureServer(cfg, sessionContext)

                let stream = context.createStream(socket)
                let certificateVerifyCallback: ?CertificateVerifyCallbackFunction = match (cfg.verifyMode) {
                    case CustomVerify(callback) => callback
                    case _ => None
                }
                let bridge = stream.createBridge(
                    this,
                    sessionStore: sessionContext?.store,
                    keylogCallback: cfg.keylogCallback,
                    certificateVerifyCallback: certificateVerifyCallback
                )
                try {
                    Bridge.register(bridge)
                    stream.handshake()
                    // The server certificate is not supposed to be null
                    let myCertificate = cfg.serverCertificate[0]
                    let socketConnected = SocketConnected(stream, socket, myCertificate, false, bridge)

                    let store = bridge.sessionStore ?? return socketConnected
                    let certsOp = getAndStorePeerCertChain(store, stream)
                    socketConnected.setPeerCerts(certsOp)

                    return socketConnected
                } catch (e: Exception) {
                    Bridge.remove(bridge)
                    stream.close()
                    throw e
                }
            } finally {
                context.close()
                if (!socket.isClosed()) {
                    socket.readTimeout = timeoutsBefore[0]
                    socket.writeTimeout = timeoutsBefore[1]
                }
            }
        }
    }

    private func handleAccepted(socket: StreamingSocket, timeout: Duration,
        cfg: KeylessTlsServerConfig, sessionContext: ?TlsServerSession): SocketConnected {
        unsafe {
            let (cert, _) = cfg.serverCertificate
            let leafCert = cert[0]
            let timeoutsBefore = (socket.readTimeout, socket.writeTimeout)
            socket.writeTimeout = timeout
            socket.readTimeout = timeout

            let context = TlsContext(server: true, enableKeylog: cfg.keylogCallback.isSome())
            try {
                context.configureServer(cfg, sessionContext)

                let stream = context.createStream(socket)
                let certificateVerifyCallback: ?CertificateVerifyCallbackFunction = match (cfg.verifyMode) {
                    case CustomVerify(callback) => callback
                    case _ => None
                }
                let bridge = stream.createBridge(
                    this,
                    sessionStore: sessionContext?.store,
                    keylogCallback: cfg.keylogCallback,
                    certificateVerifyCallback: certificateVerifyCallback
                )
                try {
                    Bridge.register(bridge)

                    stream.handshake()
                    // The server certificate is not supposed to be null
                    let myCertificate = cfg.serverCertificate[0]
                    let socketConnected = SocketConnected(stream, socket, myCertificate, false, bridge)

                    let store = bridge.sessionStore ?? return socketConnected
                    let certsOp = getAndStorePeerCertChain(store, stream)
                    socketConnected.setPeerCerts(certsOp)

                    return socketConnected
                } catch (e: Exception) {
                    Bridge.remove(bridge)
                    stream.close()
                    throw e
                }
            } finally {
                context.close()
                if (!socket.isClosed()) {
                    socket.readTimeout = timeoutsBefore[0]
                    socket.writeTimeout = timeoutsBefore[1]
                }
            }
        }
    }

    private func getAndStorePeerCertChain(store: SessionStore, stream: TlsRawSocket): ?Array<X509Certificate> {
        var certsOp: ?Array<X509Certificate> = None
        // If the session is resumed, then we have stored the peer certs (given that the original connection required a cert chain from client during establish).
        // So we can get the peer certs from session store.
        if (let Some(id) <- resumedSessionId) {
            if (let Some(holder) <- store.tryGetHolder(SessionKey(id))) {
                certsOp = holder.getCerts()
            }
        }
        // If the connection is established with a full handshake, a new session is constructed,
        // and we store the peer certs, which are received during handshake, in session store, with the session id as it's key.
        // If the connection is established with a resumed session, and a new session is constructed at the same time (happens in TLS 1.3),
        // we bind the peer certs, which were bound to the resumed session, to the new session.
        if (let Some(id) <- newSessionId) {
            if (let Some(holder) <- store.tryGetHolder(SessionKey(id))) {
                certsOp = certsOp ?? stream.getPeerCertificate() // cjlint-ignore !G.EXP.03
                if (let Some(certs) <- certsOp) {
                    holder.setCerts(certs)
                }
            }
        }
        return certsOp
    }
}

class TlsSocketHandshakeResult <: TlsHandshakeResult {
    let _socket: TlsSocket

    init(socket: TlsSocket) {
        _socket = socket
    }

    public prop version: TlsVersion {
        get() {
            _socket.tlsVersion
        }
    }

    public prop cipherSuite: String {
        get() {
            _socket.cipherSuite.toString()
        }
    }

    public prop peerCertificate: Array<Certificate> {
        get() {
           _socket.peerCertificate?.map({c => (c as Certificate).getOrThrow()}) ?? []
        }
    }

    public prop session: ?TlsSession {
        get() {
            _socket.session ?? None
        }
    }

    public prop alpnProtocol: String {
        get() {
            _socket.alpnProtocolName ?? ""
        }
    }

    public prop serverName: String {
        get() {
            _socket.domain ?? ""
        }
    }
}