38154a862fc9db56ffd87733550dc44f5e2a741c
[folly.git] / folly / executors / ThreadPoolExecutor.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #pragma once
18 #include <folly/Executor.h>
19 #include <folly/Memory.h>
20 #include <folly/RWSpinLock.h>
21 #include <folly/executors/GlobalThreadPoolList.h>
22 #include <folly/executors/task_queue/LifoSemMPMCQueue.h>
23 #include <folly/executors/thread_factory/NamedThreadFactory.h>
24 #include <folly/io/async/Request.h>
25 #include <folly/synchronization/Baton.h>
26
27 #include <algorithm>
28 #include <mutex>
29 #include <queue>
30
31 #include <glog/logging.h>
32
33 namespace folly {
34
35 class ThreadPoolExecutor : public virtual folly::Executor {
36  public:
37   explicit ThreadPoolExecutor(
38       size_t numThreads,
39       std::shared_ptr<ThreadFactory> threadFactory,
40       bool isWaitForAll = false);
41
42   ~ThreadPoolExecutor() override;
43
44   void add(Func func) override = 0;
45   virtual void
46   add(Func func, std::chrono::milliseconds expiration, Func expireCallback) = 0;
47
48   void setThreadFactory(std::shared_ptr<ThreadFactory> threadFactory) {
49     CHECK(numThreads() == 0);
50     threadFactory_ = std::move(threadFactory);
51   }
52
53   std::shared_ptr<ThreadFactory> getThreadFactory() {
54     return threadFactory_;
55   }
56
57   size_t numThreads();
58   void setNumThreads(size_t numThreads);
59   /*
60    * stop() is best effort - there is no guarantee that unexecuted tasks won't
61    * be executed before it returns. Specifically, IOThreadPoolExecutor's stop()
62    * behaves like join().
63    */
64   void stop();
65   void join();
66
67   struct PoolStats {
68     PoolStats()
69         : threadCount(0),
70           idleThreadCount(0),
71           activeThreadCount(0),
72           pendingTaskCount(0),
73           totalTaskCount(0),
74           maxIdleTime(0) {}
75     size_t threadCount, idleThreadCount, activeThreadCount;
76     uint64_t pendingTaskCount, totalTaskCount;
77     std::chrono::nanoseconds maxIdleTime;
78   };
79
80   PoolStats getPoolStats();
81   uint64_t getPendingTaskCount();
82
83   struct TaskStats {
84     TaskStats() : expired(false), waitTime(0), runTime(0) {}
85     bool expired;
86     std::chrono::nanoseconds waitTime;
87     std::chrono::nanoseconds runTime;
88   };
89
90   using TaskStatsCallback = std::function<void(TaskStats)>;
91   void subscribeToTaskStats(TaskStatsCallback cb);
92
93   /**
94    * Base class for threads created with ThreadPoolExecutor.
95    * Some subclasses have methods that operate on these
96    * handles.
97    */
98   class ThreadHandle {
99    public:
100     virtual ~ThreadHandle() = default;
101   };
102
103   /**
104    * Observer interface for thread start/stop.
105    * Provides hooks so actions can be taken when
106    * threads are created
107    */
108   class Observer {
109    public:
110     virtual void threadStarted(ThreadHandle*) = 0;
111     virtual void threadStopped(ThreadHandle*) = 0;
112     virtual void threadPreviouslyStarted(ThreadHandle* h) {
113       threadStarted(h);
114     }
115     virtual void threadNotYetStopped(ThreadHandle* h) {
116       threadStopped(h);
117     }
118     virtual ~Observer() = default;
119   };
120
121   void addObserver(std::shared_ptr<Observer>);
122   void removeObserver(std::shared_ptr<Observer>);
123
124  protected:
125   // Prerequisite: threadListLock_ writelocked
126   void addThreads(size_t n);
127   // Prerequisite: threadListLock_ writelocked
128   void removeThreads(size_t n, bool isJoin);
129
130   struct TaskStatsCallbackRegistry;
131
132   struct alignas(hardware_destructive_interference_size) Thread
133       : public ThreadHandle {
134     explicit Thread(ThreadPoolExecutor* pool)
135         : id(nextId++),
136           handle(),
137           idle(true),
138           lastActiveTime(std::chrono::steady_clock::now()),
139           taskStatsCallbacks(pool->taskStatsCallbacks_) {}
140
141     ~Thread() override = default;
142
143     static std::atomic<uint64_t> nextId;
144     uint64_t id;
145     std::thread handle;
146     bool idle;
147     std::chrono::steady_clock::time_point lastActiveTime;
148     folly::Baton<> startupBaton;
149     std::shared_ptr<TaskStatsCallbackRegistry> taskStatsCallbacks;
150   };
151
152   typedef std::shared_ptr<Thread> ThreadPtr;
153
154   struct Task {
155     explicit Task(
156         Func&& func,
157         std::chrono::milliseconds expiration,
158         Func&& expireCallback);
159     Func func_;
160     TaskStats stats_;
161     std::chrono::steady_clock::time_point enqueueTime_;
162     std::chrono::milliseconds expiration_;
163     Func expireCallback_;
164     std::shared_ptr<folly::RequestContext> context_;
165   };
166
167   static void runTask(const ThreadPtr& thread, Task&& task);
168
169   // The function that will be bound to pool threads. It must call
170   // thread->startupBaton.post() when it's ready to consume work.
171   virtual void threadRun(ThreadPtr thread) = 0;
172
173   // Stop n threads and put their ThreadPtrs in the stoppedThreads_ queue
174   // and remove them from threadList_, either synchronize or asynchronize
175   // Prerequisite: threadListLock_ writelocked
176   virtual void stopThreads(size_t n) = 0;
177
178   // Join n stopped threads and remove them from waitingForJoinThreads_ queue.
179   // Should not hold a lock because joining thread operation may invoke some
180   // cleanup operations on the thread, and those cleanup operations may
181   // require a lock on ThreadPoolExecutor.
182   void joinStoppedThreads(size_t n);
183
184   // Create a suitable Thread struct
185   virtual ThreadPtr makeThread() {
186     return std::make_shared<Thread>(this);
187   }
188
189   // Prerequisite: threadListLock_ readlocked
190   virtual uint64_t getPendingTaskCountImpl(const RWSpinLock::ReadHolder&) = 0;
191
192   class ThreadList {
193    public:
194     void add(const ThreadPtr& state) {
195       auto it = std::lower_bound(
196           vec_.begin(),
197           vec_.end(),
198           state,
199           // compare method is a static method of class
200           // and therefore cannot be inlined by compiler
201           // as a template predicate of the STL algorithm
202           // but wrapped up with the lambda function (lambda will be inlined)
203           // compiler can inline compare method as well
204           [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
205             return compare(ts1, ts2);
206           });
207       vec_.insert(it, state);
208     }
209
210     void remove(const ThreadPtr& state) {
211       auto itPair = std::equal_range(
212           vec_.begin(),
213           vec_.end(),
214           state,
215           // the same as above
216           [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
217             return compare(ts1, ts2);
218           });
219       CHECK(itPair.first != vec_.end());
220       CHECK(std::next(itPair.first) == itPair.second);
221       vec_.erase(itPair.first);
222     }
223
224     const std::vector<ThreadPtr>& get() const {
225       return vec_;
226     }
227
228    private:
229     static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
230       return ts1->id < ts2->id;
231     }
232
233     std::vector<ThreadPtr> vec_;
234   };
235
236   class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
237    public:
238     void add(ThreadPtr item) override;
239     ThreadPtr take() override;
240     size_t size() override;
241
242    private:
243     folly::LifoSem sem_;
244     std::mutex mutex_;
245     std::queue<ThreadPtr> queue_;
246   };
247
248   std::shared_ptr<ThreadFactory> threadFactory_;
249   const bool isWaitForAll_; // whether to wait till event base loop exits
250
251   ThreadList threadList_;
252   folly::RWSpinLock threadListLock_;
253   StoppedThreadQueue stoppedThreads_;
254   std::atomic<bool> isJoin_; // whether the current downsizing is a join
255
256   struct TaskStatsCallbackRegistry {
257     folly::ThreadLocal<bool> inCallback;
258     folly::Synchronized<std::vector<TaskStatsCallback>> callbackList;
259   };
260   std::shared_ptr<TaskStatsCallbackRegistry> taskStatsCallbacks_;
261   std::vector<std::shared_ptr<Observer>> observers_;
262   folly::ThreadPoolListHook threadPoolHook_;
263 };
264
265 } // namespace folly