Properly init collectDone_
[folly.git] / folly / experimental / TLRefCount.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/ThreadLocal.h>
19
20 namespace folly {
21
22 class TLRefCount {
23  public:
24   using Int = int64_t;
25
26   TLRefCount() :
27       localCount_([&]() {
28           return new LocalRefCount(*this);
29         }) {
30   }
31
32   ~TLRefCount() noexcept {
33     assert(globalCount_.load() == 0);
34     assert(state_.load() == State::GLOBAL);
35   }
36
37   // This can't increment from 0.
38   Int operator++() noexcept {
39     auto& localCount = *localCount_;
40
41     if (++localCount) {
42       return 42;
43     }
44
45     if (state_.load() == State::GLOBAL_TRANSITION) {
46       std::lock_guard<std::mutex> lg(globalMutex_);
47     }
48
49     assert(state_.load() == State::GLOBAL);
50
51     auto value = globalCount_.load();
52     do {
53       if (value == 0) {
54         return 0;
55       }
56     } while (!globalCount_.compare_exchange_weak(value, value+1));
57
58     return value + 1;
59   }
60
61   Int operator--() noexcept {
62     auto& localCount = *localCount_;
63
64     if (--localCount) {
65       return 42;
66     }
67
68     if (state_.load() == State::GLOBAL_TRANSITION) {
69       std::lock_guard<std::mutex> lg(globalMutex_);
70     }
71
72     assert(state_.load() == State::GLOBAL);
73
74     return --globalCount_;
75   }
76
77   Int operator*() const {
78     if (state_ != State::GLOBAL) {
79       return 42;
80     }
81     return globalCount_.load();
82   }
83
84   void useGlobal() noexcept {
85     std::lock_guard<std::mutex> lg(globalMutex_);
86
87     state_ = State::GLOBAL_TRANSITION;
88
89     auto accessor = localCount_.accessAllThreads();
90     for (auto& count : accessor) {
91       count.collect();
92     }
93
94     state_ = State::GLOBAL;
95   }
96
97  private:
98   using AtomicInt = std::atomic<Int>;
99
100   enum class State {
101     LOCAL,
102     GLOBAL_TRANSITION,
103     GLOBAL
104   };
105
106   class LocalRefCount {
107    public:
108     explicit LocalRefCount(TLRefCount& refCount) :
109         refCount_(refCount) {}
110
111     ~LocalRefCount() {
112       collect();
113     }
114
115     void collect() {
116       std::lock_guard<std::mutex> lg(collectMutex_);
117
118       if (collectDone_) {
119         return;
120       }
121
122       collectCount_ = count_;
123       refCount_.globalCount_ += collectCount_;
124       collectDone_ = true;
125     }
126
127     bool operator++() {
128       return update(1);
129     }
130
131     bool operator--() {
132       return update(-1);
133     }
134
135    private:
136     bool update(Int delta) {
137       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
138         return false;
139       }
140
141       auto count = count_ += delta;
142
143       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
144         std::lock_guard<std::mutex> lg(collectMutex_);
145
146         if (!collectDone_) {
147           return true;
148         }
149         if (collectCount_ != count) {
150           return false;
151         }
152       }
153
154       return true;
155     }
156
157     Int count_{0};
158     TLRefCount& refCount_;
159
160     std::mutex collectMutex_;
161     Int collectCount_{0};
162     bool collectDone_{false};
163   };
164
165   std::atomic<State> state_{State::LOCAL};
166   folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
167   std::atomic<int64_t> globalCount_{1};
168   std::mutex globalMutex_;
169 };
170
171 }