(******************************************************************************************

  The Tls_unix below was ported from  `ocaml-tls`, its `eio` subpackage, specifically from:
    * https://github.com/mirleft/ocaml-tls/blob/main/eio/tls_eio.ml
    * https://github.com/mirleft/ocaml-tls/blob/main/eio/x509_eio.ml

  under this license:

    Copyright (c) 2014, David Kaloper and Hannes Mehnert
    All rights reserved.

    Redistribution and use in source and binary forms, with or without modification,
    are permitted provided that the following conditions are met:

    * Redistributions of source code must retain the above copyright notice, this
      list of conditions and the following disclaimer.

    * Redistributions in binary form must reproduce the above copyright notice, this
      list of conditions and the following disclaimer in the documentation and/or
      other materials provided with the distribution.

    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
    ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
    WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
    DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
    ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
    LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
    ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

  *******************************************************************************************)

open Logger.Make (struct
  let namespace = [ "riot"; "net"; "ssl" ]
end)

module IO = Rio

let ( let* ) = Result.bind

type 'src t = {
  writer : 'src IO.Writer.t;
  reader : 'src IO.Reader.t;
  mutable state : [ `Active of Tls.Engine.state | `Eof | `Error of exn ];
  mutable linger : string option;
  recv_buf : bytes;
}

exception Tls_alert of Tls.Packet.alert_type
exception Tls_failure of Tls.Engine.failure

module Tls_unix = struct
  exception Read_error of Rio.io_error
  exception Write_error of Rio.io_error

  let err_to_str err = Format.asprintf "%a" Rio.pp_err err

  let read_t t dst =
    let src = IO.Bytes.with_capacity (Bytes.length dst) in
    match IO.read t.reader src with
    | Ok len ->
        trace (fun f -> f "read_t: %d/%d" len (Bytes.length dst));
        BytesLabels.blit ~src ~src_pos:0 ~dst ~dst_pos:0 ~len;
        len
    | Error (`Closed | `Eof) ->
        trace (fun f -> f "read_t: 0/%d" (Bytes.length dst));
        raise End_of_file
    | Error err ->
        trace (fun f -> f "read_t: error: %s" (err_to_str err));
        let exn = Read_error err in
        (match t.state with
        | `Error _ | `Eof -> ()
        | `Active _ -> t.state <- `Error exn);
        raise exn

  let write_t t data =
    let bufs = IO.Iovec.from_string data in
    match IO.write_owned_vectored t.writer ~bufs with
    | Ok bytes -> trace (fun f -> f "write_t: %d/%d" bytes (String.length data))
    | Error err ->
        trace (fun f -> f "write_t: error: %s" (err_to_str err));
        let exn = Write_error err in
        (match t.state with
        | `Error _ | `Eof -> ()
        | `Active _ -> t.state <- `Error exn);
        raise exn

  let try_write_t t cs =
    try write_t t cs with _ -> trace (fun f -> f "try_write_t failed")

  let inject_state tls = function
    | `Active _ -> `Active tls
    | `Eof -> `Eof
    | `Error _ as e -> e

  let rec read_react t =
    trace (fun f -> f "tls.read_react");
    let handle tls cs =
      match Tls.Engine.handle_tls tls cs with
      | Ok (state', eof, `Response resp, `Data data) ->
          trace (fun f -> f "tls.read_react->ok");
          let state' =
            match eof with
            | Some `Eof -> `Eof
            | _ -> inject_state state' t.state
          in
          t.state <- state';
          Option.iter (try_write_t t) resp;
          data
      | Error (fail, `Response resp) ->
          let state' =
            match fail with
            | `Alert a ->
                trace (fun f -> f "tls.read_react->alert");
                `Error (Tls_alert a)
            | f ->
                trace (fun f -> f "tls.read_react->error");
                `Error (Tls_failure f)
          in
          t.state <- state';
          write_t t resp;
          read_react t
    in

    match t.state with
    | `Error e -> raise e
    | `Eof -> raise End_of_file
    | `Active _ -> (
        let n = read_t t t.recv_buf in
        match (t.state, n) with
        | `Active tls, n ->
            handle tls (String.of_bytes (Bytes.sub t.recv_buf 0 n))
        | `Error e, _ -> raise e
        | `Eof, _ -> raise End_of_file)

  let rec single_read t (dst : bytes) =
    let writeout (data : string) =
      let rlen = String.length data in
      let n = min (Bytes.length dst) rlen in
      StringLabels.blit ~src:data ~src_pos:0 ~dst ~dst_pos:0 ~len:n;
      t.linger <-
        (if n < rlen then Some (String.sub data n (rlen - n)) else None);
      n
    in

    match t.linger with
    | Some res -> writeout res
    | None -> (
        match read_react t with
        | None -> single_read t dst
        | Some res -> writeout res)

  exception Tls_socket_closed

  let writev t data =
    match t.state with
    | `Error err ->
        trace (fun f -> f "writev: failed");
        raise err
    | `Eof -> raise Tls_socket_closed
    | `Active tls -> (
        match Tls.Engine.send_application_data tls data with
        | Some (tls, tlsdata) ->
            t.state <- `Active tls;
            write_t t tlsdata
        | None -> invalid_arg "tls: write: socket not ready")

  let single_write t src =
    writev t [ src ];
    let written = String.length src in
    Ok written

  let rec drain_handshake t =
    let push_linger t mcs =
      match (mcs, t.linger) with
      | None, _ -> ()
      | scs, None -> t.linger <- scs
      | Some cs, Some l -> t.linger <- Some (l ^ cs)
    in
    match t.state with
    | `Active tls when not (Tls.Engine.handshake_in_progress tls) -> t
    | _ ->
        let cs = read_react t in
        push_linger t cs;
        drain_handshake t

  let epoch t =
    match t.state with
    | `Active tls ->
        Tls.Engine.epoch tls |> Result.map_error (fun () -> `No_session_data)
    | _ -> Error `Inactive_tls_engine

  let make_client ?host ~reader ~writer config =
    let config' =
      match host with
      | None -> config
      | Some host -> Tls.Config.peer config host
    in
    let t =
      {
        state = `Eof;
        writer;
        reader;
        linger = None;
        recv_buf = Bytes.create 4_096;
      }
    in
    let tls, init = Tls.Engine.client config' in
    let t = { t with state = `Active tls } in
    write_t t init;
    drain_handshake t

  let make_server ~reader ~writer config =
    let t =
      {
        state = `Active (Tls.Engine.server config);
        writer;
        reader;
        linger = None;
        recv_buf = Bytes.create 4_096;
      }
    in
    drain_handshake t

  let to_reader : type src. src t -> src t IO.Reader.t =
   fun t ->
    let module Read = struct
      type nonrec t = src t

      let read t ?timeout:_ dst =
        match single_read t dst with
        | exception End_of_file -> Ok 0
        | len -> Ok len

      let read_vectored _t _bufs = Ok 0
    end in
    IO.Reader.of_read_src (module Read) t

  let to_writer : type src. src t -> src t IO.Writer.t =
   fun t ->
    let module Write = struct
      type nonrec t = src t

      let write t ~buf = single_write t buf

      (* TODO: This seems like not what we want *)
      let write_owned_vectored t ~bufs =
        single_write t (IO.Iovec.into_string bufs)
      (* single_write t bufs *)

      let flush _t = Ok ()
    end in
    IO.Writer.of_write_src (module Write) t
end

let negotiated_protocol t =
  let* epoch = Tls_unix.epoch t in
  Ok Tls.Core.(epoch.alpn_protocol)

let to_reader = Tls_unix.to_reader
let to_writer = Tls_unix.to_writer

let of_server_socket ?read_timeout ?send_timeout ~config sock =
  let reader, writer =
    Net.Tcp_stream.
      ( to_reader ?timeout:read_timeout sock,
        to_writer ?timeout:send_timeout sock )
  in
  Tls_unix.make_server ~reader ~writer config

let of_client_socket ?read_timeout ?send_timeout ?host ~config sock =
  let reader, writer =
    Net.Tcp_stream.
      ( to_reader ?timeout:read_timeout sock,
        to_writer ?timeout:send_timeout sock )
  in
  Tls_unix.make_client ?host ~reader ~writer config