open Riot
let () = Mirage_crypto_rng_unix.initialize (module Mirage_crypto_rng.Fortuna)
type Message.t += Received of string
(* rudimentary tcp echo server *)
let server port socket =
Logger.debug (fun f -> f "Started server on %d" port);
process_flag (Trap_exit true);
let conn, addr = Net.Tcp_listener.accept socket |> Result.get_ok in
Logger.debug (fun f ->
f "Accepted client %a (%a)" Net.Addr.pp addr Net.Socket.pp conn);
let close () =
Net.Tcp_stream.close conn;
Logger.debug (fun f ->
f "Closed client %a (%a)" Net.Addr.pp addr Net.Socket.pp conn)
in
let certificates =
let crt =
let buf = IO.Buffer.with_capacity 4_096 in
let _len =
File.open_read "fixtures/tls.crt"
|> File.to_reader |> IO.read_to_end ~buf |> Result.get_ok
in
let cs = IO.Buffer.contents buf in
X509.Certificate.decode_pem_multiple cs |> Result.get_ok
in
let pk =
let buf = IO.Buffer.with_capacity 4_096 in
let file = File.open_read "fixtures/tls.key" in
let reader = File.to_reader file in
assert (Result.is_ok (IO.read_to_end ~buf reader));
let cs = IO.Buffer.contents buf in
X509.Private_key.decode_pem cs |> Result.get_ok
in
`Single (crt, pk)
in
let config = Tls.Config.server ~certificates () |> Result.get_ok in
let ssl = SSL.of_server_socket ~config conn in
let reader, writer = SSL.(to_reader ssl, to_writer ssl) in
let buf = IO.Bytes.with_capacity 1024 in
let rec echo () =
Logger.debug (fun f ->
f "Reading from client client %a (%a)" Net.Addr.pp addr Net.Socket.pp
conn);
match IO.read reader buf with
| Ok len -> (
Logger.debug (fun f -> f "Server received %d bytes" len);
let bufs = IO.Iovec.(of_bytes buf |> sub ~len) in
match IO.write_owned_vectored ~bufs writer with
| Ok bytes ->
Logger.debug (fun f -> f "Server sent %d bytes" bytes);
echo ()
| Error (`Closed | `Timeout | `Process_down) -> close ()
| Error err ->
Logger.error (fun f -> f "error %a" IO.pp_err err);
close ())
| Error err ->
Logger.error (fun f -> f "error %a" IO.pp_err err);
close ()
in
echo ()
let client server_port main =
let addr = Net.Addr.(tcp loopback server_port) in
let conn = Net.Tcp_stream.connect addr |> Result.get_ok in
Logger.debug (fun f -> f "Connected to server on %d" server_port);
let host =
let domain_name = Domain_name.of_string_exn "localhost" in
Domain_name.host_exn domain_name
in
let null ?ip:_ ~host:_ _ = Ok None in
let config = Tls.Config.client ~authenticator:null () |> Result.get_ok in
let ssl = SSL.of_client_socket ~host ~config conn in
let reader, writer = SSL.(to_reader ssl, to_writer ssl) in
let data = IO.Bytes.of_string "hello world" in
let bufs = IO.Iovec.of_bytes data in
let rec send_loop n =
if n = 0 then Logger.error (fun f -> f "client retried too many times")
else
match IO.write_owned_vectored ~bufs writer with
| Ok bytes -> Logger.debug (fun f -> f "Client sent %d bytes" bytes)
| Error (`Timeout | `Process_down | `Closed) ->
Logger.debug (fun f -> f "connection closed")
| Error (`Unix_error (ENOTCONN | EPIPE)) -> send_loop n
| Error err ->
Logger.error (fun f -> f "error %a" IO.pp_err err);
send_loop (n - 1)
in
send_loop 10_000;
let buf = IO.Bytes.with_capacity 1024 in
let recv_loop () =
match IO.read reader buf with
| Ok bytes ->
Logger.debug (fun f -> f "Client received %d bytes" bytes);
bytes
| Error err ->
Logger.error (fun f -> f "Error: %a" IO.pp_err err);
0
in
let len = recv_loop () in
let buf = IO.Bytes.sub buf ~pos:0 ~len in
if len = 0 then send main (Received "empty paylaod")
else send main (Received (IO.Bytes.to_string buf))
let () =
Riot.run @@ fun () ->
let _ = Logger.start () |> Result.get_ok in
Logger.set_log_level (Some Info);
let socket, port = Port_finder.next_open_port () in
let main = self () in
let server =
spawn (fun () ->
try server port socket
with SSL.Tls_failure failure ->
Logger.error (fun f ->
f "server error: %a" Tls.Engine.pp_failure failure))
in
let client =
spawn (fun () ->
try client port main
with SSL.Tls_failure failure ->
Logger.error (fun f ->
f "client error: %a" Tls.Engine.pp_failure failure))
in
monitor server;
monitor client;
match receive_any ~after:500_000L () with
| Received "hello world" -> Logger.info (fun f -> f "ssl_test: OK")
| Received other ->
Logger.error (fun f -> f "ssl_test: bad payload: %S" other);
Stdlib.exit 1
| Process.Messages.Monitor (Process_down pid) ->
let who = if Pid.equal pid server then "server" else "client" in
Logger.error (fun f ->
f "ssl_test: %s(%a) died unexpectedly" who Pid.pp pid);
Stdlib.exit 1
| _ ->
Logger.error (fun f -> f "ssl_test: unexpected message");
Stdlib.exit 1