78054edd0a1719505bfe0c0169a9d1637b4c90b5
[folly.git] / folly / experimental / TLRefCount.h
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/ThreadLocal.h>
19 #include <folly/experimental/AsymmetricMemoryBarrier.h>
20
21 namespace folly {
22
23 class TLRefCount {
24  public:
25   using Int = int64_t;
26
27   TLRefCount()
28       : localCount_([&]() { return new LocalRefCount(*this); }),
29         collectGuard_(this, [](void*) {}) {}
30
31   ~TLRefCount() noexcept {
32     assert(globalCount_.load() == 0);
33     assert(state_.load() == State::GLOBAL);
34   }
35
36   // This can't increment from 0.
37   Int operator++() noexcept {
38     auto& localCount = *localCount_;
39
40     if (++localCount) {
41       return 42;
42     }
43
44     if (state_.load() == State::GLOBAL_TRANSITION) {
45       std::lock_guard<std::mutex> lg(globalMutex_);
46     }
47
48     assert(state_.load() == State::GLOBAL);
49
50     auto value = globalCount_.load();
51     do {
52       if (value == 0) {
53         return 0;
54       }
55     } while (!globalCount_.compare_exchange_weak(value, value+1));
56
57     return value + 1;
58   }
59
60   Int operator--() noexcept {
61     auto& localCount = *localCount_;
62
63     if (--localCount) {
64       return 42;
65     }
66
67     if (state_.load() == State::GLOBAL_TRANSITION) {
68       std::lock_guard<std::mutex> lg(globalMutex_);
69     }
70
71     assert(state_.load() == State::GLOBAL);
72
73     return globalCount_-- - 1;
74   }
75
76   Int operator*() const {
77     if (state_ != State::GLOBAL) {
78       return 42;
79     }
80     return globalCount_.load();
81   }
82
83   void useGlobal() noexcept {
84     std::lock_guard<std::mutex> lg(globalMutex_);
85
86     state_ = State::GLOBAL_TRANSITION;
87
88     asymmetricHeavyBarrier();
89
90     std::weak_ptr<void> collectGuardWeak = collectGuard_;
91
92     // Make sure we can't create new LocalRefCounts
93     collectGuard_.reset();
94
95     while (!collectGuardWeak.expired()) {
96       auto accessor = localCount_.accessAllThreads();
97       for (auto& count : accessor) {
98         count.collect();
99       }
100     }
101
102     state_ = State::GLOBAL;
103   }
104
105  private:
106   using AtomicInt = std::atomic<Int>;
107
108   enum class State {
109     LOCAL,
110     GLOBAL_TRANSITION,
111     GLOBAL
112   };
113
114   class LocalRefCount {
115    public:
116     explicit LocalRefCount(TLRefCount& refCount) :
117         refCount_(refCount) {
118       std::lock_guard<std::mutex> lg(refCount.globalMutex_);
119
120       collectGuard_ = refCount.collectGuard_;
121     }
122
123     ~LocalRefCount() {
124       collect();
125     }
126
127     void collect() {
128       std::lock_guard<std::mutex> lg(collectMutex_);
129
130       if (!collectGuard_) {
131         return;
132       }
133
134       collectCount_ = count_.load();
135       refCount_.globalCount_.fetch_add(collectCount_);
136       collectGuard_.reset();
137     }
138
139     bool operator++() {
140       return update(1);
141     }
142
143     bool operator--() {
144       return update(-1);
145     }
146
147    private:
148     bool update(Int delta) {
149       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
150         return false;
151       }
152
153       // This is equivalent to atomic fetch_add. We know that this operation
154       // is always performed from a single thread. asymmetricLightBarrier()
155       // makes things faster than atomic fetch_add on platforms with native
156       // support.
157       auto count = count_.load(std::memory_order_relaxed) + delta;
158       count_.store(count, std::memory_order_relaxed);
159
160       asymmetricLightBarrier();
161
162       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
163         std::lock_guard<std::mutex> lg(collectMutex_);
164
165         if (collectGuard_) {
166           return true;
167         }
168         if (collectCount_ != count) {
169           return false;
170         }
171       }
172
173       return true;
174     }
175
176     AtomicInt count_{0};
177     TLRefCount& refCount_;
178
179     std::mutex collectMutex_;
180     Int collectCount_{0};
181     std::shared_ptr<void> collectGuard_;
182   };
183
184   std::atomic<State> state_{State::LOCAL};
185   folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
186   std::atomic<int64_t> globalCount_{1};
187   std::mutex globalMutex_;
188   std::shared_ptr<void> collectGuard_;
189 };
190
191 }