2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/ThreadLocal.h>
27 : localCount_([&]() { return new LocalRefCount(*this); }),
28 collectGuard_(this, [](void*) {}) {}
30 ~TLRefCount() noexcept {
31 assert(globalCount_.load() == 0);
32 assert(state_.load() == State::GLOBAL);
35 // This can't increment from 0.
36 Int operator++() noexcept {
37 auto& localCount = *localCount_;
43 if (state_.load() == State::GLOBAL_TRANSITION) {
44 std::lock_guard<std::mutex> lg(globalMutex_);
47 assert(state_.load() == State::GLOBAL);
49 auto value = globalCount_.load();
54 } while (!globalCount_.compare_exchange_weak(value, value+1));
59 Int operator--() noexcept {
60 auto& localCount = *localCount_;
66 if (state_.load() == State::GLOBAL_TRANSITION) {
67 std::lock_guard<std::mutex> lg(globalMutex_);
70 assert(state_.load() == State::GLOBAL);
72 return globalCount_-- - 1;
75 Int operator*() const {
76 if (state_ != State::GLOBAL) {
79 return globalCount_.load();
82 void useGlobal() noexcept {
83 std::lock_guard<std::mutex> lg(globalMutex_);
85 state_ = State::GLOBAL_TRANSITION;
87 std::weak_ptr<void> collectGuardWeak = collectGuard_;
89 // Make sure we can't create new LocalRefCounts
90 collectGuard_.reset();
92 while (!collectGuardWeak.expired()) {
93 auto accessor = localCount_.accessAllThreads();
94 for (auto& count : accessor) {
99 state_ = State::GLOBAL;
103 using AtomicInt = std::atomic<Int>;
111 class LocalRefCount {
113 explicit LocalRefCount(TLRefCount& refCount) :
114 refCount_(refCount) {
115 std::lock_guard<std::mutex> lg(refCount.globalMutex_);
117 collectGuard_ = refCount.collectGuard_;
125 std::lock_guard<std::mutex> lg(collectMutex_);
127 if (!collectGuard_) {
131 collectCount_ = count_.load();
132 refCount_.globalCount_.fetch_add(collectCount_);
133 collectGuard_.reset();
145 bool update(Int delta) {
146 if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
150 auto count = count_ += delta;
152 if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
153 std::lock_guard<std::mutex> lg(collectMutex_);
158 if (collectCount_ != count) {
167 TLRefCount& refCount_;
169 std::mutex collectMutex_;
170 Int collectCount_{0};
171 std::shared_ptr<void> collectGuard_;
174 std::atomic<State> state_{State::LOCAL};
175 folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
176 std::atomic<int64_t> globalCount_{1};
177 std::mutex globalMutex_;
178 std::shared_ptr<void> collectGuard_;