2 * Copyright 2015 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/ThreadLocal.h>
28 return new LocalRefCount(*this);
32 ~TLRefCount() noexcept {
33 assert(globalCount_.load() == 0);
34 assert(state_.load() == State::GLOBAL);
37 // This can't increment from 0.
38 Int operator++() noexcept {
39 auto& localCount = *localCount_;
45 if (state_.load() == State::GLOBAL_TRANSITION) {
46 std::lock_guard<std::mutex> lg(globalMutex_);
49 assert(state_.load() == State::GLOBAL);
51 auto value = globalCount_.load();
56 } while (!globalCount_.compare_exchange_weak(value, value+1));
61 Int operator--() noexcept {
62 auto& localCount = *localCount_;
68 if (state_.load() == State::GLOBAL_TRANSITION) {
69 std::lock_guard<std::mutex> lg(globalMutex_);
72 assert(state_.load() == State::GLOBAL);
74 return --globalCount_;
77 Int operator*() const {
78 if (state_ != State::GLOBAL) {
81 return globalCount_.load();
84 void useGlobal() noexcept {
85 std::lock_guard<std::mutex> lg(globalMutex_);
87 state_ = State::GLOBAL_TRANSITION;
89 auto accessor = localCount_.accessAllThreads();
90 for (auto& count : accessor) {
94 state_ = State::GLOBAL;
98 using AtomicInt = std::atomic<Int>;
106 class LocalRefCount {
108 explicit LocalRefCount(TLRefCount& refCount) :
109 refCount_(refCount) {}
116 std::lock_guard<std::mutex> lg(collectMutex_);
122 collectCount_ = count_;
123 refCount_.globalCount_ += collectCount_;
136 bool update(Int delta) {
137 if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
141 auto count = count_ += delta;
143 if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
144 std::lock_guard<std::mutex> lg(collectMutex_);
149 if (collectCount_ != count) {
158 TLRefCount& refCount_;
160 std::mutex collectMutex_;
161 Int collectCount_{0};
162 bool collectDone_{false};
165 std::atomic<State> state_{State::LOCAL};
166 folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
167 std::atomic<int64_t> globalCount_{1};
168 std::mutex globalMutex_;