--- /dev/null
+/*
+ * Copyright 2017-present Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/Optional.h>
+#include <folly/concurrency/detail/ConcurrentHashMap-detail.h>
+#include <folly/experimental/hazptr/hazptr.h>
+#include <atomic>
+#include <mutex>
+
+namespace folly {
+
+/**
+ * Based on Java's ConcurrentHashMap
+ *
+ * Readers are always wait-free.
+ * Writers are sharded, but take a lock.
+ *
+ * The interface is as close to std::unordered_map as possible, but there
+ * are a handful of changes:
+ *
+ * * Iterators hold hazard pointers to the returned elements. Elements can only
+ * be accessed while Iterators are still valid!
+ *
+ * * Therefore operator[] and at() return copies, since they do not
+ * return an iterator. The returned value is const, to remind you
+ * that changes do not affect the value in the map.
+ *
+ * * erase() calls the hash function, and may fail if the hash
+ * function throws an exception.
+ *
+ * * clear() initializes new segments, and is not noexcept.
+ *
+ * * The interface adds assign_if_equal, since find() doesn't take a lock.
+ *
+ * * Only const version of find() is supported, and const iterators.
+ * Mutation must use functions provided, like assign().
+ *
+ * * iteration iterates over all the buckets in the table, unlike
+ * std::unordered_map which iterates over a linked list of elements.
+ * If the table is sparse, this may be more expensive.
+ *
+ * * rehash policy is a power of two, using supplied factor.
+ *
+ * * Allocator must be stateless.
+ *
+ * * ValueTypes without copy constructors will work, but pessimize the
+ * implementation.
+ *
+ * Comparisons:
+ * Single-threaded performance is extremely similar to std::unordered_map.
+ *
+ * Multithreaded performance beats anything except the lock-free
+ * atomic maps (AtomicUnorderedMap, AtomicHashMap), BUT only
+ * if you can perfectly size the atomic maps, and you don't
+ * need erase(). If you don't know the size in advance or
+ * your workload frequently calls erase(), this is the
+ * better choice.
+ */
+
+template <
+ typename KeyType,
+ typename ValueType,
+ typename HashFn = std::hash<KeyType>,
+ typename KeyEqual = std::equal_to<KeyType>,
+ typename Allocator = std::allocator<uint8_t>,
+ uint8_t ShardBits = 8,
+ template <typename> class Atom = std::atomic,
+ class Mutex = std::mutex>
+class ConcurrentHashMap {
+ using SegmentT = detail::ConcurrentHashMapSegment<
+ KeyType,
+ ValueType,
+ ShardBits,
+ HashFn,
+ KeyEqual,
+ Allocator,
+ Atom,
+ Mutex>;
+ static constexpr uint64_t NumShards = (1 << ShardBits);
+ // Slightly higher than 1.0, in case hashing to shards isn't
+ // perfectly balanced, reserve(size) will still work without
+ // rehashing.
+ float load_factor_ = 1.05;
+
+ public:
+ class ConstIterator;
+
+ typedef KeyType key_type;
+ typedef ValueType mapped_type;
+ typedef std::pair<const KeyType, ValueType> value_type;
+ typedef std::size_t size_type;
+ typedef HashFn hasher;
+ typedef KeyEqual key_equal;
+ typedef ConstIterator const_iterator;
+
+ /*
+ * Construct a ConcurrentHashMap with 1 << ShardBits shards, size
+ * and max_size given. Both size and max_size will be rounded up to
+ * the next power of two, if they are not already a power of two, so
+ * that we can index in to Shards efficiently.
+ *
+ * Insertion functions will throw bad_alloc if max_size is exceeded.
+ */
+ explicit ConcurrentHashMap(size_t size = 8, size_t max_size = 0) {
+ size_ = folly::nextPowTwo(size);
+ if (max_size != 0) {
+ max_size_ = folly::nextPowTwo(max_size);
+ }
+ CHECK(max_size_ == 0 || max_size_ >= size_);
+ for (uint64_t i = 0; i < NumShards; i++) {
+ segments_[i].store(nullptr, std::memory_order_relaxed);
+ }
+ }
+
+ ConcurrentHashMap(ConcurrentHashMap&& o) noexcept {
+ for (uint64_t i = 0; i < NumShards; i++) {
+ segments_[i].store(
+ o.segments_[i].load(std::memory_order_relaxed),
+ std::memory_order_relaxed);
+ o.segments_[i].store(nullptr, std::memory_order_relaxed);
+ }
+ }
+
+ ConcurrentHashMap& operator=(ConcurrentHashMap&& o) {
+ for (uint64_t i = 0; i < NumShards; i++) {
+ auto seg = segments_[i].load(std::memory_order_relaxed);
+ if (seg) {
+ seg->~SegmentT();
+ Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
+ }
+ segments_[i].store(
+ o.segments_[i].load(std::memory_order_relaxed),
+ std::memory_order_relaxed);
+ o.segments_[i].store(nullptr, std::memory_order_relaxed);
+ }
+ return *this;
+ }
+
+ ~ConcurrentHashMap() {
+ for (uint64_t i = 0; i < NumShards; i++) {
+ auto seg = segments_[i].load(std::memory_order_relaxed);
+ if (seg) {
+ seg->~SegmentT();
+ Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
+ }
+ }
+ }
+
+ bool empty() const noexcept {
+ for (uint64_t i = 0; i < NumShards; i++) {
+ auto seg = segments_[i].load(std::memory_order_acquire);
+ if (seg) {
+ if (!seg->empty()) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ ConstIterator find(const KeyType& k) const {
+ auto segment = pickSegment(k);
+ ConstIterator res(this, segment);
+ auto seg = segments_[segment].load(std::memory_order_acquire);
+ if (!seg || !seg->find(res.it_, k)) {
+ res.segment_ = NumShards;
+ }
+ return res;
+ }
+
+ ConstIterator cend() const noexcept {
+ return ConstIterator(NumShards);
+ }
+
+ ConstIterator cbegin() const noexcept {
+ return ConstIterator(this);
+ }
+
+ std::pair<ConstIterator, bool> insert(
+ std::pair<key_type, mapped_type>&& foo) {
+ auto segment = pickSegment(foo.first);
+ std::pair<ConstIterator, bool> res(
+ std::piecewise_construct,
+ std::forward_as_tuple(this, segment),
+ std::forward_as_tuple(false));
+ res.second = ensureSegment(segment)->insert(res.first.it_, std::move(foo));
+ return res;
+ }
+
+ std::pair<ConstIterator, bool> insert(const KeyType& k, const ValueType& v) {
+ auto segment = pickSegment(k);
+ std::pair<ConstIterator, bool> res(
+ std::piecewise_construct,
+ std::forward_as_tuple(this, segment),
+ std::forward_as_tuple(false));
+ res.second = ensureSegment(segment)->insert(res.first.it_, k, v);
+ return res;
+ }
+
+ template <typename... Args>
+ std::pair<ConstIterator, bool> try_emplace(const KeyType& k, Args&&... args) {
+ auto segment = pickSegment(k);
+ std::pair<ConstIterator, bool> res(
+ std::piecewise_construct,
+ std::forward_as_tuple(this, segment),
+ std::forward_as_tuple(false));
+ res.second = ensureSegment(segment)->try_emplace(
+ res.first.it_, k, std::forward<Args>(args)...);
+ return res;
+ }
+
+ template <typename... Args>
+ std::pair<ConstIterator, bool> emplace(Args&&... args) {
+ using Node = typename SegmentT::Node;
+ auto node = (Node*)Allocator().allocate(sizeof(Node));
+ new (node) Node(std::forward<Args>(args)...);
+ auto segment = pickSegment(node->getItem().first);
+ std::pair<ConstIterator, bool> res(
+ std::piecewise_construct,
+ std::forward_as_tuple(this, segment),
+ std::forward_as_tuple(false));
+ res.second = ensureSegment(segment)->emplace(
+ res.first.it_, node->getItem().first, node);
+ if (!res.second) {
+ node->~Node();
+ Allocator().deallocate((uint8_t*)node, sizeof(Node));
+ }
+ return res;
+ }
+
+ std::pair<ConstIterator, bool> insert_or_assign(
+ const KeyType& k,
+ const ValueType& v) {
+ auto segment = pickSegment(k);
+ std::pair<ConstIterator, bool> res(
+ std::piecewise_construct,
+ std::forward_as_tuple(this, segment),
+ std::forward_as_tuple(false));
+ res.second = ensureSegment(segment)->insert_or_assign(res.first.it_, k, v);
+ return res;
+ }
+
+ folly::Optional<ConstIterator> assign(const KeyType& k, const ValueType& v) {
+ auto segment = pickSegment(k);
+ ConstIterator res(this, segment);
+ auto seg = segments_[segment].load(std::memory_order_acquire);
+ if (!seg) {
+ return folly::Optional<ConstIterator>();
+ } else {
+ auto r = seg->assign(res.it_, k, v);
+ if (!r) {
+ return folly::Optional<ConstIterator>();
+ }
+ }
+ return res;
+ }
+
+ // Assign to desired if and only if key k is equal to expected
+ folly::Optional<ConstIterator> assign_if_equal(
+ const KeyType& k,
+ const ValueType& expected,
+ const ValueType& desired) {
+ auto segment = pickSegment(k);
+ ConstIterator res(this, segment);
+ auto seg = segments_[segment].load(std::memory_order_acquire);
+ if (!seg) {
+ return folly::Optional<ConstIterator>();
+ } else {
+ auto r = seg->assign_if_equal(res.it_, k, expected, desired);
+ if (!r) {
+ return folly::Optional<ConstIterator>();
+ }
+ }
+ return res;
+ }
+
+ // Copying wrappers around insert and find.
+ // Only available for copyable types.
+ const ValueType operator[](const KeyType& key) {
+ auto item = insert(key, ValueType());
+ return item.first->second;
+ }
+
+ const ValueType at(const KeyType& key) const {
+ auto item = find(key);
+ if (item == cend()) {
+ throw std::out_of_range("at(): value out of range");
+ }
+ return item->second;
+ }
+
+ // TODO update assign interface, operator[], at
+
+ size_type erase(const key_type& k) {
+ auto segment = pickSegment(k);
+ auto seg = segments_[segment].load(std::memory_order_acquire);
+ if (!seg) {
+ return 0;
+ } else {
+ return seg->erase(k);
+ }
+ }
+
+ // Calls the hash function, and therefore may throw.
+ ConstIterator erase(ConstIterator& pos) {
+ auto segment = pickSegment(pos->first);
+ ConstIterator res(this, segment);
+ res.next();
+ ensureSegment(segment)->erase(res.it_, pos.it_);
+ res.next(); // May point to segment end, and need to advance.
+ return res;
+ }
+
+ // NOT noexcept, initializes new shard segments vs.
+ void clear() {
+ for (uint64_t i = 0; i < NumShards; i++) {
+ auto seg = segments_[i].load(std::memory_order_acquire);
+ if (seg) {
+ seg->clear();
+ }
+ }
+ }
+
+ void reserve(size_t count) {
+ count = count >> ShardBits;
+ for (uint64_t i = 0; i < NumShards; i++) {
+ auto seg = segments_[i].load(std::memory_order_acquire);
+ if (seg) {
+ seg->rehash(count);
+ }
+ }
+ }
+
+ // This is a rolling size, and is not exact at any moment in time.
+ size_t size() const noexcept {
+ size_t res = 0;
+ for (uint64_t i = 0; i < NumShards; i++) {
+ auto seg = segments_[i].load(std::memory_order_acquire);
+ if (seg) {
+ res += seg->size();
+ }
+ }
+ return res;
+ }
+
+ float max_load_factor() const {
+ return load_factor_;
+ }
+
+ void max_load_factor(float factor) {
+ for (uint64_t i = 0; i < NumShards; i++) {
+ auto seg = segments_[i].load(std::memory_order_acquire);
+ if (seg) {
+ seg->max_load_factor(factor);
+ }
+ }
+ }
+
+ class ConstIterator {
+ public:
+ friend class ConcurrentHashMap;
+
+ const value_type& operator*() const {
+ return *it_;
+ }
+
+ const value_type* operator->() const {
+ return &*it_;
+ }
+
+ ConstIterator& operator++() {
+ it_++;
+ next();
+ return *this;
+ }
+
+ ConstIterator operator++(int) {
+ auto prev = *this;
+ ++*this;
+ return prev;
+ }
+
+ bool operator==(const ConstIterator& o) const {
+ return it_ == o.it_ && segment_ == o.segment_;
+ }
+
+ bool operator!=(const ConstIterator& o) const {
+ return !(*this == o);
+ }
+
+ ConstIterator& operator=(const ConstIterator& o) {
+ it_ = o.it_;
+ segment_ = o.segment_;
+ return *this;
+ }
+
+ ConstIterator(const ConstIterator& o) {
+ it_ = o.it_;
+ segment_ = o.segment_;
+ }
+
+ ConstIterator(const ConcurrentHashMap* parent, uint64_t segment)
+ : segment_(segment), parent_(parent) {}
+
+ private:
+ // cbegin iterator
+ explicit ConstIterator(const ConcurrentHashMap* parent)
+ : it_(parent->ensureSegment(0)->cbegin()),
+ segment_(0),
+ parent_(parent) {
+ // Always iterate to the first element, could be in any shard.
+ next();
+ }
+
+ // cend iterator
+ explicit ConstIterator(uint64_t shards) : it_(nullptr), segment_(shards) {}
+
+ void next() {
+ while (it_ == parent_->ensureSegment(segment_)->cend() &&
+ segment_ < parent_->NumShards) {
+ segment_++;
+ auto seg = parent_->segments_[segment_].load(std::memory_order_acquire);
+ if (segment_ < parent_->NumShards) {
+ if (!seg) {
+ continue;
+ }
+ it_ = seg->cbegin();
+ }
+ }
+ }
+
+ typename SegmentT::Iterator it_;
+ uint64_t segment_;
+ const ConcurrentHashMap* parent_;
+ };
+
+ private:
+ uint64_t pickSegment(const KeyType& k) const {
+ auto h = HashFn()(k);
+ // Use the lowest bits for our shard bits.
+ //
+ // This works well even if the hash function is biased towards the
+ // low bits: The sharding will happen in the segments_ instead of
+ // in the segment buckets, so we'll still get write sharding as
+ // well.
+ //
+ // Low-bit bias happens often for std::hash using small numbers,
+ // since the integer hash function is the identity function.
+ return h & (NumShards - 1);
+ }
+
+ SegmentT* ensureSegment(uint64_t i) const {
+ auto seg = segments_[i].load(std::memory_order_acquire);
+ if (!seg) {
+ auto newseg = (SegmentT*)Allocator().allocate(sizeof(SegmentT));
+ new (newseg)
+ SegmentT(size_ >> ShardBits, load_factor_, max_size_ >> ShardBits);
+ if (!segments_[i].compare_exchange_strong(seg, newseg)) {
+ // seg is updated with new value, delete ours.
+ newseg->~SegmentT();
+ Allocator().deallocate((uint8_t*)newseg, sizeof(SegmentT));
+ } else {
+ seg = newseg;
+ }
+ }
+ return seg;
+ }
+
+ mutable Atom<SegmentT*> segments_[NumShards];
+ size_t size_{0};
+ size_t max_size_{0};
+};
+
+} // namespace
--- /dev/null
+/*
+ * Copyright 2017-present Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/experimental/hazptr/hazptr.h>
+#include <atomic>
+#include <mutex>
+
+namespace folly {
+
+namespace detail {
+
+namespace concurrenthashmap {
+
+// hazptr retire() that can use an allocator.
+template <typename Allocator>
+class HazptrDeleter {
+ public:
+ template <typename Node>
+ void operator()(Node* node) {
+ node->~Node();
+ Allocator().deallocate((uint8_t*)node, sizeof(Node));
+ }
+};
+
+template <
+ typename KeyType,
+ typename ValueType,
+ typename Allocator,
+ typename Enabled = void>
+class ValueHolder {
+ public:
+ typedef std::pair<const KeyType, ValueType> value_type;
+
+ explicit ValueHolder(const ValueHolder& other) : item_(other.item_) {}
+
+ template <typename... Args>
+ ValueHolder(const KeyType& k, Args&&... args)
+ : item_(
+ std::piecewise_construct,
+ std::forward_as_tuple(k),
+ std::forward_as_tuple(std::forward<Args>(args)...)) {}
+ value_type& getItem() {
+ return item_;
+ }
+
+ private:
+ value_type item_;
+};
+
+// If the ValueType is not copy constructible, we can instead add
+// an extra indirection. Adds more allocations / deallocations and
+// pulls in an extra cacheline.
+template <typename KeyType, typename ValueType, typename Allocator>
+class ValueHolder<
+ KeyType,
+ ValueType,
+ Allocator,
+ std::enable_if_t<!std::is_nothrow_copy_constructible<ValueType>::value>> {
+ public:
+ typedef std::pair<const KeyType, ValueType> value_type;
+
+ explicit ValueHolder(const ValueHolder& other) {
+ other.owned_ = false;
+ item_ = other.item_;
+ }
+
+ template <typename... Args>
+ ValueHolder(const KeyType& k, Args&&... args) {
+ item_ = (value_type*)Allocator().allocate(sizeof(value_type));
+ new (item_) value_type(
+ std::piecewise_construct,
+ std::forward_as_tuple(k),
+ std::forward_as_tuple(std::forward<Args>(args)...));
+ }
+
+ ~ValueHolder() {
+ if (owned_) {
+ item_->~value_type();
+ Allocator().deallocate((uint8_t*)item_, sizeof(value_type));
+ }
+ }
+
+ value_type& getItem() {
+ return *item_;
+ }
+
+ private:
+ value_type* item_;
+ mutable bool owned_{true};
+};
+
+template <
+ typename KeyType,
+ typename ValueType,
+ typename Allocator,
+ template <typename> class Atom = std::atomic>
+class NodeT : public folly::hazptr::hazptr_obj_base<
+ NodeT<KeyType, ValueType, Allocator, Atom>,
+ concurrenthashmap::HazptrDeleter<Allocator>> {
+ public:
+ typedef std::pair<const KeyType, ValueType> value_type;
+
+ explicit NodeT(NodeT* other) : item_(other->item_) {}
+
+ template <typename... Args>
+ NodeT(const KeyType& k, Args&&... args)
+ : item_(k, std::forward<Args>(args)...) {}
+
+ /* Nodes are refcounted: If a node is retired() while a writer is
+ traversing the chain, the rest of the chain must remain valid
+ until all readers are finished. This includes the shared tail
+ portion of the chain, as well as both old/new hash buckets that
+ may point to the same portion, and erased nodes may increase the
+ refcount */
+ void acquire() {
+ DCHECK(refcount_.load() != 0);
+ refcount_.fetch_add(1);
+ }
+ void release() {
+ if (refcount_.fetch_sub(1) == 1 /* was previously 1 */) {
+ this->retire(
+ folly::hazptr::default_hazptr_domain(),
+ concurrenthashmap::HazptrDeleter<Allocator>());
+ }
+ }
+ ~NodeT() {
+ auto next = next_.load(std::memory_order_acquire);
+ if (next) {
+ next->release();
+ }
+ }
+
+ value_type& getItem() {
+ return item_.getItem();
+ }
+ Atom<NodeT*> next_{nullptr};
+
+ private:
+ ValueHolder<KeyType, ValueType, Allocator> item_;
+ Atom<uint8_t> refcount_{1};
+};
+
+} // namespace concurrenthashmap
+
+/* A Segment is a single shard of the ConcurrentHashMap.
+ * All writes take the lock, while readers are all wait-free.
+ * Readers always proceed in parallel with the single writer.
+ *
+ *
+ * Possible additional optimizations:
+ *
+ * * insert / erase could be lock / wait free. Would need to be
+ * careful that assign and rehash don't conflict (possibly with
+ * reader/writer lock, or microlock per node or per bucket, etc).
+ * Java 8 goes halfway, and and does lock per bucket, except for the
+ * first item, that is inserted with a CAS (which is somewhat
+ * specific to java having a lock per object)
+ *
+ * * I tried using trylock() and find() to warm the cache for insert()
+ * and erase() similar to Java 7, but didn't have much luck.
+ *
+ * * We could order elements using split ordering, for faster rehash,
+ * and no need to ever copy nodes. Note that a full split ordering
+ * including dummy nodes increases the memory usage by 2x, but we
+ * could split the difference and still require a lock to set bucket
+ * pointers.
+ *
+ * * hazptr acquire/release could be optimized more, in
+ * single-threaded case, hazptr overhead is ~30% for a hot find()
+ * loop.
+ */
+template <
+ typename KeyType,
+ typename ValueType,
+ uint8_t ShardBits = 0,
+ typename HashFn = std::hash<KeyType>,
+ typename KeyEqual = std::equal_to<KeyType>,
+ typename Allocator = std::allocator<uint8_t>,
+ template <typename> class Atom = std::atomic,
+ class Mutex = std::mutex>
+class FOLLY_ALIGNED(64) ConcurrentHashMapSegment {
+ enum class InsertType {
+ DOES_NOT_EXIST, // insert/emplace operations. If key exists, return false.
+ MUST_EXIST, // assign operations. If key does not exist, return false.
+ ANY, // insert_or_assign.
+ MATCH, // assign_if_equal (not in std). For concurrent maps, a
+ // way to atomically change a value if equal to some other
+ // value.
+ };
+
+ public:
+ typedef KeyType key_type;
+ typedef ValueType mapped_type;
+ typedef std::pair<const KeyType, ValueType> value_type;
+ typedef std::size_t size_type;
+
+ using Node = concurrenthashmap::NodeT<KeyType, ValueType, Allocator, Atom>;
+ class Iterator;
+
+ ConcurrentHashMapSegment(
+ size_t initial_buckets,
+ float load_factor,
+ size_t max_size)
+ : load_factor_(load_factor) {
+ auto buckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
+ initial_buckets = folly::nextPowTwo(initial_buckets);
+ if (max_size != 0) {
+ max_size_ = folly::nextPowTwo(max_size);
+ }
+ if (max_size_ > max_size) {
+ max_size_ >> 1;
+ }
+
+ CHECK(max_size_ == 0 || (folly::popcount(max_size_ - 1) + ShardBits <= 32));
+ new (buckets) Buckets(initial_buckets);
+ buckets_.store(buckets, std::memory_order_release);
+ load_factor_nodes_ = initial_buckets * load_factor_;
+ }
+
+ ~ConcurrentHashMapSegment() {
+ auto buckets = buckets_.load(std::memory_order_relaxed);
+ // We can delete and not retire() here, since users must have
+ // their own synchronization around destruction.
+ buckets->~Buckets();
+ Allocator().deallocate((uint8_t*)buckets, sizeof(Buckets));
+ }
+
+ size_t size() {
+ return size_;
+ }
+
+ bool empty() {
+ return size() == 0;
+ }
+
+ bool insert(Iterator& it, std::pair<key_type, mapped_type>&& foo) {
+ return insert(it, foo.first, foo.second);
+ }
+
+ bool insert(Iterator& it, const KeyType& k, const ValueType& v) {
+ auto node = (Node*)Allocator().allocate(sizeof(Node));
+ new (node) Node(k, v);
+ auto res = insert_internal(
+ it,
+ k,
+ InsertType::DOES_NOT_EXIST,
+ [](const ValueType&) { return false; },
+ node,
+ v);
+ if (!res) {
+ node->~Node();
+ Allocator().deallocate((uint8_t*)node, sizeof(Node));
+ }
+ return res;
+ }
+
+ template <typename... Args>
+ bool try_emplace(Iterator& it, const KeyType& k, Args&&... args) {
+ return insert_internal(
+ it,
+ k,
+ InsertType::DOES_NOT_EXIST,
+ [](const ValueType&) { return false; },
+ nullptr,
+ std::forward<Args>(args)...);
+ }
+
+ template <typename... Args>
+ bool emplace(Iterator& it, const KeyType& k, Node* node) {
+ return insert_internal(
+ it,
+ k,
+ InsertType::DOES_NOT_EXIST,
+ [](const ValueType&) { return false; },
+ node);
+ }
+
+ bool insert_or_assign(Iterator& it, const KeyType& k, const ValueType& v) {
+ return insert_internal(
+ it,
+ k,
+ InsertType::ANY,
+ [](const ValueType&) { return false; },
+ nullptr,
+ v);
+ }
+
+ bool assign(Iterator& it, const KeyType& k, const ValueType& v) {
+ auto node = (Node*)Allocator().allocate(sizeof(Node));
+ new (node) Node(k, v);
+ auto res = insert_internal(
+ it,
+ k,
+ InsertType::MUST_EXIST,
+ [](const ValueType&) { return false; },
+ node,
+ v);
+ if (!res) {
+ node->~Node();
+ Allocator().deallocate((uint8_t*)node, sizeof(Node));
+ }
+ return res;
+ }
+
+ bool assign_if_equal(
+ Iterator& it,
+ const KeyType& k,
+ const ValueType& expected,
+ const ValueType& desired) {
+ return insert_internal(
+ it,
+ k,
+ InsertType::MATCH,
+ [expected](const ValueType& v) { return v == expected; },
+ nullptr,
+ desired);
+ }
+
+ template <typename MatchFunc, typename... Args>
+ bool insert_internal(
+ Iterator& it,
+ const KeyType& k,
+ InsertType type,
+ MatchFunc match,
+ Node* cur,
+ Args&&... args) {
+ auto h = HashFn()(k);
+ std::unique_lock<Mutex> g(m_);
+
+ auto buckets = buckets_.load(std::memory_order_relaxed);
+ // Check for rehash needed for DOES_NOT_EXIST
+ if (size_ >= load_factor_nodes_ && type == InsertType::DOES_NOT_EXIST) {
+ if (max_size_ && size_ << 1 > max_size_) {
+ // Would exceed max size.
+ throw std::bad_alloc();
+ }
+ rehash(buckets->bucket_count_ << 1);
+ buckets = buckets_.load(std::memory_order_relaxed);
+ }
+
+ auto idx = getIdx(buckets, h);
+ auto head = &buckets->buckets_[idx];
+ auto node = head->load(std::memory_order_relaxed);
+ auto headnode = node;
+ auto prev = head;
+ it.buckets_hazptr_.reset(buckets);
+ while (node) {
+ // Is the key found?
+ if (KeyEqual()(k, node->getItem().first)) {
+ it.setNode(node, buckets, idx);
+ it.node_hazptr_.reset(node);
+ if (type == InsertType::MATCH) {
+ if (!match(node->getItem().second)) {
+ return false;
+ }
+ }
+ if (type == InsertType::DOES_NOT_EXIST) {
+ return false;
+ } else {
+ if (!cur) {
+ cur = (Node*)Allocator().allocate(sizeof(Node));
+ new (cur) Node(k, std::forward<Args>(args)...);
+ }
+ auto next = node->next_.load(std::memory_order_relaxed);
+ cur->next_.store(next, std::memory_order_relaxed);
+ if (next) {
+ next->acquire();
+ }
+ prev->store(cur, std::memory_order_release);
+ g.unlock();
+ // Release not under lock.
+ node->release();
+ return true;
+ }
+ }
+
+ prev = &node->next_;
+ node = node->next_.load(std::memory_order_relaxed);
+ }
+ if (type != InsertType::DOES_NOT_EXIST && type != InsertType::ANY) {
+ it.node_hazptr_.reset();
+ it.buckets_hazptr_.reset();
+ return false;
+ }
+ // Node not found, check for rehash on ANY
+ if (size_ >= load_factor_nodes_ && type == InsertType::ANY) {
+ if (max_size_ && size_ << 1 > max_size_) {
+ // Would exceed max size.
+ throw std::bad_alloc();
+ }
+ rehash(buckets->bucket_count_ << 1);
+
+ // Reload correct bucket.
+ buckets = buckets_.load(std::memory_order_relaxed);
+ it.buckets_hazptr_.reset(buckets);
+ idx = getIdx(buckets, h);
+ head = &buckets->buckets_[idx];
+ headnode = head->load(std::memory_order_relaxed);
+ }
+
+ // We found a slot to put the node.
+ size_++;
+ if (!cur) {
+ // InsertType::ANY
+ // OR DOES_NOT_EXIST, but only in the try_emplace case
+ DCHECK(type == InsertType::ANY || type == InsertType::DOES_NOT_EXIST);
+ cur = (Node*)Allocator().allocate(sizeof(Node));
+ new (cur) Node(k, std::forward<Args>(args)...);
+ }
+ cur->next_.store(headnode, std::memory_order_relaxed);
+ head->store(cur, std::memory_order_release);
+ it.setNode(cur, buckets, idx);
+ return true;
+ }
+
+ // Must hold lock.
+ void rehash(size_t bucket_count) {
+ auto buckets = buckets_.load(std::memory_order_relaxed);
+ auto newbuckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
+ new (newbuckets) Buckets(bucket_count);
+
+ load_factor_nodes_ = bucket_count * load_factor_;
+
+ for (size_t i = 0; i < buckets->bucket_count_; i++) {
+ auto bucket = &buckets->buckets_[i];
+ auto node = bucket->load(std::memory_order_relaxed);
+ if (!node) {
+ continue;
+ }
+ auto h = HashFn()(node->getItem().first);
+ auto idx = getIdx(newbuckets, h);
+ // Reuse as long a chain as possible from the end. Since the
+ // nodes don't have previous pointers, the longest last chain
+ // will be the same for both the previous hashmap and the new one,
+ // assuming all the nodes hash to the same bucket.
+ auto lastrun = node;
+ auto lastidx = idx;
+ auto count = 0;
+ auto last = node->next_.load(std::memory_order_relaxed);
+ for (; last != nullptr;
+ last = last->next_.load(std::memory_order_relaxed)) {
+ auto k = getIdx(newbuckets, HashFn()(last->getItem().first));
+ if (k != lastidx) {
+ lastidx = k;
+ lastrun = last;
+ count = 0;
+ }
+ count++;
+ }
+ // Set longest last run in new bucket, incrementing the refcount.
+ lastrun->acquire();
+ newbuckets->buckets_[lastidx].store(lastrun, std::memory_order_relaxed);
+ // Clone remaining nodes
+ for (; node != lastrun;
+ node = node->next_.load(std::memory_order_relaxed)) {
+ auto newnode = (Node*)Allocator().allocate(sizeof(Node));
+ new (newnode) Node(node);
+ auto k = getIdx(newbuckets, HashFn()(node->getItem().first));
+ auto prevhead = &newbuckets->buckets_[k];
+ newnode->next_.store(prevhead->load(std::memory_order_relaxed));
+ prevhead->store(newnode, std::memory_order_relaxed);
+ }
+ }
+
+ auto oldbuckets = buckets_.load(std::memory_order_relaxed);
+ buckets_.store(newbuckets, std::memory_order_release);
+ oldbuckets->retire(
+ folly::hazptr::default_hazptr_domain(),
+ concurrenthashmap::HazptrDeleter<Allocator>());
+ }
+
+ bool find(Iterator& res, const KeyType& k) {
+ folly::hazptr::hazptr_holder haznext;
+ auto h = HashFn()(k);
+ auto buckets = res.buckets_hazptr_.get_protected(buckets_);
+ auto idx = getIdx(buckets, h);
+ auto prev = &buckets->buckets_[idx];
+ auto node = res.node_hazptr_.get_protected(*prev);
+ while (node) {
+ if (KeyEqual()(k, node->getItem().first)) {
+ res.setNode(node, buckets, idx);
+ return true;
+ }
+ node = haznext.get_protected(node->next_);
+ haznext.swap(res.node_hazptr_);
+ }
+ return false;
+ }
+
+ // Listed separately because we need a prev pointer.
+ size_type erase(const key_type& key) {
+ return erase_internal(key, nullptr);
+ }
+
+ size_type erase_internal(const key_type& key, Iterator* iter) {
+ Node* node{nullptr};
+ auto h = HashFn()(key);
+ {
+ std::lock_guard<Mutex> g(m_);
+
+ auto buckets = buckets_.load(std::memory_order_relaxed);
+ auto idx = getIdx(buckets, h);
+ auto head = &buckets->buckets_[idx];
+ node = head->load(std::memory_order_relaxed);
+ Node* prev = nullptr;
+ auto headnode = node;
+ while (node) {
+ if (KeyEqual()(key, node->getItem().first)) {
+ auto next = node->next_.load(std::memory_order_relaxed);
+ if (next) {
+ next->acquire();
+ }
+ if (prev) {
+ prev->next_.store(next, std::memory_order_release);
+ } else {
+ // Must be head of list.
+ head->store(next, std::memory_order_release);
+ }
+
+ if (iter) {
+ iter->buckets_hazptr_.reset(buckets);
+ iter->setNode(
+ node->next_.load(std::memory_order_acquire), buckets, idx);
+ }
+ size_--;
+ break;
+ }
+ prev = node;
+ node = node->next_.load(std::memory_order_relaxed);
+ }
+ }
+ // Delete the node while not under the lock.
+ if (node) {
+ node->release();
+ return 1;
+ }
+ DCHECK(!iter);
+
+ return 0;
+ }
+
+ // Unfortunately because we are reusing nodes on rehash, we can't
+ // have prev pointers in the bucket chain. We have to start the
+ // search from the bucket.
+ //
+ // This is a small departure from standard stl containers: erase may
+ // throw if hash or key_eq functions throw.
+ void erase(Iterator& res, Iterator& pos) {
+ auto cnt = erase_internal(pos->first, &res);
+ DCHECK(cnt == 1);
+ }
+
+ void clear() {
+ auto buckets = buckets_.load(std::memory_order_relaxed);
+ auto newbuckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
+ new (newbuckets) Buckets(buckets->bucket_count_);
+ {
+ std::lock_guard<Mutex> g(m_);
+ buckets_.store(newbuckets, std::memory_order_release);
+ size_ = 0;
+ }
+ buckets->retire(
+ folly::hazptr::default_hazptr_domain(),
+ concurrenthashmap::HazptrDeleter<Allocator>());
+ }
+
+ void max_load_factor(float factor) {
+ std::lock_guard<Mutex> g(m_);
+ load_factor_ = factor;
+ auto buckets = buckets_.load(std::memory_order_relaxed);
+ load_factor_nodes_ = buckets->bucket_count_ * load_factor_;
+ }
+
+ Iterator cbegin() {
+ Iterator res;
+ auto buckets = res.buckets_hazptr_.get_protected(buckets_);
+ res.setNode(nullptr, buckets, 0);
+ res.next();
+ return res;
+ }
+
+ Iterator cend() {
+ return Iterator(nullptr);
+ }
+
+ // Could be optimized to avoid an extra pointer dereference by
+ // allocating buckets_ at the same time.
+ class Buckets : public folly::hazptr::hazptr_obj_base<
+ Buckets,
+ concurrenthashmap::HazptrDeleter<Allocator>> {
+ public:
+ explicit Buckets(size_t count) : bucket_count_(count) {
+ buckets_ =
+ (Atom<Node*>*)Allocator().allocate(sizeof(Atom<Node*>) * count);
+ new (buckets_) Atom<Node*>[ count ];
+ for (size_t i = 0; i < count; i++) {
+ buckets_[i].store(nullptr, std::memory_order_relaxed);
+ }
+ }
+ ~Buckets() {
+ for (size_t i = 0; i < bucket_count_; i++) {
+ auto elem = buckets_[i].load(std::memory_order_relaxed);
+ if (elem) {
+ elem->release();
+ }
+ }
+ Allocator().deallocate(
+ (uint8_t*)buckets_, sizeof(Atom<Node*>) * bucket_count_);
+ }
+
+ size_t bucket_count_;
+ Atom<Node*>* buckets_{nullptr};
+ };
+
+ public:
+ class Iterator {
+ public:
+ FOLLY_ALWAYS_INLINE Iterator() {}
+ FOLLY_ALWAYS_INLINE explicit Iterator(std::nullptr_t)
+ : buckets_hazptr_(nullptr), node_hazptr_(nullptr) {}
+ FOLLY_ALWAYS_INLINE ~Iterator() {}
+
+ void setNode(Node* node, Buckets* buckets, uint64_t idx) {
+ node_ = node;
+ buckets_ = buckets;
+ idx_ = idx;
+ }
+
+ const value_type& operator*() const {
+ DCHECK(node_);
+ return node_->getItem();
+ }
+
+ const value_type* operator->() const {
+ DCHECK(node_);
+ return &(node_->getItem());
+ }
+
+ const Iterator& operator++() {
+ DCHECK(node_);
+ node_ = node_hazptr_.get_protected(node_->next_);
+ if (!node_) {
+ ++idx_;
+ next();
+ }
+ return *this;
+ }
+
+ void next() {
+ while (!node_) {
+ if (idx_ >= buckets_->bucket_count_) {
+ break;
+ }
+ DCHECK(buckets_);
+ DCHECK(buckets_->buckets_);
+ node_ = node_hazptr_.get_protected(buckets_->buckets_[idx_]);
+ if (node_) {
+ break;
+ }
+ ++idx_;
+ }
+ }
+
+ Iterator operator++(int) {
+ auto prev = *this;
+ ++*this;
+ return prev;
+ }
+
+ bool operator==(const Iterator& o) const {
+ return node_ == o.node_;
+ }
+
+ bool operator!=(const Iterator& o) const {
+ return !(*this == o);
+ }
+
+ Iterator& operator=(const Iterator& o) {
+ node_ = o.node_;
+ node_hazptr_.reset(node_);
+ idx_ = o.idx_;
+ buckets_ = o.buckets_;
+ buckets_hazptr_.reset(buckets_);
+ return *this;
+ }
+
+ /* implicit */ Iterator(const Iterator& o) {
+ node_ = o.node_;
+ node_hazptr_.reset(node_);
+ idx_ = o.idx_;
+ buckets_ = o.buckets_;
+ buckets_hazptr_.reset(buckets_);
+ }
+
+ /* implicit */ Iterator(Iterator&& o) noexcept
+ : buckets_hazptr_(std::move(o.buckets_hazptr_)),
+ node_hazptr_(std::move(o.node_hazptr_)) {
+ node_ = o.node_;
+ buckets_ = o.buckets_;
+ idx_ = o.idx_;
+ }
+
+ // These are accessed directly from the functions above
+ folly::hazptr::hazptr_holder buckets_hazptr_;
+ folly::hazptr::hazptr_holder node_hazptr_;
+
+ private:
+ Node* node_{nullptr};
+ Buckets* buckets_{nullptr};
+ uint64_t idx_;
+ };
+
+ private:
+ // Shards have already used low ShardBits of the hash.
+ // Shift it over to use fresh bits.
+ uint64_t getIdx(Buckets* buckets, size_t hash) {
+ return (hash >> ShardBits) & (buckets->bucket_count_ - 1);
+ }
+
+ float load_factor_;
+ size_t load_factor_nodes_;
+ size_t size_{0};
+ size_t max_size_{0};
+ Atom<Buckets*> buckets_{nullptr};
+ Mutex m_;
+};
+}
+} // folly::detail namespace
--- /dev/null
+/*
+ * Copyright 2017-present Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <atomic>
+#include <memory>
+#include <thread>
+
+#include <folly/Hash.h>
+#include <folly/concurrency/ConcurrentHashMap.h>
+#include <folly/portability/GTest.h>
+#include <folly/test/DeterministicSchedule.h>
+
+using namespace folly::test;
+using namespace folly;
+using namespace std;
+
+DEFINE_int64(seed, 0, "Seed for random number generators");
+
+TEST(ConcurrentHashMap, MapTest) {
+ ConcurrentHashMap<uint64_t, uint64_t> foomap(3);
+ foomap.max_load_factor(1.05);
+ EXPECT_TRUE(foomap.empty());
+ EXPECT_EQ(foomap.find(1), foomap.cend());
+ auto r = foomap.insert(1, 0);
+ EXPECT_TRUE(r.second);
+ auto r2 = foomap.insert(1, 0);
+ EXPECT_EQ(r.first->second, 0);
+ EXPECT_EQ(r.first->first, 1);
+ EXPECT_EQ(r2.first->second, 0);
+ EXPECT_EQ(r2.first->first, 1);
+ EXPECT_EQ(r.first, r2.first);
+ EXPECT_TRUE(r.second);
+ EXPECT_FALSE(r2.second);
+ EXPECT_FALSE(foomap.empty());
+ EXPECT_TRUE(foomap.insert(std::make_pair(2, 0)).second);
+ EXPECT_TRUE(foomap.insert_or_assign(2, 0).second);
+ EXPECT_TRUE(foomap.assign_if_equal(2, 0, 3));
+ EXPECT_TRUE(foomap.insert(3, 0).second);
+ EXPECT_NE(foomap.find(1), foomap.cend());
+ EXPECT_NE(foomap.find(2), foomap.cend());
+ EXPECT_EQ(foomap.find(2)->second, 3);
+ EXPECT_EQ(foomap[2], 3);
+ EXPECT_EQ(foomap[20], 0);
+ EXPECT_EQ(foomap.at(20), 0);
+ EXPECT_FALSE(foomap.insert(1, 0).second);
+ auto l = foomap.find(1);
+ foomap.erase(l);
+ EXPECT_FALSE(foomap.erase(1));
+ EXPECT_EQ(foomap.find(1), foomap.cend());
+ auto res = foomap.find(2);
+ EXPECT_NE(res, foomap.cend());
+ EXPECT_EQ(3, res->second);
+ EXPECT_FALSE(foomap.empty());
+ foomap.clear();
+ EXPECT_TRUE(foomap.empty());
+}
+
+TEST(ConcurrentHashMap, MaxSizeTest) {
+ ConcurrentHashMap<uint64_t, uint64_t> foomap(2, 16);
+ bool insert_failed = false;
+ for (int i = 0; i < 32; i++) {
+ auto res = foomap.insert(0, 0);
+ if (!res.second) {
+ insert_failed = true;
+ }
+ }
+ EXPECT_TRUE(insert_failed);
+}
+
+TEST(ConcurrentHashMap, MoveTest) {
+ ConcurrentHashMap<uint64_t, uint64_t> foomap(2, 16);
+ auto other = std::move(foomap);
+ auto other2 = std::move(other);
+ other = std::move(other2);
+}
+
+struct foo {
+ static int moved;
+ static int copied;
+ foo(foo&& o) noexcept {
+ (void*)&o;
+ moved++;
+ }
+ foo& operator=(foo&& o) {
+ (void*)&o;
+ moved++;
+ return *this;
+ }
+ foo& operator=(const foo& o) {
+ (void*)&o;
+ copied++;
+ return *this;
+ }
+ foo(const foo& o) {
+ (void*)&o;
+ copied++;
+ }
+ foo() {}
+};
+int foo::moved{0};
+int foo::copied{0};
+
+TEST(ConcurrentHashMap, EmplaceTest) {
+ ConcurrentHashMap<uint64_t, foo> foomap(200);
+ foomap.insert(1, foo());
+ EXPECT_EQ(foo::moved, 0);
+ EXPECT_EQ(foo::copied, 1);
+ foo::copied = 0;
+ // The difference between emplace and try_emplace:
+ // If insertion fails, try_emplace does not move its argument
+ foomap.try_emplace(1, foo());
+ EXPECT_EQ(foo::moved, 0);
+ EXPECT_EQ(foo::copied, 0);
+ foomap.emplace(1, foo());
+ EXPECT_EQ(foo::moved, 1);
+ EXPECT_EQ(foo::copied, 0);
+}
+
+TEST(ConcurrentHashMap, MapResizeTest) {
+ ConcurrentHashMap<uint64_t, uint64_t> foomap(2);
+ EXPECT_EQ(foomap.find(1), foomap.cend());
+ EXPECT_TRUE(foomap.insert(1, 0).second);
+ EXPECT_TRUE(foomap.insert(2, 0).second);
+ EXPECT_TRUE(foomap.insert(3, 0).second);
+ EXPECT_TRUE(foomap.insert(4, 0).second);
+ foomap.reserve(512);
+ EXPECT_NE(foomap.find(1), foomap.cend());
+ EXPECT_NE(foomap.find(2), foomap.cend());
+ EXPECT_FALSE(foomap.insert(1, 0).second);
+ EXPECT_TRUE(foomap.erase(1));
+ EXPECT_EQ(foomap.find(1), foomap.cend());
+ auto res = foomap.find(2);
+ EXPECT_NE(res, foomap.cend());
+ if (res != foomap.cend()) {
+ EXPECT_EQ(0, res->second);
+ }
+}
+
+// Ensure we can insert objects without copy constructors.
+TEST(ConcurrentHashMap, MapNoCopiesTest) {
+ struct Uncopyable {
+ Uncopyable(int i) {
+ (void*)&i;
+ }
+ Uncopyable(const Uncopyable& that) = delete;
+ };
+ ConcurrentHashMap<uint64_t, Uncopyable> foomap(2);
+ EXPECT_TRUE(foomap.try_emplace(1, 1).second);
+ EXPECT_TRUE(foomap.try_emplace(2, 2).second);
+ auto res = foomap.find(2);
+ EXPECT_NE(res, foomap.cend());
+
+ EXPECT_TRUE(foomap.try_emplace(3, 3).second);
+
+ auto res2 = foomap.find(2);
+ EXPECT_NE(res2, foomap.cend());
+ EXPECT_EQ(&(res->second), &(res2->second));
+}
+
+TEST(ConcurrentHashMap, MapUpdateTest) {
+ ConcurrentHashMap<uint64_t, uint64_t> foomap(2);
+ EXPECT_TRUE(foomap.insert(1, 10).second);
+ EXPECT_TRUE(bool(foomap.assign(1, 11)));
+ auto res = foomap.find(1);
+ EXPECT_NE(res, foomap.cend());
+ EXPECT_EQ(11, res->second);
+}
+
+TEST(ConcurrentHashMap, MapIterateTest2) {
+ ConcurrentHashMap<uint64_t, uint64_t> foomap(2);
+ auto begin = foomap.cbegin();
+ auto end = foomap.cend();
+ EXPECT_EQ(begin, end);
+}
+
+TEST(ConcurrentHashMap, MapIterateTest) {
+ ConcurrentHashMap<uint64_t, uint64_t> foomap(2);
+ EXPECT_EQ(foomap.cbegin(), foomap.cend());
+ EXPECT_TRUE(foomap.insert(1, 1).second);
+ EXPECT_TRUE(foomap.insert(2, 2).second);
+ auto iter = foomap.cbegin();
+ EXPECT_NE(iter, foomap.cend());
+ EXPECT_EQ(iter->first, 1);
+ EXPECT_EQ(iter->second, 1);
+ iter++;
+ EXPECT_NE(iter, foomap.cend());
+ EXPECT_EQ(iter->first, 2);
+ EXPECT_EQ(iter->second, 2);
+ iter++;
+ EXPECT_EQ(iter, foomap.cend());
+
+ int count = 0;
+ for (auto it = foomap.cbegin(); it != foomap.cend(); it++) {
+ count++;
+ }
+ EXPECT_EQ(count, 2);
+}
+
+// TODO: hazptrs must support DeterministicSchedule
+
+#define Atom std::atomic // DeterministicAtomic
+#define Mutex std::mutex // DeterministicMutex
+#define lib std // DeterministicSchedule
+#define join t.join() // DeterministicSchedule::join(t)
+// #define Atom DeterministicAtomic
+// #define Mutex DeterministicMutex
+// #define lib DeterministicSchedule
+// #define join DeterministicSchedule::join(t)
+
+TEST(ConcurrentHashMap, UpdateStressTest) {
+ DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+ // size must match iters for this test.
+ unsigned size = 128 * 128;
+ unsigned iters = size;
+ ConcurrentHashMap<
+ unsigned long,
+ unsigned long,
+ std::hash<unsigned long>,
+ std::equal_to<unsigned long>,
+ std::allocator<uint8_t>,
+ 8,
+ Atom,
+ Mutex>
+ m(2);
+
+ for (uint i = 0; i < size; i++) {
+ m.insert(i, i);
+ }
+ std::vector<std::thread> threads;
+ unsigned int num_threads = 32;
+ for (uint t = 0; t < num_threads; t++) {
+ threads.push_back(lib::thread([&, t]() {
+ int offset = (iters * t / num_threads);
+ for (uint i = 0; i < iters / num_threads; i++) {
+ unsigned long k = folly::hash::jenkins_rev_mix32((i + offset));
+ k = k % (iters / num_threads) + offset;
+ unsigned long val = 3;
+ auto res = m.find(k);
+ EXPECT_NE(res, m.cend());
+ EXPECT_EQ(k, res->second);
+ auto r = m.assign(k, res->second);
+ EXPECT_TRUE(r);
+ res = m.find(k);
+ EXPECT_NE(res, m.cend());
+ EXPECT_EQ(k, res->second);
+ // Another random insertion to force table resizes
+ val = size + i + offset;
+ EXPECT_TRUE(m.insert(val, val).second);
+ }
+ }));
+ }
+ for (auto& t : threads) {
+ join;
+ }
+}
+
+TEST(ConcurrentHashMap, EraseStressTest) {
+ DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+ unsigned size = 2;
+ unsigned iters = size * 128 * 2;
+ ConcurrentHashMap<
+ unsigned long,
+ unsigned long,
+ std::hash<unsigned long>,
+ std::equal_to<unsigned long>,
+ std::allocator<uint8_t>,
+ 8,
+ Atom,
+ Mutex>
+ m(2);
+
+ for (uint i = 0; i < size; i++) {
+ unsigned long k = folly::hash::jenkins_rev_mix32(i);
+ m.insert(k, k);
+ }
+ std::vector<std::thread> threads;
+ unsigned int num_threads = 32;
+ for (uint t = 0; t < num_threads; t++) {
+ threads.push_back(lib::thread([&, t]() {
+ int offset = (iters * t / num_threads);
+ for (uint i = 0; i < iters / num_threads; i++) {
+ unsigned long k = folly::hash::jenkins_rev_mix32((i + offset));
+ unsigned long val;
+ auto res = m.insert(k, k).second;
+ if (res) {
+ res = m.erase(k);
+ if (!res) {
+ printf("Faulre to erase thread %i val %li\n", t, k);
+ exit(0);
+ }
+ EXPECT_TRUE(res);
+ }
+ res = m.insert(k, k).second;
+ if (res) {
+ res = bool(m.assign(k, k));
+ if (!res) {
+ printf("Thread %i update fail %li res%i\n", t, k, res);
+ exit(0);
+ }
+ EXPECT_TRUE(res);
+ auto res = m.find(k);
+ if (res == m.cend()) {
+ printf("Thread %i lookup fail %li\n", t, k);
+ exit(0);
+ }
+ EXPECT_EQ(k, res->second);
+ }
+ }
+ }));
+ }
+ for (auto& t : threads) {
+ join;
+ }
+}
+
+TEST(ConcurrentHashMap, IterateStressTest) {
+ DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+ unsigned size = 2;
+ unsigned iters = size * 128 * 2;
+ ConcurrentHashMap<
+ unsigned long,
+ unsigned long,
+ std::hash<unsigned long>,
+ std::equal_to<unsigned long>,
+ std::allocator<uint8_t>,
+ 8,
+ Atom,
+ Mutex>
+ m(2);
+
+ for (uint i = 0; i < size; i++) {
+ unsigned long k = folly::hash::jenkins_rev_mix32(i);
+ m.insert(k, k);
+ }
+ for (uint i = 0; i < 10; i++) {
+ m.insert(i, i);
+ }
+ std::vector<std::thread> threads;
+ unsigned int num_threads = 32;
+ for (uint t = 0; t < num_threads; t++) {
+ threads.push_back(lib::thread([&, t]() {
+ int offset = (iters * t / num_threads);
+ for (uint i = 0; i < iters / num_threads; i++) {
+ unsigned long k = folly::hash::jenkins_rev_mix32((i + offset));
+ unsigned long val;
+ auto res = m.insert(k, k).second;
+ if (res) {
+ res = m.erase(k);
+ if (!res) {
+ printf("Faulre to erase thread %i val %li\n", t, k);
+ exit(0);
+ }
+ EXPECT_TRUE(res);
+ }
+ int count = 0;
+ for (auto it = m.cbegin(); it != m.cend(); it++) {
+ printf("Item is %li\n", it->first);
+ if (it->first < 10) {
+ count++;
+ }
+ }
+ EXPECT_EQ(count, 10);
+ }
+ }));
+ }
+ for (auto& t : threads) {
+ join;
+ }
+}
+
+TEST(ConcurrentHashMap, insertStressTest) {
+ DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+ unsigned size = 2;
+ unsigned iters = size * 64 * 4;
+ ConcurrentHashMap<
+ unsigned long,
+ unsigned long,
+ std::hash<unsigned long>,
+ std::equal_to<unsigned long>,
+ std::allocator<uint8_t>,
+ 8,
+ Atom,
+ Mutex>
+ m(2);
+
+ EXPECT_TRUE(m.insert(0, 0).second);
+ EXPECT_FALSE(m.insert(0, 0).second);
+ std::vector<std::thread> threads;
+ unsigned int num_threads = 32;
+ for (uint t = 0; t < num_threads; t++) {
+ threads.push_back(lib::thread([&, t]() {
+ int offset = (iters * t / num_threads);
+ for (uint i = 0; i < iters / num_threads; i++) {
+ auto var = offset + i + 1;
+ EXPECT_TRUE(m.insert(var, var).second);
+ EXPECT_FALSE(m.insert(0, 0).second);
+ }
+ }));
+ }
+ for (auto& t : threads) {
+ join;
+ }
+}
+
+TEST(ConcurrentHashMap, assignStressTest) {
+ DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+ unsigned size = 2;
+ unsigned iters = size * 64 * 4;
+ struct big_value {
+ uint64_t v1;
+ uint64_t v2;
+ uint64_t v3;
+ uint64_t v4;
+ uint64_t v5;
+ uint64_t v6;
+ uint64_t v7;
+ uint64_t v8;
+ void set(uint64_t v) {
+ v1 = v2 = v3 = v4 = v5 = v6 = v7 = v8 = v;
+ }
+ void check() const {
+ auto v = v1;
+ EXPECT_EQ(v, v8);
+ EXPECT_EQ(v, v7);
+ EXPECT_EQ(v, v6);
+ EXPECT_EQ(v, v5);
+ EXPECT_EQ(v, v4);
+ EXPECT_EQ(v, v3);
+ EXPECT_EQ(v, v2);
+ }
+ };
+ ConcurrentHashMap<
+ unsigned long,
+ big_value,
+ std::hash<unsigned long>,
+ std::equal_to<unsigned long>,
+ std::allocator<uint8_t>,
+ 8,
+ Atom,
+ Mutex>
+ m(2);
+
+ for (uint i = 0; i < iters; i++) {
+ big_value a;
+ a.set(i);
+ m.insert(i, a);
+ }
+
+ std::vector<std::thread> threads;
+ unsigned int num_threads = 32;
+ for (uint t = 0; t < num_threads; t++) {
+ threads.push_back(lib::thread([&]() {
+ for (uint i = 0; i < iters; i++) {
+ auto res = m.find(i);
+ EXPECT_NE(res, m.cend());
+ res->second.check();
+ big_value b;
+ b.set(res->second.v1 + 1);
+ m.assign(i, b);
+ }
+ }));
+ }
+ for (auto& t : threads) {
+ join;
+ }
+}
return false;
}
-inline class hazptr_tc& hazptr_tc() {
+FOLLY_ALWAYS_INLINE class hazptr_tc& hazptr_tc() {
static thread_local class hazptr_tc tc;
DEBUG_PRINT(&tc);
return tc;