2 * Copyright 2015 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/Executor.h>
19 #include <folly/wangle/concurrent/LifoSemMPMCQueue.h>
20 #include <folly/wangle/concurrent/NamedThreadFactory.h>
21 #include <folly/wangle/rx/Observable.h>
22 #include <folly/Baton.h>
23 #include <folly/Memory.h>
24 #include <folly/RWSpinLock.h>
30 #include <glog/logging.h>
32 namespace folly { namespace wangle {
34 class ThreadPoolExecutor : public virtual Executor {
36 explicit ThreadPoolExecutor(
38 std::shared_ptr<ThreadFactory> threadFactory);
40 ~ThreadPoolExecutor();
42 virtual void add(Func func) override = 0;
45 std::chrono::milliseconds expiration,
46 Func expireCallback) = 0;
48 void setThreadFactory(std::shared_ptr<ThreadFactory> threadFactory) {
49 CHECK(numThreads() == 0);
50 threadFactory_ = std::move(threadFactory);
53 std::shared_ptr<ThreadFactory> getThreadFactory(void) {
54 return threadFactory_;
58 void setNumThreads(size_t numThreads);
60 * stop() is best effort - there is no guarantee that unexecuted tasks won't
61 * be executed before it returns. Specifically, IOThreadPoolExecutor's stop()
62 * behaves like join().
68 PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0),
69 pendingTaskCount(0), totalTaskCount(0) {}
70 size_t threadCount, idleThreadCount, activeThreadCount;
71 uint64_t pendingTaskCount, totalTaskCount;
74 PoolStats getPoolStats();
77 TaskStats() : expired(false), waitTime(0), runTime(0) {}
79 std::chrono::nanoseconds waitTime;
80 std::chrono::nanoseconds runTime;
83 Subscription<TaskStats> subscribeToTaskStats(
84 const ObserverPtr<TaskStats>& observer) {
85 return taskStatsSubject_->subscribe(observer);
89 * Base class for threads created with ThreadPoolExecutor.
90 * Some subclasses have methods that operate on these
95 virtual ~ThreadHandle() = default;
99 * Observer interface for thread start/stop.
100 * Provides hooks so actions can be taken when
101 * threads are created
105 virtual void threadStarted(ThreadHandle*) = 0;
106 virtual void threadStopped(ThreadHandle*) = 0;
107 virtual void threadPreviouslyStarted(ThreadHandle* h) {
110 virtual void threadNotYetStopped(ThreadHandle* h) {
113 virtual ~Observer() = default;
116 void addObserver(std::shared_ptr<Observer>);
117 void removeObserver(std::shared_ptr<Observer>);
120 // Prerequisite: threadListLock_ writelocked
121 void addThreads(size_t n);
122 // Prerequisite: threadListLock_ writelocked
123 void removeThreads(size_t n, bool isJoin);
125 struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread : public ThreadHandle {
126 explicit Thread(ThreadPoolExecutor* pool)
130 taskStatsSubject(pool->taskStatsSubject_) {}
134 static std::atomic<uint64_t> nextId;
138 Baton<> startupBaton;
139 std::shared_ptr<Subject<TaskStats>> taskStatsSubject;
142 typedef std::shared_ptr<Thread> ThreadPtr;
147 std::chrono::milliseconds expiration,
148 Func&& expireCallback);
151 std::chrono::steady_clock::time_point enqueueTime_;
152 std::chrono::milliseconds expiration_;
153 Func expireCallback_;
156 static void runTask(const ThreadPtr& thread, Task&& task);
158 // The function that will be bound to pool threads. It must call
159 // thread->startupBaton.post() when it's ready to consume work.
160 virtual void threadRun(ThreadPtr thread) = 0;
162 // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue
163 // Prerequisite: threadListLock_ writelocked
164 virtual void stopThreads(size_t n) = 0;
166 // Create a suitable Thread struct
167 virtual ThreadPtr makeThread() {
168 return std::make_shared<Thread>(this);
171 // Prerequisite: threadListLock_ readlocked
172 virtual uint64_t getPendingTaskCount() = 0;
176 void add(const ThreadPtr& state) {
177 auto it = std::lower_bound(vec_.begin(), vec_.end(), state,
178 // compare method is a static method of class
179 // and therefore cannot be inlined by compiler
180 // as a template predicate of the STL algorithm
181 // but wrapped up with the lambda function (lambda will be inlined)
182 // compiler can inline compare method as well
183 [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
184 return compare(ts1, ts2);
186 vec_.insert(it, state);
189 void remove(const ThreadPtr& state) {
190 auto itPair = std::equal_range(vec_.begin(), vec_.end(), state,
192 [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
193 return compare(ts1, ts2);
195 CHECK(itPair.first != vec_.end());
196 CHECK(std::next(itPair.first) == itPair.second);
197 vec_.erase(itPair.first);
200 const std::vector<ThreadPtr>& get() const {
205 static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
206 return ts1->id < ts2->id;
209 std::vector<ThreadPtr> vec_;
212 class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
214 void add(ThreadPtr item) override;
215 ThreadPtr take() override;
216 size_t size() override;
221 std::queue<ThreadPtr> queue_;
224 std::shared_ptr<ThreadFactory> threadFactory_;
225 ThreadList threadList_;
226 RWSpinLock threadListLock_;
227 StoppedThreadQueue stoppedThreads_;
228 std::atomic<bool> isJoin_; // whether the current downsizing is a join
230 std::shared_ptr<Subject<TaskStats>> taskStatsSubject_;
231 std::vector<std::shared_ptr<Observer>> observers_;