From bae33d680bbba07b74d9bf9b2ba99766ebb3795e Mon Sep 17 00:00:00 2001 From: Dan Schatzberg Date: Mon, 29 Aug 2016 16:42:44 -0700 Subject: [PATCH] Fix ThreadCachedInt race condition Summary: Move ThreadLocal object destruction to occur under the lock to avoid races. This causes a few cascading changes - the Tag lock needs to be a recursive_mutex so constructing a new object while destroying another st. Also, forking requires a new mutex to avoid deadlocking on accessing a recursive_mutex across a fork() Reviewed By: andriigrynenko Differential Revision: D3755446 fbshipit-source-id: bb4c4f29bab98d763490df29b460066f124303e0 --- folly/ThreadCachedInt.h | 10 ++- folly/ThreadLocal.h | 30 ++++--- folly/detail/ThreadLocalDetail.cpp | 31 ++++--- folly/detail/ThreadLocalDetail.h | 1 + folly/test/ThreadCachedIntTest.cpp | 134 +++++++++++++++++++++++++++++ 5 files changed, 180 insertions(+), 26 deletions(-) diff --git a/folly/ThreadCachedInt.h b/folly/ThreadCachedInt.h index fcc4a7ea..b8469721 100644 --- a/folly/ThreadCachedInt.h +++ b/folly/ThreadCachedInt.h @@ -63,8 +63,11 @@ class ThreadCachedInt : boost::noncopyable { // Reads the current value plus all the cached increments. Requires grabbing // a lock, so this is significantly slower than readFast(). IntT readFull() const { + // This could race with thread destruction and so the access lock should be + // acquired before reading the current value + auto accessor = cache_.accessAllThreads(); IntT ret = readFast(); - for (const auto& cache : cache_.accessAllThreads()) { + for (const auto& cache : accessor) { if (!cache.reset_.load(std::memory_order_acquire)) { ret += cache.val_.load(std::memory_order_relaxed); } @@ -82,8 +85,11 @@ class ThreadCachedInt : boost::noncopyable { // little off, however, but it should be much better than calling readFull() // and set(0) sequentially. IntT readFullAndReset() { + // This could race with thread destruction and so the access lock should be + // acquired before reading the current value + auto accessor = cache_.accessAllThreads(); IntT ret = readFastAndReset(); - for (auto& cache : cache_.accessAllThreads()) { + for (auto& cache : accessor) { if (!cache.reset_.load(std::memory_order_acquire)) { ret += cache.val_.load(std::memory_order_relaxed); cache.reset_.store(true, std::memory_order_release); diff --git a/folly/ThreadLocal.h b/folly/ThreadLocal.h index b4774207..55b7a246 100644 --- a/folly/ThreadLocal.h +++ b/folly/ThreadLocal.h @@ -36,10 +36,11 @@ #pragma once +#include #include #include #include -#include +#include #include #include @@ -249,6 +250,7 @@ class ThreadLocalPtr { friend class ThreadLocalPtr; threadlocal_detail::StaticMetaBase& meta_; + SharedMutex* accessAllThreadsLock_; std::mutex* lock_; uint32_t id_; @@ -321,10 +323,12 @@ class ThreadLocalPtr { Accessor& operator=(const Accessor&) = delete; Accessor(Accessor&& other) noexcept - : meta_(other.meta_), - lock_(other.lock_), - id_(other.id_) { + : meta_(other.meta_), + accessAllThreadsLock_(other.accessAllThreadsLock_), + lock_(other.lock_), + id_(other.id_) { other.id_ = 0; + other.accessAllThreadsLock_ = nullptr; other.lock_ = nullptr; } @@ -338,20 +342,23 @@ class ThreadLocalPtr { assert(&meta_ == &other.meta_); assert(lock_ == nullptr); using std::swap; + swap(accessAllThreadsLock_, other.accessAllThreadsLock_); swap(lock_, other.lock_); swap(id_, other.id_); } Accessor() - : meta_(threadlocal_detail::StaticMeta::instance()), - lock_(nullptr), - id_(0) { - } + : meta_(threadlocal_detail::StaticMeta::instance()), + accessAllThreadsLock_(nullptr), + lock_(nullptr), + id_(0) {} private: explicit Accessor(uint32_t id) - : meta_(threadlocal_detail::StaticMeta::instance()), - lock_(&meta_.lock_) { + : meta_(threadlocal_detail::StaticMeta::instance()), + accessAllThreadsLock_(&meta_.accessAllThreadsLock_), + lock_(&meta_.lock_) { + accessAllThreadsLock_->lock(); lock_->lock(); id_ = id; } @@ -359,8 +366,11 @@ class ThreadLocalPtr { void release() { if (lock_) { lock_->unlock(); + DCHECK(accessAllThreadsLock_ != nullptr); + accessAllThreadsLock_->unlock(); id_ = 0; lock_ = nullptr; + accessAllThreadsLock_ = nullptr; } } }; diff --git a/folly/detail/ThreadLocalDetail.cpp b/folly/detail/ThreadLocalDetail.cpp index 7dcb22d4..dfd17259 100644 --- a/folly/detail/ThreadLocalDetail.cpp +++ b/folly/detail/ThreadLocalDetail.cpp @@ -45,20 +45,23 @@ void StaticMetaBase::onThreadExit(void* ptr) { }; { - std::lock_guard g(meta.lock_); - meta.erase(&(*threadEntry)); - // No need to hold the lock any longer; the ThreadEntry is private to this - // thread now that it's been removed from meta. - } - // NOTE: User-provided deleter / object dtor itself may be using ThreadLocal - // with the same Tag, so dispose() calls below may (re)create some of the - // elements or even increase elementsCapacity, thus multiple cleanup rounds - // may be required. - for (bool shouldRun = true; shouldRun;) { - shouldRun = false; - FOR_EACH_RANGE (i, 0, threadEntry->elementsCapacity) { - if (threadEntry->elements[i].dispose(TLPDestructionMode::THIS_THREAD)) { - shouldRun = true; + SharedMutex::ReadHolder rlock(meta.accessAllThreadsLock_); + { + std::lock_guard g(meta.lock_); + meta.erase(&(*threadEntry)); + // No need to hold the lock any longer; the ThreadEntry is private to this + // thread now that it's been removed from meta. + } + // NOTE: User-provided deleter / object dtor itself may be using ThreadLocal + // with the same Tag, so dispose() calls below may (re)create some of the + // elements or even increase elementsCapacity, thus multiple cleanup rounds + // may be required. + for (bool shouldRun = true; shouldRun;) { + shouldRun = false; + FOR_EACH_RANGE (i, 0, threadEntry->elementsCapacity) { + if (threadEntry->elements[i].dispose(TLPDestructionMode::THIS_THREAD)) { + shouldRun = true; + } } } } diff --git a/folly/detail/ThreadLocalDetail.h b/folly/detail/ThreadLocalDetail.h index 68676a23..e05acab5 100644 --- a/folly/detail/ThreadLocalDetail.h +++ b/folly/detail/ThreadLocalDetail.h @@ -296,6 +296,7 @@ struct StaticMetaBase { uint32_t nextId_; std::vector freeIds_; std::mutex lock_; + SharedMutex accessAllThreadsLock_; pthread_key_t pthreadKey_; ThreadEntry head_; ThreadEntry* (*threadEntry_)(); diff --git a/folly/test/ThreadCachedIntTest.cpp b/folly/test/ThreadCachedIntTest.cpp index 4f5c377e..a17c4cb6 100644 --- a/folly/test/ThreadCachedIntTest.cpp +++ b/folly/test/ThreadCachedIntTest.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -28,6 +29,139 @@ using namespace folly; +using std::unique_ptr; +using std::vector; + +using Counter = ThreadCachedInt; + +class ThreadCachedIntTest : public testing::Test { + public: + uint32_t GetDeadThreadsTotal(const Counter& counter) { + return counter.readFast(); + } +}; + +// Multithreaded tests. Creates a specified number of threads each of +// which iterates a different amount and dies. + +namespace { +// Set cacheSize to be large so cached data moves to target_ only when +// thread dies. +Counter g_counter_for_mt_slow(0, UINT32_MAX); +Counter g_counter_for_mt_fast(0, UINT32_MAX); + +// Used to sync between threads. The value of this variable is the +// maximum iteration index upto which Runner() is allowed to go. +uint32_t g_sync_for_mt(0); +std::condition_variable cv; +std::mutex cv_m; + +// Performs the specified number of iterations. Within each +// iteration, it increments counter 10 times. At the beginning of +// each iteration it checks g_sync_for_mt to see if it can proceed, +// otherwise goes into a loop sleeping and rechecking. +void Runner(Counter* counter, uint32_t iterations) { + for (uint32_t i = 0; i < iterations; ++i) { + std::unique_lock lk(cv_m); + cv.wait(lk, [i] { return i < g_sync_for_mt; }); + for (uint32_t j = 0; j < 10; ++j) { + counter->increment(1); + } + } +} +} + +// Slow test with fewer threads where there are more busy waits and +// many calls to readFull(). This attempts to test as many of the +// code paths in Counter as possible to ensure that counter values are +// properly passed from thread local state, both at calls to +// readFull() and at thread death. +TEST_F(ThreadCachedIntTest, MultithreadedSlow) { + static constexpr uint32_t kNumThreads = 20; + g_sync_for_mt = 0; + vector> threads(kNumThreads); + // Creates kNumThreads threads. Each thread performs a different + // number of iterations in Runner() - threads[0] performs 1 + // iteration, threads[1] performs 2 iterations, threads[2] performs + // 3 iterations, and so on. + for (uint32_t i = 0; i < kNumThreads; ++i) { + threads[i].reset(new std::thread(Runner, &g_counter_for_mt_slow, i + 1)); + } + // Variable to grab current counter value. + int32_t counter_value; + // The expected value of the counter. + int32_t total = 0; + // The expected value of GetDeadThreadsTotal(). + int32_t dead_total = 0; + // Each iteration of the following thread allows one additional + // iteration of the threads. Given that the threads perform + // different number of iterations from 1 through kNumThreads, one + // thread will complete in each of the iterations of the loop below. + for (uint32_t i = 0; i < kNumThreads; ++i) { + // Allow upto iteration i on all threads. + { + std::lock_guard lk(cv_m); + g_sync_for_mt = i + 1; + } + cv.notify_all(); + total += (kNumThreads - i) * 10; + // Loop until the counter reaches its expected value. + do { + counter_value = g_counter_for_mt_slow.readFull(); + } while (counter_value < total); + // All threads have done what they can until iteration i, now make + // sure they don't go further by checking 10 more times in the + // following loop. + for (uint32_t j = 0; j < 10; ++j) { + counter_value = g_counter_for_mt_slow.readFull(); + EXPECT_EQ(total, counter_value); + } + dead_total += (i + 1) * 10; + EXPECT_GE(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow)); + } + // All threads are done. + for (uint32_t i = 0; i < kNumThreads; ++i) { + threads[i]->join(); + } + counter_value = g_counter_for_mt_slow.readFull(); + EXPECT_EQ(total, counter_value); + EXPECT_EQ(total, dead_total); + EXPECT_EQ(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow)); +} + +// Fast test with lots of threads and only one call to readFull() +// at the end. +TEST_F(ThreadCachedIntTest, MultithreadedFast) { + static constexpr uint32_t kNumThreads = 1000; + g_sync_for_mt = 0; + vector> threads(kNumThreads); + // Creates kNumThreads threads. Each thread performs a different + // number of iterations in Runner() - threads[0] performs 1 + // iteration, threads[1] performs 2 iterations, threads[2] performs + // 3 iterations, and so on. + for (uint32_t i = 0; i < kNumThreads; ++i) { + threads[i].reset(new std::thread(Runner, &g_counter_for_mt_fast, i + 1)); + } + // Let the threads run to completion. + { + std::lock_guard lk(cv_m); + g_sync_for_mt = kNumThreads; + } + cv.notify_all(); + // The expected value of the counter. + uint32_t total = 0; + for (uint32_t i = 0; i < kNumThreads; ++i) { + total += (kNumThreads - i) * 10; + } + // Wait for all threads to complete. + for (uint32_t i = 0; i < kNumThreads; ++i) { + threads[i]->join(); + } + int32_t counter_value = g_counter_for_mt_fast.readFull(); + EXPECT_EQ(total, counter_value); + EXPECT_EQ(total, GetDeadThreadsTotal(g_counter_for_mt_fast)); +} + TEST(ThreadCachedInt, SingleThreadedNotCached) { ThreadCachedInt val(0, 0); EXPECT_EQ(0, val.readFast()); -- 2.34.1