2 * Copyright 2017-present Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/Optional.h>
19 #include <folly/concurrency/detail/ConcurrentHashMap-detail.h>
20 #include <folly/experimental/hazptr/hazptr.h>
27 * Based on Java's ConcurrentHashMap
29 * Readers are always wait-free.
30 * Writers are sharded, but take a lock.
32 * The interface is as close to std::unordered_map as possible, but there
33 * are a handful of changes:
35 * * Iterators hold hazard pointers to the returned elements. Elements can only
36 * be accessed while Iterators are still valid!
38 * * Therefore operator[] and at() return copies, since they do not
39 * return an iterator. The returned value is const, to remind you
40 * that changes do not affect the value in the map.
42 * * erase() calls the hash function, and may fail if the hash
43 * function throws an exception.
45 * * clear() initializes new segments, and is not noexcept.
47 * * The interface adds assign_if_equal, since find() doesn't take a lock.
49 * * Only const version of find() is supported, and const iterators.
50 * Mutation must use functions provided, like assign().
52 * * iteration iterates over all the buckets in the table, unlike
53 * std::unordered_map which iterates over a linked list of elements.
54 * If the table is sparse, this may be more expensive.
56 * * rehash policy is a power of two, using supplied factor.
58 * * Allocator must be stateless.
60 * * ValueTypes without copy constructors will work, but pessimize the
64 * Single-threaded performance is extremely similar to std::unordered_map.
66 * Multithreaded performance beats anything except the lock-free
67 * atomic maps (AtomicUnorderedMap, AtomicHashMap), BUT only
68 * if you can perfectly size the atomic maps, and you don't
69 * need erase(). If you don't know the size in advance or
70 * your workload frequently calls erase(), this is the
77 typename HashFn = std::hash<KeyType>,
78 typename KeyEqual = std::equal_to<KeyType>,
79 typename Allocator = std::allocator<uint8_t>,
80 uint8_t ShardBits = 8,
81 template <typename> class Atom = std::atomic,
82 class Mutex = std::mutex>
83 class ConcurrentHashMap {
84 using SegmentT = detail::ConcurrentHashMapSegment<
93 static constexpr uint64_t NumShards = (1 << ShardBits);
94 // Slightly higher than 1.0, in case hashing to shards isn't
95 // perfectly balanced, reserve(size) will still work without
97 float load_factor_ = 1.05;
102 typedef KeyType key_type;
103 typedef ValueType mapped_type;
104 typedef std::pair<const KeyType, ValueType> value_type;
105 typedef std::size_t size_type;
106 typedef HashFn hasher;
107 typedef KeyEqual key_equal;
108 typedef ConstIterator const_iterator;
111 * Construct a ConcurrentHashMap with 1 << ShardBits shards, size
112 * and max_size given. Both size and max_size will be rounded up to
113 * the next power of two, if they are not already a power of two, so
114 * that we can index in to Shards efficiently.
116 * Insertion functions will throw bad_alloc if max_size is exceeded.
118 explicit ConcurrentHashMap(size_t size = 8, size_t max_size = 0) {
119 size_ = folly::nextPowTwo(size);
121 max_size_ = folly::nextPowTwo(max_size);
123 CHECK(max_size_ == 0 || max_size_ >= size_);
124 for (uint64_t i = 0; i < NumShards; i++) {
125 segments_[i].store(nullptr, std::memory_order_relaxed);
129 ConcurrentHashMap(ConcurrentHashMap&& o) noexcept {
130 for (uint64_t i = 0; i < NumShards; i++) {
132 o.segments_[i].load(std::memory_order_relaxed),
133 std::memory_order_relaxed);
134 o.segments_[i].store(nullptr, std::memory_order_relaxed);
138 ConcurrentHashMap& operator=(ConcurrentHashMap&& o) {
139 for (uint64_t i = 0; i < NumShards; i++) {
140 auto seg = segments_[i].load(std::memory_order_relaxed);
143 Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
146 o.segments_[i].load(std::memory_order_relaxed),
147 std::memory_order_relaxed);
148 o.segments_[i].store(nullptr, std::memory_order_relaxed);
153 ~ConcurrentHashMap() {
154 for (uint64_t i = 0; i < NumShards; i++) {
155 auto seg = segments_[i].load(std::memory_order_relaxed);
158 Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
163 bool empty() const noexcept {
164 for (uint64_t i = 0; i < NumShards; i++) {
165 auto seg = segments_[i].load(std::memory_order_acquire);
175 ConstIterator find(const KeyType& k) const {
176 auto segment = pickSegment(k);
177 ConstIterator res(this, segment);
178 auto seg = segments_[segment].load(std::memory_order_acquire);
179 if (!seg || !seg->find(res.it_, k)) {
180 res.segment_ = NumShards;
185 ConstIterator cend() const noexcept {
186 return ConstIterator(NumShards);
189 ConstIterator cbegin() const noexcept {
190 return ConstIterator(this);
193 std::pair<ConstIterator, bool> insert(
194 std::pair<key_type, mapped_type>&& foo) {
195 auto segment = pickSegment(foo.first);
196 std::pair<ConstIterator, bool> res(
197 std::piecewise_construct,
198 std::forward_as_tuple(this, segment),
199 std::forward_as_tuple(false));
200 res.second = ensureSegment(segment)->insert(res.first.it_, std::move(foo));
204 template <typename Key, typename Value>
205 std::pair<ConstIterator, bool> insert(Key&& k, Value&& v) {
206 auto segment = pickSegment(k);
207 std::pair<ConstIterator, bool> res(
208 std::piecewise_construct,
209 std::forward_as_tuple(this, segment),
210 std::forward_as_tuple(false));
211 res.second = ensureSegment(segment)->insert(
212 res.first.it_, std::forward<Key>(k), std::forward<Value>(v));
216 template <typename Key, typename... Args>
217 std::pair<ConstIterator, bool> try_emplace(Key&& k, Args&&... args) {
218 auto segment = pickSegment(k);
219 std::pair<ConstIterator, bool> res(
220 std::piecewise_construct,
221 std::forward_as_tuple(this, segment),
222 std::forward_as_tuple(false));
223 res.second = ensureSegment(segment)->try_emplace(
224 res.first.it_, std::forward<Key>(k), std::forward<Args>(args)...);
228 template <typename... Args>
229 std::pair<ConstIterator, bool> emplace(Args&&... args) {
230 using Node = typename SegmentT::Node;
231 auto node = (Node*)Allocator().allocate(sizeof(Node));
232 new (node) Node(std::forward<Args>(args)...);
233 auto segment = pickSegment(node->getItem().first);
234 std::pair<ConstIterator, bool> res(
235 std::piecewise_construct,
236 std::forward_as_tuple(this, segment),
237 std::forward_as_tuple(false));
238 res.second = ensureSegment(segment)->emplace(
239 res.first.it_, node->getItem().first, node);
242 Allocator().deallocate((uint8_t*)node, sizeof(Node));
247 template <typename Key, typename Value>
248 std::pair<ConstIterator, bool> insert_or_assign(Key&& k, Value&& v) {
249 auto segment = pickSegment(k);
250 std::pair<ConstIterator, bool> res(
251 std::piecewise_construct,
252 std::forward_as_tuple(this, segment),
253 std::forward_as_tuple(false));
254 res.second = ensureSegment(segment)->insert_or_assign(
255 res.first.it_, std::forward<Key>(k), std::forward<Value>(v));
259 template <typename Key, typename Value>
260 folly::Optional<ConstIterator> assign(Key&& k, Value&& v) {
261 auto segment = pickSegment(k);
262 ConstIterator res(this, segment);
263 auto seg = segments_[segment].load(std::memory_order_acquire);
265 return folly::Optional<ConstIterator>();
268 seg->assign(res.it_, std::forward<Key>(k), std::forward<Value>(v));
270 return folly::Optional<ConstIterator>();
276 // Assign to desired if and only if key k is equal to expected
277 template <typename Key, typename Value>
278 folly::Optional<ConstIterator>
279 assign_if_equal(Key&& k, const ValueType& expected, Value&& desired) {
280 auto segment = pickSegment(k);
281 ConstIterator res(this, segment);
282 auto seg = segments_[segment].load(std::memory_order_acquire);
284 return folly::Optional<ConstIterator>();
286 auto r = seg->assign_if_equal(
288 std::forward<Key>(k),
290 std::forward<Value>(desired));
292 return folly::Optional<ConstIterator>();
298 // Copying wrappers around insert and find.
299 // Only available for copyable types.
300 const ValueType operator[](const KeyType& key) {
301 auto item = insert(key, ValueType());
302 return item.first->second;
305 const ValueType at(const KeyType& key) const {
306 auto item = find(key);
307 if (item == cend()) {
308 throw std::out_of_range("at(): value out of range");
313 // TODO update assign interface, operator[], at
315 size_type erase(const key_type& k) {
316 auto segment = pickSegment(k);
317 auto seg = segments_[segment].load(std::memory_order_acquire);
321 return seg->erase(k);
325 // Calls the hash function, and therefore may throw.
326 ConstIterator erase(ConstIterator& pos) {
327 auto segment = pickSegment(pos->first);
328 ConstIterator res(this, segment);
330 ensureSegment(segment)->erase(res.it_, pos.it_);
331 res.next(); // May point to segment end, and need to advance.
335 // NOT noexcept, initializes new shard segments vs.
337 for (uint64_t i = 0; i < NumShards; i++) {
338 auto seg = segments_[i].load(std::memory_order_acquire);
345 void reserve(size_t count) {
346 count = count >> ShardBits;
347 for (uint64_t i = 0; i < NumShards; i++) {
348 auto seg = segments_[i].load(std::memory_order_acquire);
355 // This is a rolling size, and is not exact at any moment in time.
356 size_t size() const noexcept {
358 for (uint64_t i = 0; i < NumShards; i++) {
359 auto seg = segments_[i].load(std::memory_order_acquire);
367 float max_load_factor() const {
371 void max_load_factor(float factor) {
372 for (uint64_t i = 0; i < NumShards; i++) {
373 auto seg = segments_[i].load(std::memory_order_acquire);
375 seg->max_load_factor(factor);
380 class ConstIterator {
382 friend class ConcurrentHashMap;
384 const value_type& operator*() const {
388 const value_type* operator->() const {
392 ConstIterator& operator++() {
398 ConstIterator operator++(int) {
404 bool operator==(const ConstIterator& o) const {
405 return it_ == o.it_ && segment_ == o.segment_;
408 bool operator!=(const ConstIterator& o) const {
409 return !(*this == o);
412 ConstIterator& operator=(const ConstIterator& o) {
414 segment_ = o.segment_;
418 ConstIterator(const ConstIterator& o) {
420 segment_ = o.segment_;
423 ConstIterator(const ConcurrentHashMap* parent, uint64_t segment)
424 : segment_(segment), parent_(parent) {}
428 explicit ConstIterator(const ConcurrentHashMap* parent)
429 : it_(parent->ensureSegment(0)->cbegin()),
432 // Always iterate to the first element, could be in any shard.
437 explicit ConstIterator(uint64_t shards) : it_(nullptr), segment_(shards) {}
440 while (it_ == parent_->ensureSegment(segment_)->cend() &&
441 segment_ < parent_->NumShards) {
443 auto seg = parent_->segments_[segment_].load(std::memory_order_acquire);
444 if (segment_ < parent_->NumShards) {
453 typename SegmentT::Iterator it_;
455 const ConcurrentHashMap* parent_;
459 uint64_t pickSegment(const KeyType& k) const {
460 auto h = HashFn()(k);
461 // Use the lowest bits for our shard bits.
463 // This works well even if the hash function is biased towards the
464 // low bits: The sharding will happen in the segments_ instead of
465 // in the segment buckets, so we'll still get write sharding as
468 // Low-bit bias happens often for std::hash using small numbers,
469 // since the integer hash function is the identity function.
470 return h & (NumShards - 1);
473 SegmentT* ensureSegment(uint64_t i) const {
474 SegmentT* seg = segments_[i].load(std::memory_order_acquire);
476 SegmentT* newseg = (SegmentT*)Allocator().allocate(sizeof(SegmentT));
477 newseg = new (newseg)
478 SegmentT(size_ >> ShardBits, load_factor_, max_size_ >> ShardBits);
479 if (!segments_[i].compare_exchange_strong(seg, newseg)) {
480 // seg is updated with new value, delete ours.
482 Allocator().deallocate((uint8_t*)newseg, sizeof(SegmentT));
490 mutable Atom<SegmentT*> segments_[NumShards];