#ifdef UNSAFE_BUFFERS_BUILD
#pragma allow_unsafe_buffers
#endif
#include "mojo/core/channel_linux.h"
#include <fcntl.h>
#include <linux/futex.h>
#include <linux/memfd.h>
#include <sys/eventfd.h>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <sys/utsname.h>
#include <unistd.h>
#include <algorithm>
#include <atomic>
#include <cstring>
#include <limits>
#include <memory>
#include "base/bits.h"
#include "base/files/scoped_file.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/page_size.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr_exclusion.h"
#include "base/memory/ref_counted.h"
#include "base/memory/shared_memory_security_policy.h"
#include "base/message_loop/io_watcher.h"
#include "base/message_loop/message_pump_for_io.h"
#include "base/metrics/histogram_macros.h"
#include "base/posix/eintr_wrapper.h"
#include "base/system/sys_info.h"
#include "base/task/single_thread_task_runner.h"
#include "base/task/task_runner.h"
#include "base/time/time.h"
#include "build/build_config.h"
#include "mojo/buildflags.h"
#include "mojo/core/embedder/features.h"
namespace mojo::core {
class DataAvailableNotifier {
public:
DataAvailableNotifier() = default;
explicit DataAvailableNotifier(base::RepeatingClosure callback)
: callback_(std::move(callback)) {}
virtual ~DataAvailableNotifier() = default;
virtual bool Notify() = 0;
virtual bool Clear() = 0;
virtual bool is_valid() const = 0;
protected:
void DataAvailable() {
DCHECK(callback_);
callback_.Run();
}
base::RepeatingClosure callback_;
};
namespace {
constexpr int kMemFDSeals = F_SEAL_SEAL | F_SEAL_SHRINK | F_SEAL_GROW;
std::atomic_bool g_params_set{false};
std::atomic_bool g_use_shared_mem{false};
std::atomic_uint32_t g_shared_mem_pages{4};
struct UpgradeOfferMessage {
constexpr static int kEventFdNotifier = 1;
constexpr static int kDefaultVersion = kEventFdNotifier;
constexpr static int kDefaultPages = 4;
static bool IsValidVersion(int version) {
return version == kEventFdNotifier;
}
int version = kDefaultVersion;
int num_pages = kDefaultPages;
};
constexpr size_t RoundUpToWordBoundary(size_t size) {
return base::bits::AlignUp(size, sizeof(void*));
}
base::ScopedFD CreateSealedMemFD(size_t size) {
CHECK_GT(size, 0u);
CHECK_EQ(size % base::GetPageSize(), 0u);
base::ScopedFD fd(syscall(__NR_memfd_create, "mojo_channel_linux",
MFD_CLOEXEC | MFD_ALLOW_SEALING));
if (!fd.is_valid()) {
PLOG(ERROR) << "Unable to create memfd for shared memory channel";
return {};
}
if (ftruncate(fd.get(), size) < 0) {
PLOG(ERROR) << "Unable to truncate memfd for shared memory channel";
return {};
}
if (fcntl(fd.get(), F_ADD_SEALS, kMemFDSeals) < 0) {
PLOG(ERROR) << "Unable to seal memfd for shared memory channel";
return {};
}
return fd;
}
bool ValidateFDIsProperlySealedMemFD(const base::ScopedFD& fd) {
int seals = 0;
if ((seals = fcntl(fd.get(), F_GET_SEALS)) < 0) {
PLOG(ERROR) << "Unable to get seals on memfd for shared memory channel";
return false;
}
return seals == kMemFDSeals;
}
class EventFDNotifier : public DataAvailableNotifier,
public base::IOWatcher::FdWatcher {
public:
EventFDNotifier(EventFDNotifier&& efd) = default;
EventFDNotifier(const EventFDNotifier&) = delete;
EventFDNotifier& operator=(const EventFDNotifier&) = delete;
~EventFDNotifier() override { reset(); }
static constexpr int kEfdFlags = EFD_CLOEXEC | EFD_NONBLOCK;
static std::unique_ptr<EventFDNotifier> CreateWriteNotifier() {
int fd = syscall(__NR_eventfd2, 0, kEfdFlags);
if (fd < 0) {
PLOG(ERROR) << "Unable to create an eventfd";
return nullptr;
}
return WrapFD(base::ScopedFD(fd));
}
static std::unique_ptr<EventFDNotifier> CreateReadNotifier(
base::ScopedFD efd,
base::RepeatingClosure cb,
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner) {
DCHECK(io_task_runner->RunsTasksInCurrentSequence());
DCHECK(cb);
return WrapFDWithCallback(std::move(efd), std::move(cb), io_task_runner);
}
static bool KernelSupported() {
int ret = syscall(__NR_eventfd2, 0, ~0);
PCHECK(ret < 0 && (errno == EINVAL || errno == ENOSYS || errno == EPERM));
return (ret < 0 && errno == EINVAL);
}
bool Clear() override {
uint64_t value = 0;
ssize_t res = HANDLE_EINTR(
read(fd_.get(), reinterpret_cast<void*>(&value), sizeof(value)));
if (res < static_cast<int64_t>(sizeof(value))) {
PLOG_IF(ERROR, errno != EWOULDBLOCK) << "eventfd read error";
}
return res == sizeof(value);
}
bool Notify() override {
uint64_t value = 1;
ssize_t res = HANDLE_EINTR(write(fd_.get(), &value, sizeof(value)));
return res == sizeof(value);
}
bool is_valid() const override { return fd_.is_valid(); }
void OnFdReadable(int fd) override {
DCHECK(fd == fd_.get());
DataAvailable();
}
void OnFdWritable(int fd) override {}
base::ScopedFD take() { return std::move(fd_); }
base::ScopedFD take_dup() {
return base::ScopedFD(HANDLE_EINTR(dup(fd_.get())));
}
void reset() {
watch_.reset();
fd_.reset();
}
int fd() { return fd_.get(); }
private:
explicit EventFDNotifier(base::ScopedFD fd) : fd_(std::move(fd)) {}
explicit EventFDNotifier(
base::ScopedFD fd,
base::RepeatingClosure cb,
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner)
: DataAvailableNotifier(std::move(cb)),
fd_(std::move(fd)),
io_task_runner_(io_task_runner) {
WaitForEventFDOnIOThread();
}
static std::unique_ptr<EventFDNotifier> WrapFD(base::ScopedFD fd) {
return base::WrapUnique<EventFDNotifier>(
new EventFDNotifier(std::move(fd)));
}
static std::unique_ptr<EventFDNotifier> WrapFDWithCallback(
base::ScopedFD fd,
base::RepeatingClosure cb,
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner) {
return base::WrapUnique<EventFDNotifier>(
new EventFDNotifier(std::move(fd), std::move(cb), io_task_runner));
}
void WaitForEventFDOnIOThread() {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
watch_ = base::IOWatcher::Get()->WatchFileDescriptor(
fd_.get(), base::IOWatcher::FdWatchDuration::kPersistent,
base::IOWatcher::FdWatchMode::kRead, *this);
}
base::ScopedFD fd_;
std::unique_ptr<base::IOWatcher::FdWatch> watch_;
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner_;
};
}
class ChannelLinux::SharedBuffer {
public:
SharedBuffer(SharedBuffer&& other) = default;
SharedBuffer(const SharedBuffer&) = delete;
SharedBuffer& operator=(const SharedBuffer&) = delete;
~SharedBuffer() { reset(); }
enum class Error { kSuccess = 0, kGeneralError = 1, kControlCorruption = 2 };
static std::unique_ptr<SharedBuffer> Create(const base::ScopedFD& memfd,
size_t size) {
if (!memfd.is_valid()) {
return nullptr;
}
if (!base::SharedMemorySecurityPolicy::AcquireReservationForMapping(size)) {
LOG(ERROR)
<< "Unable to create shared buffer: unable to acquire reservation";
return nullptr;
}
uint8_t* ptr = reinterpret_cast<uint8_t*>(mmap(
nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, memfd.get(), 0));
if (ptr == MAP_FAILED) {
PLOG(ERROR) << "Unable to map shared memory";
base::SharedMemorySecurityPolicy::ReleaseReservationForMapping(size);
return nullptr;
}
return base::WrapUnique<SharedBuffer>(new SharedBuffer(ptr, size));
}
uint8_t* usable_region_ptr() { return base_ptr_ + kReservedSpace; }
size_t usable_len() const { return len_ - kReservedSpace; }
bool is_valid() const { return base_ptr_ != nullptr && len_ > 0; }
void reset() {
if (is_valid()) {
if (munmap(base_ptr_, len_) < 0) {
PLOG(ERROR) << "Unable to unmap shared buffer";
return;
}
base::SharedMemorySecurityPolicy::ReleaseReservationForMapping(len_);
base_ptr_ = nullptr;
len_ = 0;
}
}
void Initialize() { new (static_cast<void*>(base_ptr_)) ControlStructure; }
Error TryWrite(const void* data, size_t len) {
DCHECK(data);
DCHECK(len);
if (len > usable_len()) {
return Error::kGeneralError;
}
if (!TryLockForWriting()) {
return Error::kGeneralError;
}
uint32_t cur_read_pos = read_pos().load();
uint32_t cur_write_pos = write_pos().load();
if (!ValidateReadWritePositions(cur_read_pos, cur_write_pos)) {
UnlockForWriting();
return Error::kControlCorruption;
}
uint32_t space_available =
usable_len() - NumBytesInUse(cur_read_pos, cur_write_pos);
if (space_available <= len) {
UnlockForWriting();
return Error::kGeneralError;
}
if ((usable_len() - cur_write_pos) > len) {
memcpy(usable_region_ptr() + cur_write_pos, data, len);
} else {
size_t copy1_len = usable_len() - cur_write_pos;
memcpy(usable_region_ptr() + cur_write_pos, data, copy1_len);
memcpy(usable_region_ptr(),
reinterpret_cast<const uint8_t*>(data) + copy1_len,
len - copy1_len);
}
if (write_pos().exchange((cur_write_pos + len) % usable_len()) !=
cur_write_pos) {
UnlockForWriting();
return Error::kControlCorruption;
}
UnlockForWriting();
return Error::kSuccess;
}
Error TryReadLocked(void* data, uint32_t len, uint32_t* bytes_read) {
uint32_t cur_read_pos = read_pos().load();
uint32_t cur_write_pos = write_pos().load();
if (!ValidateReadWritePositions(cur_read_pos, cur_write_pos)) {
return Error::kControlCorruption;
}
uint32_t bytes_available_to_read =
NumBytesInUse(cur_read_pos, cur_write_pos);
bytes_available_to_read = std::min(bytes_available_to_read, len);
if (bytes_available_to_read == 0) {
*bytes_read = 0;
return Error::kSuccess;
}
if (cur_read_pos < cur_write_pos) {
memcpy(data, usable_region_ptr() + cur_read_pos, bytes_available_to_read);
} else {
uint32_t bytes_from_read_to_end = usable_len() - cur_read_pos;
bytes_from_read_to_end =
std::min(bytes_from_read_to_end, bytes_available_to_read);
memcpy(data, usable_region_ptr() + cur_read_pos, bytes_from_read_to_end);
if (bytes_from_read_to_end < bytes_available_to_read) {
memcpy(reinterpret_cast<uint8_t*>(data) + bytes_from_read_to_end,
usable_region_ptr(),
bytes_available_to_read - bytes_from_read_to_end);
}
}
uint32_t new_read_pos =
(cur_read_pos + bytes_available_to_read) % usable_len();
if (read_pos().exchange(new_read_pos) != cur_read_pos) {
*bytes_read = 0;
return Error::kControlCorruption;
}
*bytes_read = bytes_available_to_read;
return Error::kSuccess;
}
bool TryLockForReading() {
return !read_flag().test_and_set(std::memory_order_acquire);
}
void UnlockForReading() { read_flag().clear(std::memory_order_release); }
private:
struct ControlStructure {
std::atomic_flag write_flag{false};
std::atomic_uint32_t write_pos{0};
std::atomic_flag read_flag{false};
std::atomic_uint32_t read_pos{0};
alignas(4) volatile uint32_t futex = 0;
};
bool ValidateReadWritePositions(uint32_t read_pos, uint32_t write_pos) {
if (write_pos >= usable_len()) {
LOG(ERROR) << "Write position of shared buffer is currently beyond the "
"usable length";
return false;
}
if (read_pos >= usable_len()) {
LOG(ERROR) << "Read position of shared buffer is currently beyond the "
"usable length";
return false;
}
return true;
}
uint32_t NumBytesInUse(uint32_t read_pos, uint32_t write_pos) {
uint32_t bytes_in_use = 0;
if (read_pos <= write_pos) {
bytes_in_use = write_pos - read_pos;
} else {
bytes_in_use = write_pos + (usable_len() - read_pos);
}
return bytes_in_use;
}
bool TryLockForWriting() {
return !write_flag().test_and_set(std::memory_order_acquire);
}
void UnlockForWriting() { write_flag().clear(std::memory_order_release); }
constexpr static size_t kReservedSpace =
RoundUpToWordBoundary(sizeof(ControlStructure));
std::atomic_flag& write_flag() {
DCHECK(is_valid());
return reinterpret_cast<ControlStructure*>(base_ptr_)->write_flag;
}
std::atomic_flag& read_flag() {
DCHECK(is_valid());
return reinterpret_cast<ControlStructure*>(base_ptr_)->read_flag;
}
std::atomic_uint32_t& read_pos() {
DCHECK(is_valid());
return reinterpret_cast<ControlStructure*>(base_ptr_)->read_pos;
}
std::atomic_uint32_t& write_pos() {
DCHECK(is_valid());
return reinterpret_cast<ControlStructure*>(base_ptr_)->write_pos;
}
SharedBuffer(uint8_t* ptr, size_t len) : base_ptr_(ptr), len_(len) {}
RAW_PTR_EXCLUSION uint8_t* base_ptr_ = nullptr;
size_t len_ = 0;
};
ChannelLinux::ChannelLinux(
Delegate* delegate,
ConnectionParams connection_params,
HandlePolicy handle_policy,
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner)
: ChannelPosix(delegate,
std::move(connection_params),
handle_policy,
io_task_runner),
num_pages_(g_shared_mem_pages.load()) {}
ChannelLinux::~ChannelLinux() = default;
void ChannelLinux::Write(MessagePtr message) {
if (!shared_mem_writer_ || message->has_handles() || reject_writes_) {
return ChannelPosix::Write(std::move(message));
}
SharedBuffer::Error write_result =
write_buffer_->TryWrite(message->data(), message->data_num_bytes());
if (write_result == SharedBuffer::Error::kGeneralError) {
return ChannelPosix::Write(std::move(message));
} else if (write_result == SharedBuffer::Error::kControlCorruption) {
reject_writes_ = true;
io_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ChannelLinux::OnWriteError, this,
Channel::Error::kReceivedMalformedData));
return;
}
write_notifier_->Notify();
}
void ChannelLinux::OfferSharedMemUpgrade() {
if (!offered_.test_and_set() && UpgradesEnabled()) {
if (handle_policy() == HandlePolicy::kAcceptHandles) {
OfferSharedMemUpgradeInternal();
}
}
}
bool ChannelLinux::OnControlMessage(Message::MessageType message_type,
const void* payload,
size_t payload_size,
std::vector<PlatformHandle> handles) {
switch (message_type) {
case Message::MessageType::UPGRADE_OFFER: {
if (payload_size < sizeof(UpgradeOfferMessage)) {
LOG(ERROR) << "Received an UPGRADE_OFFER without a payload";
return true;
}
const UpgradeOfferMessage* msg =
reinterpret_cast<const UpgradeOfferMessage*>(payload);
if (!UpgradeOfferMessage::IsValidVersion(msg->version)) {
LOG(ERROR) << "Reject shared mem upgrade unexpected version: "
<< msg->version;
RejectUpgradeOffer();
return true;
}
if (!is_for_ipcz()) {
LOG(ERROR) << "Rejecting UPGRADE_OFFER for non-ipcz transport";
RejectUpgradeOffer();
return true;
}
if (handles.size() != 2) {
LOG(ERROR) << "Received an UPGRADE_OFFER without two FDs";
RejectUpgradeOffer();
return true;
}
if (read_buffer_ || read_notifier_) {
LOG(ERROR) << "Received an UPGRADE_OFFER on already upgraded channel";
return true;
}
base::ScopedFD memfd(handles[0].TakeFD());
if (memfd.is_valid() && !ValidateFDIsProperlySealedMemFD(memfd)) {
PLOG(ERROR) << "Passed FD was not properly sealed";
DLOG(FATAL) << "MemFD was NOT properly sealed";
memfd.reset();
}
if (!memfd.is_valid()) {
RejectUpgradeOffer();
return true;
}
if (msg->num_pages <= 0 || msg->num_pages > 128) {
LOG(ERROR) << "SharedMemory upgrade offer was received with invalid "
"number of pages: "
<< msg->num_pages;
RejectUpgradeOffer();
}
std::unique_ptr<DataAvailableNotifier> read_notifier;
if (msg->version == UpgradeOfferMessage::kEventFdNotifier) {
read_notifier = EventFDNotifier::CreateReadNotifier(
handles[1].TakeFD(),
base::BindRepeating(&ChannelLinux::SharedMemReadReady, this),
io_task_runner_);
}
if (!read_notifier) {
RejectUpgradeOffer();
return true;
}
read_notifier_ = std::move(read_notifier);
std::unique_ptr<SharedBuffer> read_sb = SharedBuffer::Create(
std::move(memfd), msg->num_pages * base::GetPageSize());
if (!read_sb || !read_sb->is_valid()) {
RejectUpgradeOffer();
return true;
}
read_buffer_ = std::move(read_sb);
read_buf_.resize(read_buffer_->usable_len());
AcceptUpgradeOffer();
OfferSharedMemUpgrade();
return true;
}
case Message::MessageType::UPGRADE_ACCEPT: {
if (!write_buffer_ || !write_notifier_ || !write_notifier_->is_valid()) {
LOG(ERROR) << "Received unexpected UPGRADE_ACCEPT";
shared_mem_writer_ = false;
write_buffer_.reset();
write_notifier_.reset();
return true;
}
shared_mem_writer_ = true;
return true;
}
case Message::MessageType::UPGRADE_REJECT: {
shared_mem_writer_ = false;
write_buffer_.reset();
write_notifier_.reset();
return true;
}
default:
break;
}
return ChannelPosix::OnControlMessage(message_type, payload, payload_size,
std::move(handles));
}
void ChannelLinux::SharedMemReadReady() {
CHECK(read_buffer_);
if (read_buffer_->TryLockForReading()) {
read_notifier_->Clear();
bool read_fail = false;
do {
uint32_t bytes_read = 0;
SharedBuffer::Error read_res = read_buffer_->TryReadLocked(
read_buf_.data(), read_buf_.size(), &bytes_read);
if (read_res == SharedBuffer::Error::kControlCorruption) {
OnError(Error::kReceivedMalformedData);
break;
}
if (bytes_read == 0) {
break;
}
off_t data_offset = 0;
while (bytes_read - data_offset > 0) {
size_t read_size_hint;
DispatchResult result = TryDispatchMessage(
base::span(reinterpret_cast<char*>(read_buf_.data() + data_offset),
static_cast<size_t>(bytes_read - data_offset)),
&read_size_hint);
if (result != DispatchResult::kOK) {
LOG(ERROR) << "Recevied a bad message via shared memory";
read_fail = true;
OnError(Error::kReceivedMalformedData);
break;
}
if (!DispatchDelayedMessages()) {
LOG(ERROR) << "Error dispatching queued messages";
read_fail = true;
OnError(Error::kReceivedMalformedData);
}
data_offset += read_size_hint;
}
} while (!read_fail);
read_buffer_->UnlockForReading();
}
}
void ChannelLinux::OnWriteError(Error error) {
reject_writes_ = true;
ChannelPosix::OnWriteError(error);
}
void ChannelLinux::RejectUpgradeOffer() {
if (is_for_ipcz()) {
ChannelPosix::Write(
Message::CreateIpczMessage({}, {}, Message::MessageType::UPGRADE_REJECT,
IncrementLastSentChannelSequenceNumber()));
} else {
ChannelPosix::RejectPreIpczUpgradeOffer();
}
}
void ChannelLinux::AcceptUpgradeOffer() {
CHECK(is_for_ipcz());
ChannelPosix::Write(
Message::CreateIpczMessage({}, {}, Message::MessageType::UPGRADE_ACCEPT,
IncrementLastSentChannelSequenceNumber()));
}
void ChannelLinux::ShutDownOnIOThread() {
reject_writes_ = true;
read_notifier_.reset();
write_notifier_.reset();
ChannelPosix::ShutDownOnIOThread();
}
void ChannelLinux::StartOnIOThread() {
ChannelPosix::StartOnIOThread();
}
void ChannelLinux::OfferSharedMemUpgradeInternal() {
if (reject_writes_) {
return;
}
if (write_buffer_ || write_notifier_) {
LOG(ERROR) << "Upgrade attempted on an already upgraded channel";
return;
}
const size_t kSize = num_pages_ * base::GetPageSize();
base::ScopedFD memfd = CreateSealedMemFD(kSize);
if (!memfd.is_valid()) {
PLOG(ERROR) << "Unable to create memfd";
return;
}
bool properly_sealed = ValidateFDIsProperlySealedMemFD(memfd);
if (!properly_sealed) {
LOG(ERROR) << "FD was not properly sealed we cannot offer upgrade.";
return;
}
std::unique_ptr<SharedBuffer> write_buffer =
SharedBuffer::Create(memfd, kSize);
if (!write_buffer || !write_buffer->is_valid()) {
PLOG(ERROR) << "Unable to map shared memory";
return;
}
write_buffer->Initialize();
auto notifier_version = UpgradeOfferMessage::kEventFdNotifier;
std::unique_ptr<EventFDNotifier> write_notifier =
EventFDNotifier::CreateWriteNotifier();
if (!write_notifier) {
PLOG(ERROR) << "Failed to create eventfd write notifier";
return;
}
std::vector<PlatformHandle> fds;
fds.emplace_back(std::move(memfd));
fds.emplace_back(write_notifier->take_dup());
write_notifier_ = std::move(write_notifier);
write_buffer_ = std::move(write_buffer);
UpgradeOfferMessage offer_msg;
offer_msg.num_pages = num_pages_;
offer_msg.version = notifier_version;
MessagePtr msg;
DCHECK(is_for_ipcz());
auto data = base::span(reinterpret_cast<const uint8_t*>(&offer_msg),
sizeof(UpgradeOfferMessage));
msg = Message::CreateIpczMessage(data, std::move(fds),
Message::MessageType::UPGRADE_OFFER,
IncrementLastSentChannelSequenceNumber());
ChannelPosix::Write(std::move(msg));
}
bool ChannelLinux::KernelSupportsUpgradeRequirements() {
static bool supported = []() -> bool {
if (base::SysInfo::KernelVersionNumber::Current() <
base::SysInfo::KernelVersionNumber(4, 0)) {
return false;
}
int ret = syscall(__NR_memfd_create, "", ~0);
PCHECK(ret < 0 && (errno == EINVAL || errno == ENOSYS || errno == EPERM));
bool memfd_supported = (ret < 0 && errno == EINVAL);
return memfd_supported && EventFDNotifier::KernelSupported();
}();
return supported;
}
bool ChannelLinux::UpgradesEnabled() {
if (!g_params_set.load()) {
return g_use_shared_mem.load();
}
return base::FeatureList::IsEnabled(kMojoUseEventFd);
}
void ChannelLinux::SetSharedMemParameters(bool enabled, uint32_t num_pages) {
g_params_set.store(true);
g_use_shared_mem.store(enabled);
g_shared_mem_pages.store(num_pages);
}
}