2 * Copyright 2014-present Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include <folly/Optional.h>
18 #include <folly/fibers/FiberManagerInternal.h>
19 #include <folly/fibers/ForEach.h>
24 template <class InputIterator>
25 typename std::vector<typename std::enable_if<
27 typename std::result_of<
28 typename std::iterator_traits<InputIterator>::value_type()>::type,
32 typename std::result_of<typename std::iterator_traits<
33 InputIterator>::value_type()>::type>>::type>
34 collectN(InputIterator first, InputIterator last, size_t n) {
35 typedef typename std::result_of<
36 typename std::iterator_traits<InputIterator>::value_type()>::type Result;
38 assert(std::distance(first, last) >= 0);
39 assert(n <= static_cast<size_t>(std::distance(first, last)));
42 std::vector<std::pair<size_t, Result>> results;
45 folly::Optional<Promise<void>> promise;
47 Context(size_t tasksTodo_) : tasksTodo(tasksTodo_) {
48 this->results.reserve(tasksTodo_);
51 auto context = std::make_shared<Context>(n);
53 await([first, last, context](Promise<void> promise) mutable {
54 context->promise = std::move(promise);
55 for (size_t i = 0; first != last; ++i, ++first) {
56 addTask([ i, context, f = std::move(*first) ]() {
59 if (context->tasksTodo == 0) {
62 context->results.emplace_back(i, std::move(result));
64 if (context->tasksTodo == 0) {
67 context->e = std::current_exception();
69 if (--context->tasksTodo == 0) {
70 context->promise->setValue();
76 if (context->e != std::exception_ptr()) {
77 std::rethrow_exception(context->e);
80 return std::move(context->results);
83 template <class InputIterator>
84 typename std::enable_if<
86 typename std::result_of<
87 typename std::iterator_traits<InputIterator>::value_type()>::type,
89 std::vector<size_t>>::type
90 collectN(InputIterator first, InputIterator last, size_t n) {
92 assert(std::distance(first, last) >= 0);
93 assert(n <= static_cast<size_t>(std::distance(first, last)));
96 std::vector<size_t> taskIndices;
99 folly::Optional<Promise<void>> promise;
101 Context(size_t tasksTodo_) : tasksTodo(tasksTodo_) {
102 this->taskIndices.reserve(tasksTodo_);
105 auto context = std::make_shared<Context>(n);
107 await([first, last, context](Promise<void> promise) mutable {
108 context->promise = std::move(promise);
109 for (size_t i = 0; first != last; ++i, ++first) {
110 addTask([ i, context, f = std::move(*first) ]() {
113 if (context->tasksTodo == 0) {
116 context->taskIndices.push_back(i);
118 if (context->tasksTodo == 0) {
121 context->e = std::current_exception();
123 if (--context->tasksTodo == 0) {
124 context->promise->setValue();
130 if (context->e != std::exception_ptr()) {
131 std::rethrow_exception(context->e);
134 return context->taskIndices;
137 template <class InputIterator>
138 typename std::vector<
139 typename std::enable_if<
141 typename std::result_of<typename std::iterator_traits<
142 InputIterator>::value_type()>::type,
144 typename std::result_of<
145 typename std::iterator_traits<InputIterator>::value_type()>::type>::
146 type> inline collectAll(InputIterator first, InputIterator last) {
147 typedef typename std::result_of<
148 typename std::iterator_traits<InputIterator>::value_type()>::type Result;
149 size_t n = size_t(std::distance(first, last));
150 std::vector<Result> results;
151 std::vector<size_t> order(n);
154 forEach(first, last, [&results, &order](size_t id, Result result) {
155 order[id] = results.size();
156 results.emplace_back(std::move(result));
158 assert(results.size() == n);
160 std::vector<Result> orderedResults;
161 orderedResults.reserve(n);
163 for (size_t i = 0; i < n; ++i) {
164 orderedResults.emplace_back(std::move(results[order[i]]));
167 return orderedResults;
170 template <class InputIterator>
171 typename std::enable_if<
173 typename std::result_of<
174 typename std::iterator_traits<InputIterator>::value_type()>::type,
176 void>::type inline collectAll(InputIterator first, InputIterator last) {
177 forEach(first, last, [](size_t /* id */) {});
180 template <class InputIterator>
181 typename std::enable_if<
183 typename std::result_of<
184 typename std::iterator_traits<InputIterator>::value_type()>::type,
188 typename std::result_of<typename std::iterator_traits<
189 InputIterator>::value_type()>::type>>::
190 type inline collectAny(InputIterator first, InputIterator last) {
191 auto result = collectN(first, last, 1);
192 assert(result.size() == 1);
193 return std::move(result[0]);
196 template <class InputIterator>
197 typename std::enable_if<
199 typename std::result_of<
200 typename std::iterator_traits<InputIterator>::value_type()>::type,
202 size_t>::type inline collectAny(InputIterator first, InputIterator last) {
203 auto result = collectN(first, last, 1);
204 assert(result.size() == 1);
205 return std::move(result[0]);
207 } // namespace fibers