2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <folly/executors/ThreadPoolExecutor.h>
19 #include <folly/concurrency/GlobalThreadPoolList.h>
23 ThreadPoolExecutor::ThreadPoolExecutor(
24 size_t /* numThreads */,
25 std::shared_ptr<ThreadFactory> threadFactory,
27 : threadFactory_(std::move(threadFactory)),
28 isWaitForAll_(isWaitForAll),
29 taskStatsCallbacks_(std::make_shared<TaskStatsCallbackRegistry>()),
30 threadPoolHook_("Wangle::ThreadPoolExecutor") {}
32 ThreadPoolExecutor::~ThreadPoolExecutor() {
33 CHECK_EQ(0, threadList_.get().size());
36 ThreadPoolExecutor::Task::Task(
38 std::chrono::milliseconds expiration,
39 Func&& expireCallback)
40 : func_(std::move(func)),
41 expiration_(expiration),
42 expireCallback_(std::move(expireCallback)),
43 context_(folly::RequestContext::saveContext()) {
44 // Assume that the task in enqueued on creation
45 enqueueTime_ = std::chrono::steady_clock::now();
48 void ThreadPoolExecutor::runTask(const ThreadPtr& thread, Task&& task) {
50 auto startTime = std::chrono::steady_clock::now();
51 task.stats_.waitTime = startTime - task.enqueueTime_;
52 if (task.expiration_ > std::chrono::milliseconds(0) &&
53 task.stats_.waitTime >= task.expiration_) {
54 task.stats_.expired = true;
55 if (task.expireCallback_ != nullptr) {
56 task.expireCallback_();
59 folly::RequestContextScopeGuard rctx(task.context_);
62 } catch (const std::exception& e) {
63 LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled "
64 << typeid(e).name() << " exception: " << e.what();
66 LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception "
69 task.stats_.runTime = std::chrono::steady_clock::now() - startTime;
72 thread->lastActiveTime = std::chrono::steady_clock::now();
73 thread->taskStatsCallbacks->callbackList.withRLock([&](auto& callbacks) {
74 *thread->taskStatsCallbacks->inCallback = true;
76 *thread->taskStatsCallbacks->inCallback = false;
79 for (auto& callback : callbacks) {
80 callback(task.stats_);
82 } catch (const std::exception& e) {
83 LOG(ERROR) << "ThreadPoolExecutor: task stats callback threw "
85 << typeid(e).name() << " exception: " << e.what();
87 LOG(ERROR) << "ThreadPoolExecutor: task stats callback threw "
88 "unhandled non-exception object";
93 size_t ThreadPoolExecutor::numThreads() {
94 RWSpinLock::ReadHolder r{&threadListLock_};
95 return threadList_.get().size();
98 void ThreadPoolExecutor::setNumThreads(size_t n) {
99 size_t numThreadsToJoin = 0;
101 RWSpinLock::WriteHolder w{&threadListLock_};
102 const auto current = threadList_.get().size();
104 addThreads(n - current);
105 } else if (n < current) {
106 numThreadsToJoin = current - n;
107 removeThreads(numThreadsToJoin, true);
110 joinStoppedThreads(numThreadsToJoin);
111 CHECK_EQ(n, threadList_.get().size());
112 CHECK_EQ(0, stoppedThreads_.size());
115 // threadListLock_ is writelocked
116 void ThreadPoolExecutor::addThreads(size_t n) {
117 std::vector<ThreadPtr> newThreads;
118 for (size_t i = 0; i < n; i++) {
119 newThreads.push_back(makeThread());
121 for (auto& thread : newThreads) {
122 // TODO need a notion of failing to create the thread
123 // and then handling for that case
124 thread->handle = threadFactory_->newThread(
125 std::bind(&ThreadPoolExecutor::threadRun, this, thread));
126 threadList_.add(thread);
128 for (auto& thread : newThreads) {
129 thread->startupBaton.wait();
131 for (auto& o : observers_) {
132 for (auto& thread : newThreads) {
133 o->threadStarted(thread.get());
138 // threadListLock_ is writelocked
139 void ThreadPoolExecutor::removeThreads(size_t n, bool isJoin) {
140 CHECK_LE(n, threadList_.get().size());
145 void ThreadPoolExecutor::joinStoppedThreads(size_t n) {
146 for (size_t i = 0; i < n; i++) {
147 auto thread = stoppedThreads_.take();
148 thread->handle.join();
152 void ThreadPoolExecutor::stop() {
155 RWSpinLock::WriteHolder w{&threadListLock_};
156 n = threadList_.get().size();
157 removeThreads(n, false);
159 joinStoppedThreads(n);
160 CHECK_EQ(0, threadList_.get().size());
161 CHECK_EQ(0, stoppedThreads_.size());
164 void ThreadPoolExecutor::join() {
167 RWSpinLock::WriteHolder w{&threadListLock_};
168 n = threadList_.get().size();
169 removeThreads(n, true);
171 joinStoppedThreads(n);
172 CHECK_EQ(0, threadList_.get().size());
173 CHECK_EQ(0, stoppedThreads_.size());
176 ThreadPoolExecutor::PoolStats ThreadPoolExecutor::getPoolStats() {
177 const auto now = std::chrono::steady_clock::now();
178 RWSpinLock::ReadHolder r{&threadListLock_};
179 ThreadPoolExecutor::PoolStats stats;
180 stats.threadCount = threadList_.get().size();
181 for (auto thread : threadList_.get()) {
183 stats.idleThreadCount++;
184 const std::chrono::nanoseconds idleTime = now - thread->lastActiveTime;
185 stats.maxIdleTime = std::max(stats.maxIdleTime, idleTime);
187 stats.activeThreadCount++;
190 stats.pendingTaskCount = getPendingTaskCountImpl(r);
191 stats.totalTaskCount = stats.pendingTaskCount + stats.activeThreadCount;
195 uint64_t ThreadPoolExecutor::getPendingTaskCount() {
196 RWSpinLock::ReadHolder r{&threadListLock_};
197 return getPendingTaskCountImpl(r);
200 std::atomic<uint64_t> ThreadPoolExecutor::Thread::nextId(0);
202 void ThreadPoolExecutor::subscribeToTaskStats(TaskStatsCallback cb) {
203 if (*taskStatsCallbacks_->inCallback) {
204 throw std::runtime_error("cannot subscribe in task stats callback");
206 taskStatsCallbacks_->callbackList.wlock()->push_back(std::move(cb));
209 void ThreadPoolExecutor::StoppedThreadQueue::add(
210 ThreadPoolExecutor::ThreadPtr item) {
211 std::lock_guard<std::mutex> guard(mutex_);
212 queue_.push(std::move(item));
216 ThreadPoolExecutor::ThreadPtr ThreadPoolExecutor::StoppedThreadQueue::take() {
219 std::lock_guard<std::mutex> guard(mutex_);
220 if (queue_.size() > 0) {
221 auto item = std::move(queue_.front());
230 size_t ThreadPoolExecutor::StoppedThreadQueue::size() {
231 std::lock_guard<std::mutex> guard(mutex_);
232 return queue_.size();
235 void ThreadPoolExecutor::addObserver(std::shared_ptr<Observer> o) {
236 RWSpinLock::ReadHolder r{&threadListLock_};
237 observers_.push_back(o);
238 for (auto& thread : threadList_.get()) {
239 o->threadPreviouslyStarted(thread.get());
243 void ThreadPoolExecutor::removeObserver(std::shared_ptr<Observer> o) {
244 RWSpinLock::ReadHolder r{&threadListLock_};
245 for (auto& thread : threadList_.get()) {
246 o->threadNotYetStopped(thread.get());
249 for (auto it = observers_.begin(); it != observers_.end(); it++) {
251 observers_.erase(it);