From 5ff51e977e67cf7759d048b5d928d6ec85186705 Mon Sep 17 00:00:00 2001 From: Andrii Grynenko Date: Tue, 8 Dec 2015 12:51:40 -0800 Subject: [PATCH] Fix TLRefCount race around thread local destruction and fix RefCount unit test Reviewed By: pavlo-fb Differential Revision: D2708425 fb-gh-sync-id: 665d077210503df4f4e8aa8f88ce5b9b277582f3 --- folly/experimental/TLRefCount.h | 24 +++++++++++++++++++----- folly/experimental/test/RefCountTest.cpp | 12 ++++++++++-- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/folly/experimental/TLRefCount.h b/folly/experimental/TLRefCount.h index 6d20b5ef..7a06d47c 100644 --- a/folly/experimental/TLRefCount.h +++ b/folly/experimental/TLRefCount.h @@ -15,6 +15,7 @@ */ #pragma once +#include #include namespace folly { @@ -26,6 +27,10 @@ class TLRefCount { TLRefCount() : localCount_([&]() { return new LocalRefCount(*this); + }), + collectGuard_(&collectBaton_, [](void* p) { + auto baton = reinterpret_cast*>(p); + baton->post(); }) { } @@ -91,6 +96,9 @@ class TLRefCount { count.collect(); } + collectGuard_.reset(); + collectBaton_.wait(); + state_ = State::GLOBAL; } @@ -106,7 +114,11 @@ class TLRefCount { class LocalRefCount { public: explicit LocalRefCount(TLRefCount& refCount) : - refCount_(refCount) {} + refCount_(refCount) { + std::lock_guard lg(refCount.globalMutex_); + + collectGuard_ = refCount.collectGuard_; + } ~LocalRefCount() { collect(); @@ -115,13 +127,13 @@ class TLRefCount { void collect() { std::lock_guard lg(collectMutex_); - if (collectDone_) { + if (!collectGuard_) { return; } collectCount_ = count_; refCount_.globalCount_ += collectCount_; - collectDone_ = true; + collectGuard_.reset(); } bool operator++() { @@ -143,7 +155,7 @@ class TLRefCount { if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) { std::lock_guard lg(collectMutex_); - if (!collectDone_) { + if (collectGuard_) { return true; } if (collectCount_ != count) { @@ -159,13 +171,15 @@ class TLRefCount { std::mutex collectMutex_; Int collectCount_{0}; - bool collectDone_{false}; + std::shared_ptr collectGuard_; }; std::atomic state_{State::LOCAL}; folly::ThreadLocal localCount_; std::atomic globalCount_{1}; std::mutex globalMutex_; + folly::Baton<> collectBaton_; + std::shared_ptr collectGuard_; }; } diff --git a/folly/experimental/test/RefCountTest.cpp b/folly/experimental/test/RefCountTest.cpp index 74fa2447..48ad80d9 100644 --- a/folly/experimental/test/RefCountTest.cpp +++ b/folly/experimental/test/RefCountTest.cpp @@ -35,12 +35,16 @@ void basicTest() { folly::Baton<> b; std::vector ts; + folly::Baton<> threadBatons[numThreads]; for (size_t t = 0; t < numThreads; ++t) { - ts.emplace_back([&count, &b, &got0, numIters, t]() { + ts.emplace_back([&count, &b, &got0, numIters, t, &threadBatons]() { for (size_t i = 0; i < numIters; ++i) { auto ret = ++count; EXPECT_TRUE(ret > 1); + if (i == 0) { + threadBatons[t].post(); + } } if (t == 0) { @@ -58,10 +62,14 @@ void basicTest() { }); } + for (size_t t = 0; t < numThreads; ++t) { + threadBatons[t].wait(); + } + b.wait(); count.useGlobal(); - EXPECT_TRUE(--count > 0); + --count; for (auto& t: ts) { t.join(); -- 2.34.1