Allow stealing pointer bits
authorDave Watson <davejwatson@fb.com>
Wed, 26 Jul 2017 15:07:56 +0000 (08:07 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 26 Jul 2017 15:19:46 +0000 (08:19 -0700)
Summary:
Currently hazard pointers doesn't support stealing any of the pointer bits.
You can *almost* roll it yourself using try_protect, but this prevents
implementations from choosing their type of barrier.

This adds a new get_protected interface that you can use to steal bits, or
otherwise manipulate pointers as you would like.

This also adds a MWMR list based set example that uses it, that is wait-free
for readers (unlike the SWMR example, that is only lock-free).

Reviewed By: magedm

Differential Revision: D5455615

fbshipit-source-id: 53d282eda433e00b6b53cd804d4e1c32c74c2fb8

folly/experimental/hazptr/example/MWMRSet.h [new file with mode: 0644]
folly/experimental/hazptr/hazptr-impl.h
folly/experimental/hazptr/hazptr.h
folly/experimental/hazptr/test/HazptrTest.cpp

diff --git a/folly/experimental/hazptr/example/MWMRSet.h b/folly/experimental/hazptr/example/MWMRSet.h
new file mode 100644 (file)
index 0000000..49d561a
--- /dev/null
@@ -0,0 +1,251 @@
+/*
+ * Copyright 2017 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/experimental/hazptr/debug.h>
+#include <folly/experimental/hazptr/hazptr.h>
+
+namespace folly {
+namespace hazptr {
+
+/** Set implemented as an ordered singly-linked list.
+ *
+ *  Multiple writers may add or remove elements. Multiple reader
+ *  threads may search the set concurrently with each other and with
+ *  the writers' operations.
+ */
+template <typename T>
+class MWMRListSet {
+  class Node : public hazptr_obj_base<Node> {
+    friend MWMRListSet;
+    T elem_;
+    std::atomic<uint64_t> refcount_{1};
+    std::atomic<Node*> next_{nullptr};
+
+    // Node must be refcounted for wait-free access: A deleted node
+    // may have hazptrs pointing at it, so the rest of the list (or at
+    // least, what existed at the time of the hazptr load) must still
+    // be accessible.
+    void release() {
+      if (refcount_.fetch_sub(1) == 1) {
+        this->retire();
+      }
+    }
+
+    // Optimization in the case that we know there are no hazptrs pointing
+    // at the list.
+    void releaseFast() {
+      if (refcount_.load(std::memory_order_relaxed) == 1) {
+        auto next = getPtr(next_.load(std::memory_order_relaxed));
+        if (next) {
+          next->releaseFast();
+          next_.store(nullptr, std::memory_order_relaxed);
+        }
+        delete this;
+      }
+    }
+
+    void acquire() {
+      DCHECK(refcount_.load() != 0);
+      refcount_.fetch_add(1);
+    }
+
+   public:
+    explicit Node(T e) : elem_(e) {
+      DEBUG_PRINT(this << " " << e);
+    }
+
+    ~Node() {
+      DEBUG_PRINT(this);
+      auto next = getPtr(next_.load(std::memory_order_relaxed));
+      if (next) {
+        next->release();
+      }
+    }
+  };
+
+  static bool getDeleted(Node* ptr) {
+    return uintptr_t(ptr) & 1;
+  }
+
+  static Node* getPtr(Node* ptr) {
+    return (Node*)(uintptr_t(ptr) & ~1UL);
+  }
+
+  mutable std::atomic<Node*> head_ = {nullptr};
+
+  // Remove a single deleted item.
+  // Although it doesn't have to be our item.
+  //
+  // Note that standard lock-free Michael linked lists put this in the
+  // contains() path, while this implementation leaves it only in
+  // remove(), such that contains() is wait-free.
+  void fixlist(
+      hazptr_holder& hptr_prev,
+      hazptr_holder& hptr_curr,
+      std::atomic<Node*>*& prev,
+      Node*& curr) const {
+    while (true) {
+      prev = &head_;
+      curr = hptr_curr.get_protected(*prev, getPtr);
+      while (getPtr(curr)) {
+        auto next = getPtr(curr)->next_.load(std::memory_order_acquire);
+        if (getDeleted(next)) {
+          auto nextp = getPtr(next);
+          if (nextp) {
+            nextp->acquire();
+          }
+          // Try to fix
+          auto curr_no_mark = getPtr(curr);
+          if (prev->compare_exchange_weak(curr_no_mark, nextp)) {
+            // Physically delete
+            curr_no_mark->release();
+            return;
+          } else {
+            if (nextp) {
+              nextp->release();
+            }
+            break;
+          }
+        }
+        prev = &(getPtr(curr)->next_);
+        curr = hptr_prev.get_protected(getPtr(curr)->next_, getPtr);
+
+        swap(hptr_curr, hptr_prev);
+      }
+      DCHECK(getPtr(curr));
+    }
+  }
+
+  /* wait-free set search */
+  bool find(
+      const T& val,
+      hazptr_holder& hptr_prev,
+      hazptr_holder& hptr_curr,
+      std::atomic<Node*>*& prev,
+      Node*& curr) const {
+    prev = &head_;
+    curr = hptr_curr.get_protected(*prev, getPtr);
+    while (getPtr(curr)) {
+      auto next = getPtr(curr)->next_.load(std::memory_order_acquire);
+      if (!getDeleted(next)) {
+        if (getPtr(curr)->elem_ == val) {
+          return true;
+        } else if (!(getPtr(curr)->elem_ < val)) {
+          break; // Because the list is sorted.
+        }
+      }
+      prev = &(getPtr(curr)->next_);
+      curr = hptr_prev.get_protected(getPtr(curr)->next_, getPtr);
+      /* Swap does not change the values of the owned hazard
+       * pointers themselves. After the swap, The hazard pointer
+       * owned by hptr_prev continues to protect the node that
+       * contains the pointer *prev. The hazard pointer owned by
+       * hptr_curr will continue to protect the node that contains
+       * the old *prev (unless the old prev was &head), which no
+       * longer needs protection, so hptr_curr's hazard pointer is
+       * now free to protect *curr in the next iteration (if curr !=
+       * null).
+       */
+      swap(hptr_curr, hptr_prev);
+    }
+
+    return false;
+  }
+
+ public:
+  explicit MWMRListSet() {}
+
+  ~MWMRListSet() {
+    Node* next = head_.load();
+    if (next) {
+      next->releaseFast();
+    }
+  }
+
+  bool add(T v) {
+    hazptr_holder hptr_prev;
+    hazptr_holder hptr_curr;
+    std::atomic<Node*>* prev;
+    Node* cur;
+
+    auto newnode = folly::make_unique<Node>(v);
+
+    while (true) {
+      if (find(v, hptr_prev, hptr_curr, prev, cur)) {
+        return false;
+      }
+      newnode->next_.store(cur, std::memory_order_relaxed);
+      auto cur_no_mark = getPtr(cur);
+      if (prev->compare_exchange_weak(cur_no_mark, newnode.get())) {
+        newnode.release();
+        return true;
+      }
+      // Ensure ~Node() destructor doesn't destroy next_
+      newnode->next_.store(nullptr, std::memory_order_relaxed);
+    }
+  }
+
+  bool remove(const T& v) {
+    hazptr_holder hptr_prev;
+    hazptr_holder hptr_curr;
+    std::atomic<Node*>* prev;
+    Node* curr;
+
+    while (true) {
+      if (!find(v, hptr_prev, hptr_curr, prev, curr)) {
+        return false;
+      }
+      auto next = getPtr(curr)->next_.load(std::memory_order_acquire);
+      auto next_no_mark = getPtr(next); // Ensure only one deleter wins
+      // Logically delete
+      if (!getPtr(curr)->next_.compare_exchange_weak(
+              next_no_mark, (Node*)(uintptr_t(next_no_mark) | 1))) {
+        continue;
+      }
+      if (next) {
+        next->acquire();
+      }
+
+      // Swing prev around
+      auto curr_no_mark = getPtr(curr); /* ensure not deleted */
+      if (prev->compare_exchange_weak(curr_no_mark, next)) {
+        // Physically delete
+        curr->release();
+        return true;
+      }
+      if (next) {
+        next->release();
+      }
+
+      // Someone else modified prev.  Call fixlist
+      // to unlink deleted element by re-walking list.
+      fixlist(hptr_prev, hptr_curr, prev, curr);
+    }
+  }
+
+  bool contains(const T& v) const {
+    hazptr_holder hptr_prev;
+    hazptr_holder hptr_curr;
+    std::atomic<Node*>* prev;
+    Node* curr;
+
+    return find(v, hptr_prev, hptr_curr, prev, curr);
+  }
+};
+
+} // namespace folly {
+} // namespace hazptr {
index 70123659053dedcd5133cbfb6fa1b6a36dfdd737..60cf5ecd6ec5cbfeccd17df32a94aa8ddc294711 100644 (file)
@@ -231,8 +231,16 @@ template <typename T>
 FOLLY_ALWAYS_INLINE bool hazptr_holder::try_protect(
     T*& ptr,
     const std::atomic<T*>& src) noexcept {
+  return try_protect(ptr, src, [](T* t) { return t; });
+}
+
+template <typename T, typename Func>
+FOLLY_ALWAYS_INLINE bool hazptr_holder::try_protect(
+    T*& ptr,
+    const std::atomic<T*>& src,
+    Func f) noexcept {
   DEBUG_PRINT(this << " " << ptr << " " << &src);
-  reset(ptr);
+  reset(f(ptr));
   /*** Full fence ***/ hazptr_mb::light();
   T* p = src.load(std::memory_order_acquire);
   if (p != ptr) {
@@ -246,8 +254,16 @@ FOLLY_ALWAYS_INLINE bool hazptr_holder::try_protect(
 template <typename T>
 FOLLY_ALWAYS_INLINE T* hazptr_holder::get_protected(
     const std::atomic<T*>& src) noexcept {
+  return get_protected(src, [](T* t) { return t; });
+}
+
+template <typename T, typename Func>
+FOLLY_ALWAYS_INLINE T* hazptr_holder::get_protected(
+    const std::atomic<T*>& src,
+    Func f) noexcept {
   T* p = src.load(std::memory_order_relaxed);
-  while (!try_protect(p, src)) {}
+  while (!try_protect(p, src, f)) {
+  }
   DEBUG_PRINT(this << " " << p << " " << &src);
   return p;
 }
index d5944adfa1b744f7065ec3fba55aaf1854f79f1c..301fdfab173d002fed36e94216fa4903fa9bae42 100644 (file)
@@ -119,10 +119,20 @@ class hazptr_holder {
   /* Returns a protected pointer from the source */
   template <typename T>
   T* get_protected(const std::atomic<T*>& src) noexcept;
+  /* Returns a protected pointer from the source, filtering
+     the protected pointer through function Func.  Useful for
+     stealing bits of the pointer word */
+  template <typename T, typename Func>
+  T* get_protected(const std::atomic<T*>& src, Func f) noexcept;
   /* Return true if successful in protecting ptr if src == ptr after
    * setting the hazard pointer.  Otherwise sets ptr to src. */
   template <typename T>
   bool try_protect(T*& ptr, const std::atomic<T*>& src) noexcept;
+  /* Return true if successful in protecting ptr if src == ptr after
+   * setting the hazard pointer, filtering the pointer through Func.
+   * Otherwise sets ptr to src. */
+  template <typename T, typename Func>
+  bool try_protect(T*& ptr, const std::atomic<T*>& src, Func f) noexcept;
   /* Set the hazard pointer to ptr */
   template <typename T>
   void reset(const T* ptr) noexcept;
index d8ca18dacf7ad406bab73c24aeadbd07d3b021b6..301aaa805d53cf53f75ac4497d674f6af0452560 100644 (file)
@@ -19,6 +19,7 @@
 
 #include <folly/experimental/hazptr/debug.h>
 #include <folly/experimental/hazptr/example/LockFreeLIFO.h>
+#include <folly/experimental/hazptr/example/MWMRSet.h>
 #include <folly/experimental/hazptr/example/SWMRList.h>
 #include <folly/experimental/hazptr/example/WideCAS.h>
 #include <folly/experimental/hazptr/hazptr.h>
@@ -224,6 +225,36 @@ TEST_F(HazptrTest, SWMRLIST) {
   }
 }
 
+TEST_F(HazptrTest, MWMRSet) {
+  using T = uint64_t;
+
+  CHECK_GT(FLAGS_num_threads, 0);
+  for (int i = 0; i < FLAGS_num_reps; ++i) {
+    DEBUG_PRINT("========== start of rep scope");
+    MWMRListSet<T> s;
+    std::vector<std::thread> threads(FLAGS_num_threads);
+    for (int tid = 0; tid < FLAGS_num_threads; ++tid) {
+      threads[tid] = std::thread([&s, tid]() {
+        for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) {
+          s.contains(j);
+          s.add(j);
+          s.remove(j);
+        }
+      });
+    }
+    for (int j = 0; j < 10; ++j) {
+      s.add(j);
+    }
+    for (int j = 0; j < 10; ++j) {
+      s.remove(j);
+    }
+    for (auto& t : threads) {
+      t.join();
+    }
+    DEBUG_PRINT("========== end of rep scope");
+  }
+}
+
 TEST_F(HazptrTest, WIDECAS) {
   WideCAS s;
   std::string u = "";