#ifndef BASE_CONTAINERS_ENUM_SET_H_
#define BASE_CONTAINERS_ENUM_SET_H_
#include <bitset>
#include <compare>
#include <cstddef>
#include <initializer_list>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include "base/check.h"
#include "base/check_op.h"
#include "base/memory/raw_ptr.h"
#include "base/types/cxx23_to_underlying.h"
#include "build/build_config.h"
namespace base {
template <typename E, E MinEnumValue, E MaxEnumValue>
class EnumSet;
template <typename E, E Min, E Max>
EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
EnumSet<E, Min, Max> set2);
template <typename E, E Min, E Max>
EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
EnumSet<E, Min, Max> set2);
template <typename E, E Min, E Max>
EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
EnumSet<E, Min, Max> set2);
template <typename E, E MinEnumValue, E MaxEnumValue>
class EnumSet {
private:
static_assert(
std::is_enum_v<E>,
"First template parameter of EnumSet must be an enumeration type");
static constexpr bool InRange(E value) {
return (value >= MinEnumValue) && (value <= MaxEnumValue);
}
public:
using EnumType = E;
static const E kMinValue = MinEnumValue;
static const E kMaxValue = MaxEnumValue;
static const size_t kValueCount =
to_underlying(kMaxValue) - to_underlying(kMinValue) + 1;
static_assert(kMinValue <= kMaxValue,
"min value must be no greater than max value");
using value_type = EnumType;
private:
using EnumBitSet = std::bitset<kValueCount>;
public:
class Iterator {
public:
using value_type = EnumType;
using size_type = size_t;
using difference_type = ptrdiff_t;
using pointer = EnumType*;
using reference = EnumType&;
using iterator_category = std::forward_iterator_tag;
Iterator() : enums_(nullptr), i_(kValueCount) {}
~Iterator() = default;
Iterator(const Iterator&) = default;
Iterator& operator=(const Iterator&) = default;
Iterator(Iterator&&) = default;
Iterator& operator=(Iterator&&) = default;
friend bool operator==(const Iterator& lhs, const Iterator& rhs) {
return lhs.i_ == rhs.i_;
}
value_type operator*() const {
DCHECK(Good());
return FromIndex(i_);
}
Iterator& operator++() {
DCHECK(Good());
i_ = FindNext(i_ + 1);
return *this;
}
Iterator operator++(int) {
DCHECK(Good());
Iterator old(*this);
i_ = FindNext(i_ + 1);
return std::move(old);
}
private:
friend Iterator EnumSet::begin() const;
explicit Iterator(const EnumBitSet& enums)
: enums_(&enums), i_(FindNext(0)) {}
bool Good() const { return enums_ && i_ < kValueCount && enums_->test(i_); }
size_t FindNext(size_t i) {
while ((i < kValueCount) && !enums_->test(i)) {
++i;
}
return i;
}
raw_ptr<const EnumBitSet> enums_;
size_t i_;
};
EnumSet() = default;
~EnumSet() = default;
constexpr EnumSet(std::initializer_list<E> values) {
if (std::is_constant_evaluated()) {
enums_ = bitstring(values);
} else {
for (E value : values) {
Put(value);
}
}
}
static constexpr EnumSet All() {
if (std::is_constant_evaluated()) {
if (kValueCount == 0) {
return EnumSet();
}
uint64_t mask = 1ULL << (kValueCount - 1);
return EnumSet(EnumBitSet(mask - 1 + mask));
} else {
EnumSet enum_set;
for (size_t value = 0; value < kValueCount; ++value) {
enum_set.Put(FromIndex(value));
}
return enum_set;
}
}
static constexpr EnumSet FromRange(E start, E end) {
CHECK_LE(start, end);
return EnumSet(EnumBitSet(
((single_val_bitstring(end)) - (single_val_bitstring(start))) |
(single_val_bitstring(end))));
}
static constexpr EnumSet FromEnumBitmask(const uint64_t bitmask) {
static_assert(to_underlying(kMaxValue) < 64,
"The highest enum value must be < 64 for FromEnumBitmask ");
static_assert(to_underlying(kMinValue) >= 0,
"The lowest enum value must be >= 0 for FromEnumBitmask ");
return EnumSet(EnumBitSet(bitmask >> to_underlying(kMinValue)));
}
uint64_t ToEnumBitmask() const {
static_assert(to_underlying(kMaxValue) < 64,
"The highest enum value must be < 64 for ToEnumBitmask ");
static_assert(to_underlying(kMinValue) >= 0,
"The lowest enum value must be >= 0 for FromEnumBitmask ");
return enums_.to_ullong() << to_underlying(kMinValue);
}
std::optional<uint64_t> GetNth64bitWordBitmask(size_t n) const {
if (to_underlying(kMaxValue) / 64 < n) {
return std::nullopt;
}
std::bitset<kValueCount> mask = ~uint64_t{0};
std::bitset<kValueCount> bits = enums_;
if (to_underlying(kMinValue) < n * 64) {
bits >>= n * 64 - to_underlying(kMinValue);
}
uint64_t result = (bits & mask).to_ullong();
if (to_underlying(kMinValue) > n * 64) {
result <<= to_underlying(kMinValue) - n * 64;
}
return result;
}
void Put(E value) { enums_.set(ToIndex(value)); }
void PutAll(EnumSet other) { enums_ |= other.enums_; }
void PutRange(E start, E end) {
CHECK_LE(start, end);
size_t endIndexInclusive = ToIndex(end);
for (size_t current = ToIndex(start); current <= endIndexInclusive;
++current) {
enums_.set(current);
}
}
void RetainAll(EnumSet other) { enums_ &= other.enums_; }
void Remove(E value) {
if (InRange(value)) {
enums_.reset(ToIndex(value));
}
}
void RemoveAll(EnumSet other) { enums_ &= ~other.enums_; }
void Clear() { enums_.reset(); }
void PutOrRemove(E value, bool should_be_present) {
if (should_be_present) {
Put(value);
} else {
Remove(value);
}
}
constexpr bool Has(E value) const {
return InRange(value) && enums_[ToIndex(value)];
}
bool HasAll(EnumSet other) const {
return (enums_ & other.enums_) == other.enums_;
}
bool HasAny(EnumSet other) const {
return (enums_ & other.enums_).count() > 0;
}
bool empty() const { return !enums_.any(); }
size_t size() const { return enums_.count(); }
Iterator begin() const { return Iterator(enums_); }
Iterator end() const { return Iterator(); }
friend bool operator==(const EnumSet& a, const EnumSet& b) = default;
friend auto operator<=>(const EnumSet& a, const EnumSet& b) {
return a.ToEnumBitmask() <=> b.ToEnumBitmask();
}
std::string ToString() const { return enums_.to_string(); }
template <typename H>
friend H AbslHashValue(H h, EnumSet e) {
return H::combine(std::move(h), e.enums_);
}
private:
friend EnumSet Union<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
EnumSet set2);
friend EnumSet Intersection<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
EnumSet set2);
friend EnumSet Difference<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
EnumSet set2);
static constexpr uint64_t bitstring(const std::initializer_list<E>& values) {
uint64_t result = 0;
for (E value : values) {
result |= single_val_bitstring(value);
}
return result;
}
static constexpr uint64_t single_val_bitstring(E val) {
const uint64_t bitstring = 1;
const size_t shift_amount = ToIndex(val);
CHECK_LT(shift_amount, sizeof(bitstring) * 8);
return bitstring << shift_amount;
}
explicit constexpr EnumSet(EnumBitSet enums) : enums_(enums) {
if (std::is_constant_evaluated()) {
CHECK(kValueCount <= 64)
<< "Max number of enum values is 64 for constexpr constructor";
}
}
static constexpr size_t ToIndex(E value) {
CHECK(InRange(value));
return static_cast<size_t>(to_underlying(value)) -
static_cast<size_t>(to_underlying(MinEnumValue));
}
static E FromIndex(size_t i) {
DCHECK_LT(i, kValueCount);
return static_cast<E>(to_underlying(MinEnumValue) + i);
}
EnumBitSet enums_;
};
template <typename E, E MinEnumValue, E MaxEnumValue>
const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMinValue;
template <typename E, E MinEnumValue, E MaxEnumValue>
const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMaxValue;
template <typename E, E MinEnumValue, E MaxEnumValue>
const size_t EnumSet<E, MinEnumValue, MaxEnumValue>::kValueCount;
template <typename E, E Min, E Max>
EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
EnumSet<E, Min, Max> set2) {
return EnumSet<E, Min, Max>(set1.enums_ | set2.enums_);
}
template <typename E, E Min, E Max>
EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
EnumSet<E, Min, Max> set2) {
return EnumSet<E, Min, Max>(set1.enums_ & set2.enums_);
}
template <typename E, E Min, E Max>
EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
EnumSet<E, Min, Max> set2) {
return EnumSet<E, Min, Max>(set1.enums_ & ~set2.enums_);
}
}
#endif