// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "mojo/public/cpp/bindings/receiver_set.h"

#include <string>
#include <string_view>
#include <utility>

#include "base/check.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/raw_ptr_exclusion.h"
#include "base/memory/weak_ptr.h"
#include "mojo/public/cpp/bindings/message.h"

namespace mojo {

class ReceiverSetState::Entry::DispatchFilter : public MessageFilter {
 public:
  explicit DispatchFilter(Entry& entry,
                          std::unique_ptr<MessageFilter> nested_filter)
      : entry_(entry), nested_filter_(std::move(nested_filter)) {}
  DispatchFilter(const DispatchFilter&) = delete;
  DispatchFilter& operator=(const DispatchFilter&) = delete;
  ~DispatchFilter() override = default;

 private:
  // MessageFilter:
  bool WillDispatch(Message* message) override {
    entry_.WillDispatch();
    if (nested_filter_)
      return nested_filter_->WillDispatch(message);
    return true;
  }

  void DidDispatchOrReject(Message* message, bool accepted) override {
    entry_.DidDispatchOrReject();
    if (nested_filter_)
      nested_filter_->DidDispatchOrReject(message, accepted);
  }

  // RAW_PTR_EXCLUSION: Binary size increase.
  RAW_PTR_EXCLUSION Entry& entry_;
  std::unique_ptr<MessageFilter> nested_filter_;
};

ReceiverSetState::Entry::Entry(ReceiverSetState& state,
                               ReceiverId id,
                               std::unique_ptr<ReceiverState> receiver,
                               std::unique_ptr<MessageFilter> filter)
    : state_(state), id_(id), receiver_(std::move(receiver)) {
  receiver_->InstallDispatchHooks(
      std::make_unique<DispatchFilter>(*this, std::move(filter)),
      base::BindRepeating(&ReceiverSetState::Entry::OnDisconnect,
                          base::Unretained(this)));
}

ReceiverSetState::Entry::~Entry() = default;

void ReceiverSetState::Entry::WillDispatch() {
  state_.SetDispatchContext(receiver_->GetContext(), id_);
}

void ReceiverSetState::Entry::DidDispatchOrReject() {
  state_.SetDispatchContext(nullptr, 0);
}

void ReceiverSetState::Entry::OnDisconnect(uint32_t custom_reason_code,
                                           const std::string& description) {
  WillDispatch();
  state_.OnDisconnect(id_, custom_reason_code, description);
}

ReceiverSetState::ReceiverSetState() : entries_(PassKey()) {}

ReceiverSetState::~ReceiverSetState() = default;

void ReceiverSetState::set_disconnect_handler(base::RepeatingClosure handler) {
  disconnect_handler_ = std::move(handler);
  disconnect_with_reason_handler_.Reset();
}

void ReceiverSetState::set_disconnect_with_reason_handler(
    RepeatingConnectionErrorWithReasonCallback handler) {
  disconnect_with_reason_handler_ = std::move(handler);
  disconnect_handler_.Reset();
}

ReportBadMessageCallback ReceiverSetState::GetBadMessageCallback() {
  DCHECK(current_context_);
  return base::BindOnce(
      [](ReportBadMessageCallback error_callback,
         base::WeakPtr<ReceiverSetState> receiver_set, ReceiverId receiver_id,
         std::string_view error) {
        std::move(error_callback).Run(error);
        if (receiver_set)
          receiver_set->Remove(receiver_id);
      },
      mojo::GetBadMessageCallback(), weak_ptr_factory_.GetWeakPtr(),
      current_receiver());
}

ReceiverId ReceiverSetState::Add(std::unique_ptr<ReceiverState> receiver,
                                 std::unique_ptr<MessageFilter> filter) {
  ReceiverId id = ++next_receiver_id_;
  CHECK_NE(0u, id) << "ReceiverId overflow";
  entries_.insert({id, std::make_unique<Entry>(*this, id, std::move(receiver),
                                               std::move(filter))});
  return id;
}

bool ReceiverSetState::Remove(ReceiverId id) {
  auto it = entries_.find(id);
  if (it == entries_.end())
    return false;
  entries_.erase(it);
  return true;
}

bool ReceiverSetState::RemoveWithReason(ReceiverId id,
                                        uint32_t custom_reason_code,
                                        const std::string& description) {
  auto it = entries_.find(id);
  if (it == entries_.end())
    return false;
  it->second->receiver().ResetWithReason(custom_reason_code, description);
  entries_.erase(it);
  return true;
}

void ReceiverSetState::FlushForTesting() {
  // We avoid flushing while iterating over |entries_| because this set may be
  // mutated during individual flush operations.  Instead, snapshot the
  // ReceiverIds first, then iterate over them. This is less efficient, but it's
  // only a testing API. This also allows for correct behavior in reentrant
  // calls to FlushForTesting().
  std::vector<ReceiverId> ids;
  for (const auto& entry : entries_)
    ids.push_back(entry.first);

  auto weak_self = weak_ptr_factory_.GetWeakPtr();
  for (const auto& id : ids) {
    if (!weak_self)
      return;
    auto it = entries_.find(id);
    if (it != entries_.end())
      it->second->receiver().FlushForTesting();
  }
}

void ReceiverSetState::SetDispatchContext(void* context,
                                          ReceiverId receiver_id) {
  current_context_ = context;
  current_receiver_ = receiver_id;
}

void ReceiverSetState::OnDisconnect(ReceiverId id,
                                    uint32_t custom_reason_code,
                                    const std::string& description) {
  auto it = entries_.find(id);
  CHECK(it != entries_.end());

  // We keep the Entry alive throughout error dispatch.
  std::unique_ptr<Entry> entry = std::move(it->second);
  entries_.erase(it);

  if (disconnect_handler_)
    disconnect_handler_.Run();
  else if (disconnect_with_reason_handler_)
    disconnect_with_reason_handler_.Run(custom_reason_code, description);
}

}  // namespace mojo