Dead shift in ConcurrentHashMapSegment
[folly.git] / folly / concurrency / detail / ConcurrentHashMap-detail.h
1 /*
2  * Copyright 2017-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/experimental/hazptr/hazptr.h>
19 #include <atomic>
20 #include <mutex>
21
22 namespace folly {
23
24 namespace detail {
25
26 namespace concurrenthashmap {
27
28 // hazptr retire() that can use an allocator.
29 template <typename Allocator>
30 class HazptrDeleter {
31  public:
32   template <typename Node>
33   void operator()(Node* node) {
34     node->~Node();
35     Allocator().deallocate((uint8_t*)node, sizeof(Node));
36   }
37 };
38
39 template <
40     typename KeyType,
41     typename ValueType,
42     typename Allocator,
43     typename Enabled = void>
44 class ValueHolder {
45  public:
46   typedef std::pair<const KeyType, ValueType> value_type;
47
48   explicit ValueHolder(const ValueHolder& other) : item_(other.item_) {}
49
50   template <typename... Args>
51   ValueHolder(const KeyType& k, Args&&... args)
52       : item_(
53             std::piecewise_construct,
54             std::forward_as_tuple(k),
55             std::forward_as_tuple(std::forward<Args>(args)...)) {}
56   value_type& getItem() {
57     return item_;
58   }
59
60  private:
61   value_type item_;
62 };
63
64 // If the ValueType is not copy constructible, we can instead add
65 // an extra indirection.  Adds more allocations / deallocations and
66 // pulls in an extra cacheline.
67 template <typename KeyType, typename ValueType, typename Allocator>
68 class ValueHolder<
69     KeyType,
70     ValueType,
71     Allocator,
72     std::enable_if_t<!std::is_nothrow_copy_constructible<ValueType>::value>> {
73  public:
74   typedef std::pair<const KeyType, ValueType> value_type;
75
76   explicit ValueHolder(const ValueHolder& other) {
77     other.owned_ = false;
78     item_ = other.item_;
79   }
80
81   template <typename... Args>
82   ValueHolder(const KeyType& k, Args&&... args) {
83     item_ = (value_type*)Allocator().allocate(sizeof(value_type));
84     new (item_) value_type(
85         std::piecewise_construct,
86         std::forward_as_tuple(k),
87         std::forward_as_tuple(std::forward<Args>(args)...));
88   }
89
90   ~ValueHolder() {
91     if (owned_) {
92       item_->~value_type();
93       Allocator().deallocate((uint8_t*)item_, sizeof(value_type));
94     }
95   }
96
97   value_type& getItem() {
98     return *item_;
99   }
100
101  private:
102   value_type* item_;
103   mutable bool owned_{true};
104 };
105
106 template <
107     typename KeyType,
108     typename ValueType,
109     typename Allocator,
110     template <typename> class Atom = std::atomic>
111 class NodeT : public folly::hazptr::hazptr_obj_base<
112                   NodeT<KeyType, ValueType, Allocator, Atom>,
113                   concurrenthashmap::HazptrDeleter<Allocator>> {
114  public:
115   typedef std::pair<const KeyType, ValueType> value_type;
116
117   explicit NodeT(NodeT* other) : item_(other->item_) {}
118
119   template <typename... Args>
120   NodeT(const KeyType& k, Args&&... args)
121       : item_(k, std::forward<Args>(args)...) {}
122
123   /* Nodes are refcounted: If a node is retired() while a writer is
124      traversing the chain, the rest of the chain must remain valid
125      until all readers are finished.  This includes the shared tail
126      portion of the chain, as well as both old/new hash buckets that
127      may point to the same portion, and erased nodes may increase the
128      refcount */
129   void acquire() {
130     DCHECK(refcount_.load() != 0);
131     refcount_.fetch_add(1);
132   }
133   void release() {
134     if (refcount_.fetch_sub(1) == 1 /* was previously 1 */) {
135       this->retire(
136           folly::hazptr::default_hazptr_domain(),
137           concurrenthashmap::HazptrDeleter<Allocator>());
138     }
139   }
140   ~NodeT() {
141     auto next = next_.load(std::memory_order_acquire);
142     if (next) {
143       next->release();
144     }
145   }
146
147   value_type& getItem() {
148     return item_.getItem();
149   }
150   Atom<NodeT*> next_{nullptr};
151
152  private:
153   ValueHolder<KeyType, ValueType, Allocator> item_;
154   Atom<uint8_t> refcount_{1};
155 };
156
157 } // namespace concurrenthashmap
158
159 /* A Segment is a single shard of the ConcurrentHashMap.
160  * All writes take the lock, while readers are all wait-free.
161  * Readers always proceed in parallel with the single writer.
162  *
163  *
164  * Possible additional optimizations:
165  *
166  * * insert / erase could be lock / wait free.  Would need to be
167  *   careful that assign and rehash don't conflict (possibly with
168  *   reader/writer lock, or microlock per node or per bucket, etc).
169  *   Java 8 goes halfway, and and does lock per bucket, except for the
170  *   first item, that is inserted with a CAS (which is somewhat
171  *   specific to java having a lock per object)
172  *
173  * * I tried using trylock() and find() to warm the cache for insert()
174  *   and erase() similar to Java 7, but didn't have much luck.
175  *
176  * * We could order elements using split ordering, for faster rehash,
177  *   and no need to ever copy nodes.  Note that a full split ordering
178  *   including dummy nodes increases the memory usage by 2x, but we
179  *   could split the difference and still require a lock to set bucket
180  *   pointers.
181  *
182  * * hazptr acquire/release could be optimized more, in
183  *   single-threaded case, hazptr overhead is ~30% for a hot find()
184  *   loop.
185  */
186 template <
187     typename KeyType,
188     typename ValueType,
189     uint8_t ShardBits = 0,
190     typename HashFn = std::hash<KeyType>,
191     typename KeyEqual = std::equal_to<KeyType>,
192     typename Allocator = std::allocator<uint8_t>,
193     template <typename> class Atom = std::atomic,
194     class Mutex = std::mutex>
195 class FOLLY_ALIGNED(64) ConcurrentHashMapSegment {
196   enum class InsertType {
197     DOES_NOT_EXIST, // insert/emplace operations.  If key exists, return false.
198     MUST_EXIST, // assign operations.  If key does not exist, return false.
199     ANY, // insert_or_assign.
200     MATCH, // assign_if_equal (not in std).  For concurrent maps, a
201            // way to atomically change a value if equal to some other
202            // value.
203   };
204
205  public:
206   typedef KeyType key_type;
207   typedef ValueType mapped_type;
208   typedef std::pair<const KeyType, ValueType> value_type;
209   typedef std::size_t size_type;
210
211   using Node = concurrenthashmap::NodeT<KeyType, ValueType, Allocator, Atom>;
212   class Iterator;
213
214   ConcurrentHashMapSegment(
215       size_t initial_buckets,
216       float load_factor,
217       size_t max_size)
218       : load_factor_(load_factor), max_size_(max_size) {
219     auto buckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
220     initial_buckets = folly::nextPowTwo(initial_buckets);
221     DCHECK(
222         max_size_ == 0 ||
223         (isPowTwo(max_size_) &&
224          (folly::popcount(max_size_ - 1) + ShardBits <= 32)));
225     new (buckets) Buckets(initial_buckets);
226     buckets_.store(buckets, std::memory_order_release);
227     load_factor_nodes_ = initial_buckets * load_factor_;
228   }
229
230   ~ConcurrentHashMapSegment() {
231     auto buckets = buckets_.load(std::memory_order_relaxed);
232     // We can delete and not retire() here, since users must have
233     // their own synchronization around destruction.
234     buckets->~Buckets();
235     Allocator().deallocate((uint8_t*)buckets, sizeof(Buckets));
236   }
237
238   size_t size() {
239     return size_;
240   }
241
242   bool empty() {
243     return size() == 0;
244   }
245
246   bool insert(Iterator& it, std::pair<key_type, mapped_type>&& foo) {
247     return insert(it, foo.first, foo.second);
248   }
249
250   bool insert(Iterator& it, const KeyType& k, const ValueType& v) {
251     auto node = (Node*)Allocator().allocate(sizeof(Node));
252     new (node) Node(k, v);
253     auto res = insert_internal(
254         it,
255         k,
256         InsertType::DOES_NOT_EXIST,
257         [](const ValueType&) { return false; },
258         node,
259         v);
260     if (!res) {
261       node->~Node();
262       Allocator().deallocate((uint8_t*)node, sizeof(Node));
263     }
264     return res;
265   }
266
267   template <typename... Args>
268   bool try_emplace(Iterator& it, const KeyType& k, Args&&... args) {
269     return insert_internal(
270         it,
271         k,
272         InsertType::DOES_NOT_EXIST,
273         [](const ValueType&) { return false; },
274         nullptr,
275         std::forward<Args>(args)...);
276   }
277
278   template <typename... Args>
279   bool emplace(Iterator& it, const KeyType& k, Node* node) {
280     return insert_internal(
281         it,
282         k,
283         InsertType::DOES_NOT_EXIST,
284         [](const ValueType&) { return false; },
285         node);
286   }
287
288   bool insert_or_assign(Iterator& it, const KeyType& k, const ValueType& v) {
289     return insert_internal(
290         it,
291         k,
292         InsertType::ANY,
293         [](const ValueType&) { return false; },
294         nullptr,
295         v);
296   }
297
298   bool assign(Iterator& it, const KeyType& k, const ValueType& v) {
299     auto node = (Node*)Allocator().allocate(sizeof(Node));
300     new (node) Node(k, v);
301     auto res = insert_internal(
302         it,
303         k,
304         InsertType::MUST_EXIST,
305         [](const ValueType&) { return false; },
306         node,
307         v);
308     if (!res) {
309       node->~Node();
310       Allocator().deallocate((uint8_t*)node, sizeof(Node));
311     }
312     return res;
313   }
314
315   bool assign_if_equal(
316       Iterator& it,
317       const KeyType& k,
318       const ValueType& expected,
319       const ValueType& desired) {
320     return insert_internal(
321         it,
322         k,
323         InsertType::MATCH,
324         [expected](const ValueType& v) { return v == expected; },
325         nullptr,
326         desired);
327   }
328
329   template <typename MatchFunc, typename... Args>
330   bool insert_internal(
331       Iterator& it,
332       const KeyType& k,
333       InsertType type,
334       MatchFunc match,
335       Node* cur,
336       Args&&... args) {
337     auto h = HashFn()(k);
338     std::unique_lock<Mutex> g(m_);
339
340     auto buckets = buckets_.load(std::memory_order_relaxed);
341     // Check for rehash needed for DOES_NOT_EXIST
342     if (size_ >= load_factor_nodes_ && type == InsertType::DOES_NOT_EXIST) {
343       if (max_size_ && size_ << 1 > max_size_) {
344         // Would exceed max size.
345         throw std::bad_alloc();
346       }
347       rehash(buckets->bucket_count_ << 1);
348       buckets = buckets_.load(std::memory_order_relaxed);
349     }
350
351     auto idx = getIdx(buckets, h);
352     auto head = &buckets->buckets_[idx];
353     auto node = head->load(std::memory_order_relaxed);
354     auto headnode = node;
355     auto prev = head;
356     it.buckets_hazptr_.reset(buckets);
357     while (node) {
358       // Is the key found?
359       if (KeyEqual()(k, node->getItem().first)) {
360         it.setNode(node, buckets, idx);
361         it.node_hazptr_.reset(node);
362         if (type == InsertType::MATCH) {
363           if (!match(node->getItem().second)) {
364             return false;
365           }
366         }
367         if (type == InsertType::DOES_NOT_EXIST) {
368           return false;
369         } else {
370           if (!cur) {
371             cur = (Node*)Allocator().allocate(sizeof(Node));
372             new (cur) Node(k, std::forward<Args>(args)...);
373           }
374           auto next = node->next_.load(std::memory_order_relaxed);
375           cur->next_.store(next, std::memory_order_relaxed);
376           if (next) {
377             next->acquire();
378           }
379           prev->store(cur, std::memory_order_release);
380           g.unlock();
381           // Release not under lock.
382           node->release();
383           return true;
384         }
385       }
386
387       prev = &node->next_;
388       node = node->next_.load(std::memory_order_relaxed);
389     }
390     if (type != InsertType::DOES_NOT_EXIST && type != InsertType::ANY) {
391       it.node_hazptr_.reset();
392       it.buckets_hazptr_.reset();
393       return false;
394     }
395     // Node not found, check for rehash on ANY
396     if (size_ >= load_factor_nodes_ && type == InsertType::ANY) {
397       if (max_size_ && size_ << 1 > max_size_) {
398         // Would exceed max size.
399         throw std::bad_alloc();
400       }
401       rehash(buckets->bucket_count_ << 1);
402
403       // Reload correct bucket.
404       buckets = buckets_.load(std::memory_order_relaxed);
405       it.buckets_hazptr_.reset(buckets);
406       idx = getIdx(buckets, h);
407       head = &buckets->buckets_[idx];
408       headnode = head->load(std::memory_order_relaxed);
409     }
410
411     // We found a slot to put the node.
412     size_++;
413     if (!cur) {
414       // InsertType::ANY
415       // OR DOES_NOT_EXIST, but only in the try_emplace case
416       DCHECK(type == InsertType::ANY || type == InsertType::DOES_NOT_EXIST);
417       cur = (Node*)Allocator().allocate(sizeof(Node));
418       new (cur) Node(k, std::forward<Args>(args)...);
419     }
420     cur->next_.store(headnode, std::memory_order_relaxed);
421     head->store(cur, std::memory_order_release);
422     it.setNode(cur, buckets, idx);
423     return true;
424   }
425
426   // Must hold lock.
427   void rehash(size_t bucket_count) {
428     auto buckets = buckets_.load(std::memory_order_relaxed);
429     auto newbuckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
430     new (newbuckets) Buckets(bucket_count);
431
432     load_factor_nodes_ = bucket_count * load_factor_;
433
434     for (size_t i = 0; i < buckets->bucket_count_; i++) {
435       auto bucket = &buckets->buckets_[i];
436       auto node = bucket->load(std::memory_order_relaxed);
437       if (!node) {
438         continue;
439       }
440       auto h = HashFn()(node->getItem().first);
441       auto idx = getIdx(newbuckets, h);
442       // Reuse as long a chain as possible from the end.  Since the
443       // nodes don't have previous pointers, the longest last chain
444       // will be the same for both the previous hashmap and the new one,
445       // assuming all the nodes hash to the same bucket.
446       auto lastrun = node;
447       auto lastidx = idx;
448       auto count = 0;
449       auto last = node->next_.load(std::memory_order_relaxed);
450       for (; last != nullptr;
451            last = last->next_.load(std::memory_order_relaxed)) {
452         auto k = getIdx(newbuckets, HashFn()(last->getItem().first));
453         if (k != lastidx) {
454           lastidx = k;
455           lastrun = last;
456           count = 0;
457         }
458         count++;
459       }
460       // Set longest last run in new bucket, incrementing the refcount.
461       lastrun->acquire();
462       newbuckets->buckets_[lastidx].store(lastrun, std::memory_order_relaxed);
463       // Clone remaining nodes
464       for (; node != lastrun;
465            node = node->next_.load(std::memory_order_relaxed)) {
466         auto newnode = (Node*)Allocator().allocate(sizeof(Node));
467         new (newnode) Node(node);
468         auto k = getIdx(newbuckets, HashFn()(node->getItem().first));
469         auto prevhead = &newbuckets->buckets_[k];
470         newnode->next_.store(prevhead->load(std::memory_order_relaxed));
471         prevhead->store(newnode, std::memory_order_relaxed);
472       }
473     }
474
475     auto oldbuckets = buckets_.load(std::memory_order_relaxed);
476     buckets_.store(newbuckets, std::memory_order_release);
477     oldbuckets->retire(
478         folly::hazptr::default_hazptr_domain(),
479         concurrenthashmap::HazptrDeleter<Allocator>());
480   }
481
482   bool find(Iterator& res, const KeyType& k) {
483     folly::hazptr::hazptr_holder haznext;
484     auto h = HashFn()(k);
485     auto buckets = res.buckets_hazptr_.get_protected(buckets_);
486     auto idx = getIdx(buckets, h);
487     auto prev = &buckets->buckets_[idx];
488     auto node = res.node_hazptr_.get_protected(*prev);
489     while (node) {
490       if (KeyEqual()(k, node->getItem().first)) {
491         res.setNode(node, buckets, idx);
492         return true;
493       }
494       node = haznext.get_protected(node->next_);
495       haznext.swap(res.node_hazptr_);
496     }
497     return false;
498   }
499
500   // Listed separately because we need a prev pointer.
501   size_type erase(const key_type& key) {
502     return erase_internal(key, nullptr);
503   }
504
505   size_type erase_internal(const key_type& key, Iterator* iter) {
506     Node* node{nullptr};
507     auto h = HashFn()(key);
508     {
509       std::lock_guard<Mutex> g(m_);
510
511       auto buckets = buckets_.load(std::memory_order_relaxed);
512       auto idx = getIdx(buckets, h);
513       auto head = &buckets->buckets_[idx];
514       node = head->load(std::memory_order_relaxed);
515       Node* prev = nullptr;
516       auto headnode = node;
517       while (node) {
518         if (KeyEqual()(key, node->getItem().first)) {
519           auto next = node->next_.load(std::memory_order_relaxed);
520           if (next) {
521             next->acquire();
522           }
523           if (prev) {
524             prev->next_.store(next, std::memory_order_release);
525           } else {
526             // Must be head of list.
527             head->store(next, std::memory_order_release);
528           }
529
530           if (iter) {
531             iter->buckets_hazptr_.reset(buckets);
532             iter->setNode(
533                 node->next_.load(std::memory_order_acquire), buckets, idx);
534           }
535           size_--;
536           break;
537         }
538         prev = node;
539         node = node->next_.load(std::memory_order_relaxed);
540       }
541     }
542     // Delete the node while not under the lock.
543     if (node) {
544       node->release();
545       return 1;
546     }
547     DCHECK(!iter);
548
549     return 0;
550   }
551
552   // Unfortunately because we are reusing nodes on rehash, we can't
553   // have prev pointers in the bucket chain.  We have to start the
554   // search from the bucket.
555   //
556   // This is a small departure from standard stl containers: erase may
557   // throw if hash or key_eq functions throw.
558   void erase(Iterator& res, Iterator& pos) {
559     auto cnt = erase_internal(pos->first, &res);
560     DCHECK(cnt == 1);
561   }
562
563   void clear() {
564     auto buckets = buckets_.load(std::memory_order_relaxed);
565     auto newbuckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
566     new (newbuckets) Buckets(buckets->bucket_count_);
567     {
568       std::lock_guard<Mutex> g(m_);
569       buckets_.store(newbuckets, std::memory_order_release);
570       size_ = 0;
571     }
572     buckets->retire(
573         folly::hazptr::default_hazptr_domain(),
574         concurrenthashmap::HazptrDeleter<Allocator>());
575   }
576
577   void max_load_factor(float factor) {
578     std::lock_guard<Mutex> g(m_);
579     load_factor_ = factor;
580     auto buckets = buckets_.load(std::memory_order_relaxed);
581     load_factor_nodes_ = buckets->bucket_count_ * load_factor_;
582   }
583
584   Iterator cbegin() {
585     Iterator res;
586     auto buckets = res.buckets_hazptr_.get_protected(buckets_);
587     res.setNode(nullptr, buckets, 0);
588     res.next();
589     return res;
590   }
591
592   Iterator cend() {
593     return Iterator(nullptr);
594   }
595
596   // Could be optimized to avoid an extra pointer dereference by
597   // allocating buckets_ at the same time.
598   class Buckets : public folly::hazptr::hazptr_obj_base<
599                       Buckets,
600                       concurrenthashmap::HazptrDeleter<Allocator>> {
601    public:
602     explicit Buckets(size_t count) : bucket_count_(count) {
603       buckets_ =
604           (Atom<Node*>*)Allocator().allocate(sizeof(Atom<Node*>) * count);
605       new (buckets_) Atom<Node*>[ count ];
606       for (size_t i = 0; i < count; i++) {
607         buckets_[i].store(nullptr, std::memory_order_relaxed);
608       }
609     }
610     ~Buckets() {
611       for (size_t i = 0; i < bucket_count_; i++) {
612         auto elem = buckets_[i].load(std::memory_order_relaxed);
613         if (elem) {
614           elem->release();
615         }
616       }
617       Allocator().deallocate(
618           (uint8_t*)buckets_, sizeof(Atom<Node*>) * bucket_count_);
619     }
620
621     size_t bucket_count_;
622     Atom<Node*>* buckets_{nullptr};
623   };
624
625  public:
626   class Iterator {
627    public:
628     FOLLY_ALWAYS_INLINE Iterator() {}
629     FOLLY_ALWAYS_INLINE explicit Iterator(std::nullptr_t)
630         : buckets_hazptr_(nullptr), node_hazptr_(nullptr) {}
631     FOLLY_ALWAYS_INLINE ~Iterator() {}
632
633     void setNode(Node* node, Buckets* buckets, uint64_t idx) {
634       node_ = node;
635       buckets_ = buckets;
636       idx_ = idx;
637     }
638
639     const value_type& operator*() const {
640       DCHECK(node_);
641       return node_->getItem();
642     }
643
644     const value_type* operator->() const {
645       DCHECK(node_);
646       return &(node_->getItem());
647     }
648
649     const Iterator& operator++() {
650       DCHECK(node_);
651       node_ = node_hazptr_.get_protected(node_->next_);
652       if (!node_) {
653         ++idx_;
654         next();
655       }
656       return *this;
657     }
658
659     void next() {
660       while (!node_) {
661         if (idx_ >= buckets_->bucket_count_) {
662           break;
663         }
664         DCHECK(buckets_);
665         DCHECK(buckets_->buckets_);
666         node_ = node_hazptr_.get_protected(buckets_->buckets_[idx_]);
667         if (node_) {
668           break;
669         }
670         ++idx_;
671       }
672     }
673
674     Iterator operator++(int) {
675       auto prev = *this;
676       ++*this;
677       return prev;
678     }
679
680     bool operator==(const Iterator& o) const {
681       return node_ == o.node_;
682     }
683
684     bool operator!=(const Iterator& o) const {
685       return !(*this == o);
686     }
687
688     Iterator& operator=(const Iterator& o) {
689       node_ = o.node_;
690       node_hazptr_.reset(node_);
691       idx_ = o.idx_;
692       buckets_ = o.buckets_;
693       buckets_hazptr_.reset(buckets_);
694       return *this;
695     }
696
697     /* implicit */ Iterator(const Iterator& o) {
698       node_ = o.node_;
699       node_hazptr_.reset(node_);
700       idx_ = o.idx_;
701       buckets_ = o.buckets_;
702       buckets_hazptr_.reset(buckets_);
703     }
704
705     /* implicit */ Iterator(Iterator&& o) noexcept
706         : buckets_hazptr_(std::move(o.buckets_hazptr_)),
707           node_hazptr_(std::move(o.node_hazptr_)) {
708       node_ = o.node_;
709       buckets_ = o.buckets_;
710       idx_ = o.idx_;
711     }
712
713     // These are accessed directly from the functions above
714     folly::hazptr::hazptr_holder buckets_hazptr_;
715     folly::hazptr::hazptr_holder node_hazptr_;
716
717    private:
718     Node* node_{nullptr};
719     Buckets* buckets_{nullptr};
720     uint64_t idx_;
721   };
722
723  private:
724   // Shards have already used low ShardBits of the hash.
725   // Shift it over to use fresh bits.
726   uint64_t getIdx(Buckets* buckets, size_t hash) {
727     return (hash >> ShardBits) & (buckets->bucket_count_ - 1);
728   }
729
730   float load_factor_;
731   size_t load_factor_nodes_;
732   size_t size_{0};
733   size_t const max_size_;
734   Atom<Buckets*> buckets_{nullptr};
735   Mutex m_;
736 };
737 } // namespace detail
738 } // namespace folly