Fix races in TLRefCount
authorAndrii Grynenko <andrii@fb.com>
Thu, 28 Apr 2016 04:00:57 +0000 (21:00 -0700)
committerFacebook Github Bot 9 <facebook-github-bot-9-bot@fb.com>
Thu, 28 Apr 2016 04:05:31 +0000 (21:05 -0700)
Summary:
This fixes 2 races in TLRefCount:
1. Thread-local constructor race, exposed by the stress test. It was possible for LocalRefCount to be created (grabbing collectGuard), but not be added to the thread-local list, so that accessAllThreads wasn't collecting it. collectAll() was then blocking waiting on baton to be posted, causing a dead-lock.
2. LocalRefCount::count_ has to be made atomic, because otherwise += operation may be not flushed (nbronson explained the race in D3133443).

Reviewed By: djwatson

Differential Revision: D3166956

fb-gh-sync-id: 17d58a215ebfc572f8316ed46bafaa5e6a9e2368
fbshipit-source-id: 17d58a215ebfc572f8316ed46bafaa5e6a9e2368

folly/experimental/TLRefCount.h
folly/experimental/test/RefCountTest.cpp

index 87734e6ecb074ab494c96c409188be61bcdeb498..9667f4c100cb5b1d62c188a6b9c05585547c2c49 100644 (file)
@@ -15,7 +15,6 @@
  */
 #pragma once
 
-#include <folly/Baton.h>
 #include <folly/ThreadLocal.h>
 
 namespace folly {
@@ -24,15 +23,9 @@ class TLRefCount {
  public:
   using Int = int64_t;
 
-  TLRefCount() :
-      localCount_([&]() {
-          return new LocalRefCount(*this);
-        }),
-      collectGuard_(&collectBaton_, [](void* p) {
-          auto baton = reinterpret_cast<folly::Baton<>*>(p);
-          baton->post();
-        }) {
-  }
+  TLRefCount()
+      : localCount_([&]() { return new LocalRefCount(*this); }),
+        collectGuard_(this, [](void*) {}) {}
 
   ~TLRefCount() noexcept {
     assert(globalCount_.load() == 0);
@@ -91,13 +84,17 @@ class TLRefCount {
 
     state_ = State::GLOBAL_TRANSITION;
 
-    auto accessor = localCount_.accessAllThreads();
-    for (auto& count : accessor) {
-      count.collect();
-    }
+    std::weak_ptr<void> collectGuardWeak = collectGuard_;
 
+    // Make sure we can't create new LocalRefCounts
     collectGuard_.reset();
-    collectBaton_.wait();
+
+    while (!collectGuardWeak.expired()) {
+      auto accessor = localCount_.accessAllThreads();
+      for (auto& count : accessor) {
+        count.collect();
+      }
+    }
 
     state_ = State::GLOBAL;
   }
@@ -131,7 +128,7 @@ class TLRefCount {
         return;
       }
 
-      collectCount_ = count_;
+      collectCount_ = count_.load();
       refCount_.globalCount_.fetch_add(collectCount_);
       collectGuard_.reset();
     }
@@ -166,7 +163,7 @@ class TLRefCount {
       return true;
     }
 
-    Int count_{0};
+    AtomicInt count_{0};
     TLRefCount& refCount_;
 
     std::mutex collectMutex_;
@@ -178,7 +175,6 @@ class TLRefCount {
   folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
   std::atomic<int64_t> globalCount_{1};
   std::mutex globalMutex_;
-  folly::Baton<> collectBaton_;
   std::shared_ptr<void> collectGuard_;
 };
 
index b9ef475bbdda32c8c6a1f6c86dc9e6a453b07382..a9b08143a9866a72ddb376f8e50eed836c18aa08 100644 (file)
@@ -83,6 +83,40 @@ void basicTest() {
   EXPECT_EQ(0, ++count);
 }
 
+template <typename RefCount>
+void stressTest() {
+  constexpr size_t kItersCount = 10000;
+
+  for (size_t i = 0; i < kItersCount; ++i) {
+    RefCount count;
+    std::mutex mutex;
+    int a{1};
+
+    std::thread t1([&]() {
+      if (++count) {
+        {
+          std::lock_guard<std::mutex> lg(mutex);
+          EXPECT_EQ(1, a);
+        }
+        --count;
+      }
+    });
+
+    std::thread t2([&]() {
+      count.useGlobal();
+      if (--count == 0) {
+        std::lock_guard<std::mutex> lg(mutex);
+        a = 0;
+      }
+    });
+
+    t1.join();
+    t2.join();
+
+    EXPECT_EQ(0, ++count);
+  }
+}
+
 TEST(RCURefCount, Basic) {
   basicTest<RCURefCount>();
 }
@@ -91,4 +125,11 @@ TEST(TLRefCount, Basic) {
   basicTest<TLRefCount>();
 }
 
+TEST(RCURefCount, Stress) {
+  stressTest<TLRefCount>();
+}
+
+TEST(TLRefCount, Stress) {
+  stressTest<TLRefCount>();
+}
 }