*/
#pragma once
-#include <folly/Baton.h>
#include <folly/ThreadLocal.h>
namespace folly {
public:
using Int = int64_t;
- TLRefCount() :
- localCount_([&]() {
- return new LocalRefCount(*this);
- }),
- collectGuard_(&collectBaton_, [](void* p) {
- auto baton = reinterpret_cast<folly::Baton<>*>(p);
- baton->post();
- }) {
- }
+ TLRefCount()
+ : localCount_([&]() { return new LocalRefCount(*this); }),
+ collectGuard_(this, [](void*) {}) {}
~TLRefCount() noexcept {
assert(globalCount_.load() == 0);
state_ = State::GLOBAL_TRANSITION;
- auto accessor = localCount_.accessAllThreads();
- for (auto& count : accessor) {
- count.collect();
- }
+ std::weak_ptr<void> collectGuardWeak = collectGuard_;
+ // Make sure we can't create new LocalRefCounts
collectGuard_.reset();
- collectBaton_.wait();
+
+ while (!collectGuardWeak.expired()) {
+ auto accessor = localCount_.accessAllThreads();
+ for (auto& count : accessor) {
+ count.collect();
+ }
+ }
state_ = State::GLOBAL;
}
return;
}
- collectCount_ = count_;
+ collectCount_ = count_.load();
refCount_.globalCount_.fetch_add(collectCount_);
collectGuard_.reset();
}
return true;
}
- Int count_{0};
+ AtomicInt count_{0};
TLRefCount& refCount_;
std::mutex collectMutex_;
folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
std::atomic<int64_t> globalCount_{1};
std::mutex globalMutex_;
- folly::Baton<> collectBaton_;
std::shared_ptr<void> collectGuard_;
};
EXPECT_EQ(0, ++count);
}
+template <typename RefCount>
+void stressTest() {
+ constexpr size_t kItersCount = 10000;
+
+ for (size_t i = 0; i < kItersCount; ++i) {
+ RefCount count;
+ std::mutex mutex;
+ int a{1};
+
+ std::thread t1([&]() {
+ if (++count) {
+ {
+ std::lock_guard<std::mutex> lg(mutex);
+ EXPECT_EQ(1, a);
+ }
+ --count;
+ }
+ });
+
+ std::thread t2([&]() {
+ count.useGlobal();
+ if (--count == 0) {
+ std::lock_guard<std::mutex> lg(mutex);
+ a = 0;
+ }
+ });
+
+ t1.join();
+ t2.join();
+
+ EXPECT_EQ(0, ++count);
+ }
+}
+
TEST(RCURefCount, Basic) {
basicTest<RCURefCount>();
}
basicTest<TLRefCount>();
}
+TEST(RCURefCount, Stress) {
+ stressTest<TLRefCount>();
+}
+
+TEST(TLRefCount, Stress) {
+ stressTest<TLRefCount>();
+}
}