Fix TLRefCount race around thread local destruction and fix RefCount unit test
authorAndrii Grynenko <andrii@fb.com>
Tue, 8 Dec 2015 20:51:40 +0000 (12:51 -0800)
committerfacebook-github-bot-1 <folly-bot@fb.com>
Tue, 8 Dec 2015 21:20:22 +0000 (13:20 -0800)
Reviewed By: pavlo-fb

Differential Revision: D2708425

fb-gh-sync-id: 665d077210503df4f4e8aa8f88ce5b9b277582f3

folly/experimental/TLRefCount.h
folly/experimental/test/RefCountTest.cpp

index 6d20b5ef45ce34a285a6b6ccca0c6542847a5c1c..7a06d47c2bb8f80a18d88ee432feea7a483450fe 100644 (file)
@@ -15,6 +15,7 @@
  */
 #pragma once
 
+#include <folly/Baton.h>
 #include <folly/ThreadLocal.h>
 
 namespace folly {
@@ -26,6 +27,10 @@ class TLRefCount {
   TLRefCount() :
       localCount_([&]() {
           return new LocalRefCount(*this);
+        }),
+      collectGuard_(&collectBaton_, [](void* p) {
+          auto baton = reinterpret_cast<folly::Baton<>*>(p);
+          baton->post();
         }) {
   }
 
@@ -91,6 +96,9 @@ class TLRefCount {
       count.collect();
     }
 
+    collectGuard_.reset();
+    collectBaton_.wait();
+
     state_ = State::GLOBAL;
   }
 
@@ -106,7 +114,11 @@ class TLRefCount {
   class LocalRefCount {
    public:
     explicit LocalRefCount(TLRefCount& refCount) :
-        refCount_(refCount) {}
+        refCount_(refCount) {
+      std::lock_guard<std::mutex> lg(refCount.globalMutex_);
+
+      collectGuard_ = refCount.collectGuard_;
+    }
 
     ~LocalRefCount() {
       collect();
@@ -115,13 +127,13 @@ class TLRefCount {
     void collect() {
       std::lock_guard<std::mutex> lg(collectMutex_);
 
-      if (collectDone_) {
+      if (!collectGuard_) {
         return;
       }
 
       collectCount_ = count_;
       refCount_.globalCount_ += collectCount_;
-      collectDone_ = true;
+      collectGuard_.reset();
     }
 
     bool operator++() {
@@ -143,7 +155,7 @@ class TLRefCount {
       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
         std::lock_guard<std::mutex> lg(collectMutex_);
 
-        if (!collectDone_) {
+        if (collectGuard_) {
           return true;
         }
         if (collectCount_ != count) {
@@ -159,13 +171,15 @@ class TLRefCount {
 
     std::mutex collectMutex_;
     Int collectCount_{0};
-    bool collectDone_{false};
+    std::shared_ptr<void> collectGuard_;
   };
 
   std::atomic<State> state_{State::LOCAL};
   folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
   std::atomic<int64_t> globalCount_{1};
   std::mutex globalMutex_;
+  folly::Baton<> collectBaton_;
+  std::shared_ptr<void> collectGuard_;
 };
 
 }
index 74fa2447edff3a9310b055bd83abc70c9f6b7b53..48ad80d91e396d3a4667de45670645ecf499b110 100644 (file)
@@ -35,12 +35,16 @@ void basicTest() {
   folly::Baton<> b;
 
   std::vector<std::thread> ts;
+  folly::Baton<> threadBatons[numThreads];
   for (size_t t = 0; t < numThreads; ++t) {
-    ts.emplace_back([&count, &b, &got0, numIters, t]() {
+    ts.emplace_back([&count, &b, &got0, numIters, t, &threadBatons]() {
         for (size_t i = 0; i < numIters; ++i) {
           auto ret = ++count;
 
           EXPECT_TRUE(ret > 1);
+          if (i == 0) {
+            threadBatons[t].post();
+          }
         }
 
         if (t == 0) {
@@ -58,10 +62,14 @@ void basicTest() {
       });
   }
 
+  for (size_t t = 0; t < numThreads; ++t) {
+    threadBatons[t].wait();
+  }
+
   b.wait();
 
   count.useGlobal();
-  EXPECT_TRUE(--count > 0);
+  --count;
 
   for (auto& t: ts) {
     t.join();