open Global

open Logger.Make (struct
  let namespace = [ "riot"; "dynamic_supervisor" ]
end)

type state = { max_children : int; curr_children : int; name : string }

type Message.t +=
  | Start_child of Pid.t * Supervisor.child_spec
  | Started_child of Pid.t
  | Max_children

let rec loop state =
  match receive_any () with
  | Process.Messages.Monitor _ ->
      trace (fun f -> f "child finished");
      loop { state with curr_children = Int.max 0 (state.curr_children - 1) }
  | Start_child (reply, spec) -> handle_start_child state reply spec
  | _ -> loop state

and handle_start_child state reply child_spec =
  let curr_children = state.curr_children + 1 in
  if curr_children < state.max_children then (
    let pid = Supervisor.start_child child_spec in
    Process.monitor pid;
    trace (fun f -> f "started child %d" curr_children);
    send reply (Started_child pid);
    loop { state with curr_children })
  else (
    send reply Max_children;
    loop state)

let init ({ max_children; name; _ } as state) =
  register name (self ());
  Process.flag (Trap_exit true);
  trace (fun f -> f "max %d children" max_children);
  loop state

let start_link state =
  let pid = spawn_link (fun () -> init state) in
  Ok pid

let child_spec ?(max_children = 50) ~name () =
  let state = { max_children; curr_children = 0; name } in
  Supervisor.child_spec start_link state

let start_child pid spec =
  let ref = Ref.make () in
  send pid (Start_child (self (), spec));
  let selector msg =
    match msg with
    | Started_child pid -> `select (`started_child pid)
    | Max_children -> `select `max_children
    | _ -> `skip
  in
  match receive ~selector ~ref () with
  | `started_child pid -> Ok pid
  | `max_children -> Error `Max_children