2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/ThreadLocal.h>
19 #include <folly/experimental/AsymmetricMemoryBarrier.h>
28 : localCount_([&]() { return new LocalRefCount(*this); }),
29 collectGuard_(this, [](void*) {}) {}
31 ~TLRefCount() noexcept {
32 assert(globalCount_.load() == 0);
33 assert(state_.load() == State::GLOBAL);
36 // This can't increment from 0.
37 Int operator++() noexcept {
38 auto& localCount = *localCount_;
44 if (state_.load() == State::GLOBAL_TRANSITION) {
45 std::lock_guard<std::mutex> lg(globalMutex_);
48 assert(state_.load() == State::GLOBAL);
50 auto value = globalCount_.load();
55 } while (!globalCount_.compare_exchange_weak(value, value+1));
60 Int operator--() noexcept {
61 auto& localCount = *localCount_;
67 if (state_.load() == State::GLOBAL_TRANSITION) {
68 std::lock_guard<std::mutex> lg(globalMutex_);
71 assert(state_.load() == State::GLOBAL);
73 return globalCount_-- - 1;
76 Int operator*() const {
77 if (state_ != State::GLOBAL) {
80 return globalCount_.load();
83 void useGlobal() noexcept {
84 std::array<TLRefCount*, 1> ptrs{{this}};
88 template <typename Container>
89 static void useGlobal(const Container& refCountPtrs) {
90 std::vector<std::unique_lock<std::mutex>> lgs_;
91 for (auto refCountPtr : refCountPtrs) {
92 lgs_.emplace_back(refCountPtr->globalMutex_);
94 refCountPtr->state_ = State::GLOBAL_TRANSITION;
97 asymmetricHeavyBarrier();
99 for (auto refCountPtr : refCountPtrs) {
100 std::weak_ptr<void> collectGuardWeak = refCountPtr->collectGuard_;
102 // Make sure we can't create new LocalRefCounts
103 refCountPtr->collectGuard_.reset();
105 while (!collectGuardWeak.expired()) {
106 auto accessor = refCountPtr->localCount_.accessAllThreads();
107 for (auto& count : accessor) {
112 refCountPtr->state_ = State::GLOBAL;
117 using AtomicInt = std::atomic<Int>;
125 class LocalRefCount {
127 explicit LocalRefCount(TLRefCount& refCount) :
128 refCount_(refCount) {
129 std::lock_guard<std::mutex> lg(refCount.globalMutex_);
131 collectGuard_ = refCount.collectGuard_;
139 std::lock_guard<std::mutex> lg(collectMutex_);
141 if (!collectGuard_) {
145 collectCount_ = count_.load();
146 refCount_.globalCount_.fetch_add(collectCount_);
147 collectGuard_.reset();
159 bool update(Int delta) {
160 if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
164 // This is equivalent to atomic fetch_add. We know that this operation
165 // is always performed from a single thread. asymmetricLightBarrier()
166 // makes things faster than atomic fetch_add on platforms with native
168 auto count = count_.load(std::memory_order_relaxed) + delta;
169 count_.store(count, std::memory_order_relaxed);
171 asymmetricLightBarrier();
173 if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
174 std::lock_guard<std::mutex> lg(collectMutex_);
179 if (collectCount_ != count) {
188 TLRefCount& refCount_;
190 std::mutex collectMutex_;
191 Int collectCount_{0};
192 std::shared_ptr<void> collectGuard_;
195 std::atomic<State> state_{State::LOCAL};
196 folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
197 std::atomic<int64_t> globalCount_{1};
198 std::mutex globalMutex_;
199 std::shared_ptr<void> collectGuard_;