#include "base/task/common/checked_lock_impl.h"
#include <algorithm>
#include <optional>
#include <ostream>
#include <unordered_map>
#include <vector>
#include "base/check_op.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/raw_ptr_exclusion.h"
#include "base/no_destructor.h"
#include "base/synchronization/condition_variable.h"
#include "base/task/common/checked_lock.h"
#include "base/threading/platform_thread.h"
#include "base/threading/thread_local.h"
namespace base::internal {
namespace {
class SafeAcquisitionTracker {
public:
SafeAcquisitionTracker() = default;
SafeAcquisitionTracker(const SafeAcquisitionTracker&) = delete;
SafeAcquisitionTracker& operator=(const SafeAcquisitionTracker&) = delete;
void RegisterLock(const CheckedLockImpl* const lock,
const CheckedLockImpl* const predecessor) {
DCHECK_NE(lock, predecessor) << "Reentrant locks are unsupported.";
AutoLock auto_lock(allowed_predecessor_map_lock_);
allowed_predecessor_map_[lock] = predecessor;
AssertSafePredecessor(lock);
}
void UnregisterLock(const CheckedLockImpl* const lock) {
AutoLock auto_lock(allowed_predecessor_map_lock_);
allowed_predecessor_map_.erase(lock);
}
void RecordAcquisition(const CheckedLockImpl* const lock) {
AssertSafeAcquire(lock);
GetAcquiredLocksOnCurrentThread()->push_back(lock);
}
void RecordRelease(const CheckedLockImpl* const lock) {
LockVector* acquired_locks = GetAcquiredLocksOnCurrentThread();
const auto iter_at_lock = std::ranges::find(*acquired_locks, lock);
CHECK(iter_at_lock != acquired_locks->end());
acquired_locks->erase(iter_at_lock);
}
void AssertNoLockHeldOnCurrentThread() {
DCHECK(GetAcquiredLocksOnCurrentThread()->empty());
}
private:
using LockVector = std::vector<const CheckedLockImpl*>;
using PredecessorMap =
std::unordered_map<const CheckedLockImpl*, const CheckedLockImpl*>;
void AssertSafeAcquire(const CheckedLockImpl* const lock) {
const LockVector* acquired_locks = GetAcquiredLocksOnCurrentThread();
if (acquired_locks->empty()) {
return;
}
DCHECK(!lock->is_universal_predecessor());
const CheckedLockImpl* previous_lock = acquired_locks->back();
if (previous_lock->is_universal_predecessor()) {
return;
}
AutoLock auto_lock(allowed_predecessor_map_lock_);
const CheckedLockImpl* allowed_predecessor =
allowed_predecessor_map_.at(lock);
if (lock->is_universal_successor()) {
DCHECK(!previous_lock->is_universal_successor());
return;
} else {
DCHECK_EQ(previous_lock, allowed_predecessor);
}
}
void AssertSafePredecessor(const CheckedLockImpl* lock) const {
allowed_predecessor_map_lock_.AssertAcquired();
const CheckedLockImpl* predecessor = allowed_predecessor_map_.at(lock);
if (predecessor) {
DCHECK(allowed_predecessor_map_.find(predecessor) !=
allowed_predecessor_map_.end())
<< "CheckedLock was registered before its predecessor. "
<< "Potential cycle detected";
}
}
LockVector* GetAcquiredLocksOnCurrentThread() {
if (!tls_acquired_locks_.Get()) {
tls_acquired_locks_.Set(std::make_unique<LockVector>());
}
return tls_acquired_locks_.Get();
}
Lock allowed_predecessor_map_lock_;
PredecessorMap allowed_predecessor_map_;
RAW_PTR_EXCLUSION ThreadLocalOwnedPointer<LockVector> tls_acquired_locks_;
};
SafeAcquisitionTracker& GetSafeAcquisitionTracker() {
static base::NoDestructor<SafeAcquisitionTracker> tracker;
return *tracker;
}
}
CheckedLockImpl::CheckedLockImpl() : CheckedLockImpl(nullptr) {}
CheckedLockImpl::CheckedLockImpl(const CheckedLockImpl* predecessor) {
DCHECK(predecessor == nullptr || !predecessor->is_universal_successor_);
GetSafeAcquisitionTracker().RegisterLock(this, predecessor);
}
CheckedLockImpl::CheckedLockImpl(UniversalPredecessor)
: is_universal_predecessor_(true) {}
CheckedLockImpl::CheckedLockImpl(UniversalSuccessor)
: is_universal_successor_(true) {
GetSafeAcquisitionTracker().RegisterLock(this, nullptr);
}
CheckedLockImpl::~CheckedLockImpl() {
GetSafeAcquisitionTracker().UnregisterLock(this);
}
void CheckedLockImpl::AssertNoLockHeldOnCurrentThread() {
GetSafeAcquisitionTracker().AssertNoLockHeldOnCurrentThread();
}
void CheckedLockImpl::Acquire(subtle::LockTracking tracking) {
lock_.Acquire(tracking);
GetSafeAcquisitionTracker().RecordAcquisition(this);
}
void CheckedLockImpl::Release() {
lock_.Release();
GetSafeAcquisitionTracker().RecordRelease(this);
}
void CheckedLockImpl::AssertAcquired() const {
lock_.AssertAcquired();
}
void CheckedLockImpl::AssertNotHeld() const {
lock_.AssertNotHeld();
}
ConditionVariable CheckedLockImpl::CreateConditionVariable() {
return ConditionVariable(&lock_);
}
void CheckedLockImpl::CreateConditionVariableAndEmplace(
std::optional<ConditionVariable>& opt) {
opt.emplace(&lock_);
}
}