#include "mojo/core/ipcz_driver/mojo_message.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <utility>
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/numerics/safe_conversions.h"
#include "mojo/core/ipcz_api.h"
#include "mojo/core/ipcz_driver/data_pipe.h"
#include "mojo/core/scoped_ipcz_handle.h"
#include "third_party/abseil-cpp/absl/container/inlined_vector.h"
#include "third_party/ipcz/include/ipcz/ipcz.h"
namespace mojo::core::ipcz_driver {
namespace {
constexpr int kGrowthFactor = 2;
bool FixUpDataPipeHandles(std::vector<IpczHandle>& handles) {
absl::InlinedVector<DataPipe*, 2> data_pipes;
for (IpczHandle handle : handles) {
if (auto* data_pipe = DataPipe::FromBox(handle)) {
data_pipes.push_back(data_pipe);
}
}
if (handles.size() < data_pipes.size() * 2) {
return false;
}
const size_t first_data_pipe_portal = handles.size() - data_pipes.size();
for (size_t i = 0; i < data_pipes.size(); ++i) {
const IpczHandle handle = handles[first_data_pipe_portal + i];
if (!data_pipes[i]->AdoptPortal(ScopedIpczHandle(handle))) {
return false;
}
}
handles.resize(first_data_pipe_portal);
return true;
}
}
MojoMessage::MojoMessage() = default;
MojoMessage::MojoMessage(std::vector<uint8_t> data,
std::vector<IpczHandle> handles)
: handles_(std::move(handles)) {
data_storage_.reset(new uint8_t[data.size()]);
data_storage_size_ = data.size();
std::ranges::copy(data, data_storage_.get());
}
MojoMessage::~MojoMessage() {
for (IpczHandle handle : handles_) {
if (handle != IPCZ_INVALID_HANDLE) {
GetIpczAPI().Close(handle, IPCZ_NO_FLAGS, nullptr);
}
}
if (destructor_) {
destructor_(context_);
}
}
void MojoMessage::SetParcel(ScopedIpczHandle parcel) {
DCHECK(!data_storage_);
DCHECK(!parcel_.is_valid());
parcel_ = std::move(parcel);
const volatile void* data;
size_t num_bytes;
size_t num_handles = 0;
IpczTransaction transaction;
IpczResult result =
GetIpczAPI().BeginGet(parcel_.get(), IPCZ_NO_FLAGS, nullptr, &data,
&num_bytes, nullptr, &num_handles, &transaction);
if (result == IPCZ_RESULT_RESOURCE_EXHAUSTED) {
handles_.resize(num_handles);
result = GetIpczAPI().BeginGet(parcel_.get(), IPCZ_NO_FLAGS, nullptr, &data,
&num_bytes, handles_.data(), &num_handles,
&transaction);
}
DCHECK_EQ(result, IPCZ_RESULT_OK);
if (num_bytes > 0) {
data_storage_.reset(new uint8_t[num_bytes]);
UNSAFE_TODO(
memcpy(data_storage_.get(), const_cast<const void*>(data), num_bytes));
} else {
data_storage_.reset();
}
data_ = UNSAFE_TODO({data_storage_.get(), num_bytes});
data_storage_size_ = num_bytes;
result = GetIpczAPI().EndGet(parcel_.get(), transaction, IPCZ_NO_FLAGS,
nullptr, nullptr);
DCHECK_EQ(result, IPCZ_RESULT_OK);
if (!FixUpDataPipeHandles(handles_)) {
handles_.clear();
}
}
MojoResult MojoMessage::ReserveCapacity(uint32_t payload_buffer_size,
uint32_t* buffer_size) {
DCHECK(!parcel_.is_valid());
if (context_ || size_committed_ || !data_.empty()) {
return MOJO_RESULT_FAILED_PRECONDITION;
}
data_storage_size_ = std::max(payload_buffer_size, uint32_t{kMinBufferSize});
DataPtr new_storage(new uint8_t[data_storage_size_]);
data_storage_ = std::move(new_storage);
data_ = UNSAFE_TODO(base::span(data_storage_.get(), 0u));
if (buffer_size) {
*buffer_size = base::checked_cast<uint32_t>(data_storage_size_);
}
return MOJO_RESULT_OK;
}
MojoResult MojoMessage::AppendData(uint32_t additional_num_bytes,
const MojoHandle* handles,
uint32_t num_handles,
void** buffer,
uint32_t* buffer_size,
bool commit_size) {
DCHECK(!parcel_.is_valid());
if (context_ || size_committed_) {
return MOJO_RESULT_FAILED_PRECONDITION;
}
const size_t data_size = data_.size();
const size_t new_data_size = data_size + additional_num_bytes;
const size_t required_storage_size = std::max(new_data_size, kMinBufferSize);
if (required_storage_size > data_storage_size_) {
const size_t copy_size = std::min(new_data_size, data_storage_size_);
data_storage_size_ =
std::max(data_size * kGrowthFactor, required_storage_size);
DataPtr new_storage(new uint8_t[data_storage_size_]);
std::ranges::copy(UNSAFE_TODO(base::span(data_storage_.get(), copy_size)),
new_storage.get());
data_storage_ = std::move(new_storage);
}
data_ = UNSAFE_TODO(base::span(data_storage_.get(), new_data_size));
handles_.reserve(handles_.size() + num_handles);
for (MojoHandle handle : UNSAFE_TODO(base::span(handles, num_handles))) {
handles_.push_back(handle);
}
if (buffer) {
*buffer = data_storage_.get();
}
if (buffer_size) {
*buffer_size = base::checked_cast<uint32_t>(data_storage_size_);
}
size_committed_ = commit_size;
return MOJO_RESULT_OK;
}
IpczResult MojoMessage::GetData(void** buffer,
uint32_t* num_bytes,
MojoHandle* handles,
uint32_t* num_handles,
bool consume_handles) {
if (context_ || (!parcel_.is_valid() && !size_committed_)) {
return MOJO_RESULT_FAILED_PRECONDITION;
}
if (consume_handles && handles_consumed_) {
return MOJO_RESULT_NOT_FOUND;
}
if (buffer) {
*buffer = data_.data();
}
if (num_bytes) {
*num_bytes = base::checked_cast<uint32_t>(data_.size());
}
if (!consume_handles || handles_.empty()) {
return MOJO_RESULT_OK;
}
uint32_t capacity = num_handles ? *num_handles : 0;
uint32_t required_capacity = base::checked_cast<uint32_t>(handles_.size());
if (num_handles) {
*num_handles = required_capacity;
}
if (!handles || capacity < required_capacity) {
return MOJO_RESULT_RESOURCE_EXHAUSTED;
}
std::ranges::copy(handles_, handles);
handles_.clear();
handles_consumed_ = true;
return MOJO_RESULT_OK;
}
void MojoMessage::AttachDataPipePortals() {
const size_t base_num_handles = handles_.size();
for (size_t i = 0; i < base_num_handles; ++i) {
if (auto* data_pipe = ipcz_driver::DataPipe::FromBox(handles_[i])) {
handles_.push_back(data_pipe->TakePortal().release());
}
}
}
MojoResult MojoMessage::SetContext(uintptr_t context,
MojoMessageContextSerializer serializer,
MojoMessageContextDestructor destructor) {
if (context_ && context) {
return MOJO_RESULT_ALREADY_EXISTS;
}
if (parcel_.is_valid() || data_storage_ || !handles_.empty()) {
return MOJO_RESULT_FAILED_PRECONDITION;
}
context_ = context;
serializer_ = serializer;
destructor_ = destructor;
return MOJO_RESULT_OK;
}
MojoResult MojoMessage::Serialize() {
if (parcel_.is_valid() || data_storage_ || !handles_.empty()) {
return MOJO_RESULT_FAILED_PRECONDITION;
}
if (!serializer_) {
return MOJO_RESULT_NOT_FOUND;
}
const uintptr_t context = std::exchange(context_, 0);
const MojoMessageContextSerializer serializer =
std::exchange(serializer_, nullptr);
const MojoMessageContextDestructor destructor =
std::exchange(destructor_, nullptr);
serializer(handle(), context);
if (destructor) {
destructor(context);
}
return MOJO_RESULT_OK;
}
IpczResult MojoMessage::SerializeForIpcz(uintptr_t object,
uint32_t,
const void*,
volatile void* data,
size_t* num_bytes,
IpczHandle* handles,
size_t* num_handles) {
return reinterpret_cast<MojoMessage*>(object)->SerializeForIpczImpl(
data, num_bytes, handles, num_handles);
}
void MojoMessage::DestroyForIpcz(uintptr_t object, uint32_t, const void*) {
delete reinterpret_cast<MojoMessage*>(object);
}
ScopedIpczHandle MojoMessage::Box(std::unique_ptr<MojoMessage> message) {
const IpczBoxContents contents = {
.size = sizeof(contents),
.type = IPCZ_BOX_TYPE_APPLICATION_OBJECT,
.object = {.application_object = message->handle()},
.serializer = &SerializeForIpcz,
.destructor = &DestroyForIpcz,
};
ScopedIpczHandle box;
const IpczResult result =
GetIpczAPI().Box(GetIpczNode(), &contents, IPCZ_NO_FLAGS, nullptr,
ScopedIpczHandle::Receiver(box));
CHECK_EQ(IPCZ_RESULT_OK, result);
std::ignore = message.release();
return box;
}
std::unique_ptr<MojoMessage> MojoMessage::UnwrapFrom(MojoMessage& message) {
if (!message.data().empty() || message.handles().size() != 1) {
return nullptr;
}
const IpczHandle box = message.handles()[0];
IpczBoxContents contents = {.size = sizeof(contents)};
const IpczResult peek_result =
GetIpczAPI().Unbox(box, IPCZ_UNBOX_PEEK, nullptr, &contents);
if (peek_result != IPCZ_RESULT_OK) {
return nullptr;
}
if (contents.type != IPCZ_BOX_TYPE_APPLICATION_OBJECT &&
contents.type != IPCZ_BOX_TYPE_SUBPARCEL) {
return nullptr;
}
const IpczResult unbox_result =
GetIpczAPI().Unbox(box, IPCZ_NO_FLAGS, nullptr, &contents);
DCHECK_EQ(IPCZ_RESULT_OK, unbox_result);
message.handles()[0] = IPCZ_INVALID_HANDLE;
if (contents.type == IPCZ_BOX_TYPE_APPLICATION_OBJECT) {
return base::WrapUnique(
reinterpret_cast<MojoMessage*>(contents.object.application_object));
}
DCHECK_EQ(contents.type, IPCZ_BOX_TYPE_SUBPARCEL);
ScopedIpczHandle subparcel(contents.object.subparcel);
size_t num_bytes = 0;
size_t num_handles = 0;
const IpczResult get_query_result =
GetIpczAPI().Get(subparcel.get(), IPCZ_NO_FLAGS, nullptr, nullptr,
&num_bytes, nullptr, &num_handles, nullptr);
if (get_query_result != IPCZ_RESULT_RESOURCE_EXHAUSTED) {
return nullptr;
}
void* buffer;
std::vector<IpczHandle> handles(num_handles);
auto new_message = std::make_unique<ipcz_driver::MojoMessage>();
const MojoResult append_result = new_message->AppendData(
base::checked_cast<uint32_t>(num_bytes), nullptr, 0, &buffer, nullptr,
true);
if (append_result != MOJO_RESULT_OK) {
return nullptr;
}
new_message->handles().resize(num_handles);
const IpczResult get_result = GetIpczAPI().Get(
subparcel.get(), IPCZ_NO_FLAGS, nullptr, buffer, &num_bytes,
new_message->handles().data(), &num_handles, nullptr);
if (get_result != IPCZ_RESULT_OK) {
return nullptr;
}
return new_message;
}
IpczResult MojoMessage::SerializeForIpczImpl(volatile void* data,
size_t* num_bytes,
IpczHandle* handles,
size_t* num_handles) {
const MojoResult result = Serialize();
if (result != MOJO_RESULT_OK && result != MOJO_RESULT_FAILED_PRECONDITION) {
return IPCZ_RESULT_FAILED_PRECONDITION;
}
const size_t required_byte_capacity = data_.size();
const size_t required_handle_capacity = handles_.size();
const size_t byte_capacity = num_bytes ? *num_bytes : 0;
const size_t handle_capacity = num_handles ? *num_handles : 0;
if (num_bytes) {
*num_bytes = required_byte_capacity;
}
if (num_handles) {
*num_handles = required_handle_capacity;
}
if (byte_capacity < required_byte_capacity ||
handle_capacity < required_handle_capacity) {
return IPCZ_RESULT_RESOURCE_EXHAUSTED;
}
if ((byte_capacity && !data) || (handle_capacity && !handles)) {
return IPCZ_RESULT_INVALID_ARGUMENT;
}
UNSAFE_TODO(memcpy(const_cast<void*>(data), data_.data(), data_.size()));
for (size_t i = 0; i < handles_.size(); ++i) {
UNSAFE_TODO(handles[i]) = std::exchange(handles_[i], IPCZ_INVALID_HANDLE);
}
return IPCZ_RESULT_OK;
}
}