Allow adding tasks to TaskIterator dynamically
authorAndrii Grynenko <andrii@fb.com>
Thu, 5 May 2016 01:15:53 +0000 (18:15 -0700)
committerFacebook Github Bot 2 <facebook-github-bot-2-bot@fb.com>
Thu, 5 May 2016 05:55:00 +0000 (22:55 -0700)
Reviewed By: yfeldblum

Differential Revision: D3244669

fb-gh-sync-id: 73fa4ecb0432a802e67ef922255a896d96f32374
fbshipit-source-id: 73fa4ecb0432a802e67ef922255a896d96f32374

folly/experimental/fibers/AddTasks-inl.h
folly/experimental/fibers/AddTasks.h
folly/experimental/fibers/test/FibersTest.cpp

index 6a5c1da4c2e2a098afe33fe81fdc3d442a380340..f0a712e9628439e2c42613e92fc8b5d2d8341457 100644 (file)
 #include <memory>
 #include <vector>
 
-#include <folly/experimental/fibers/FiberManager.h>
-
 namespace folly {
 namespace fibers {
 
 template <typename T>
 TaskIterator<T>::TaskIterator(TaskIterator&& other) noexcept
-    : context_(std::move(other.context_)), id_(other.id_) {}
-
-template <typename T>
-TaskIterator<T>::TaskIterator(std::shared_ptr<Context> context)
-    : context_(std::move(context)), id_(-1) {
-  assert(context_);
-}
+    : context_(std::move(other.context_)), id_(other.id_), fm_(other.fm_) {}
 
 template <typename T>
 inline bool TaskIterator<T>::hasCompleted() const {
@@ -92,6 +84,30 @@ inline size_t TaskIterator<T>::getTaskID() const {
   return id_;
 }
 
+template <typename T>
+template <typename F>
+void TaskIterator<T>::addTask(F&& func) {
+  static_assert(
+      std::is_convertible<typename std::result_of<F()>::type, T>::value,
+      "TaskIterator<T>: T must be convertible from func()'s return type");
+
+  auto taskId = context_->totalTasks++;
+
+  fm_.addTask(
+      [ taskId, context = context_, func = std::forward<F>(func) ]() mutable {
+        context->results.emplace_back(
+            taskId, folly::makeTryWith(std::move(func)));
+
+        // Check for awaiting iterator.
+        if (context->promise.hasValue()) {
+          if (--context->tasksToFulfillPromise == 0) {
+            context->promise->setValue();
+            context->promise.clear();
+          }
+        }
+      });
+}
+
 template <class InputIterator>
 TaskIterator<typename std::result_of<
     typename std::iterator_traits<InputIterator>::value_type()>::type>
@@ -101,32 +117,15 @@ addTasks(InputIterator first, InputIterator last) {
       ResultType;
   typedef TaskIterator<ResultType> IteratorType;
 
-  auto context = std::make_shared<typename IteratorType::Context>();
-  context->totalTasks = std::distance(first, last);
-  context->results.reserve(context->totalTasks);
-
-  for (size_t i = 0; first != last; ++i, ++first) {
-#ifdef __clang__
-#pragma clang diagnostic push // ignore generalized lambda capture warning
-#pragma clang diagnostic ignored "-Wc++1y-extensions"
-#endif
-    addTask([ i, context, f = std::move(*first) ]() {
-      context->results.emplace_back(i, folly::makeTryWith(std::move(f)));
-
-      // Check for awaiting iterator.
-      if (context->promise.hasValue()) {
-        if (--context->tasksToFulfillPromise == 0) {
-          context->promise->setValue();
-          context->promise.clear();
-        }
-      }
-    });
-#ifdef __clang__
-#pragma clang diagnostic pop
-#endif
+  IteratorType iterator;
+
+  for (; first != last; ++first) {
+    iterator.addTask(std::move(*first));
   }
 
-  return IteratorType(std::move(context));
+  iterator.context_->results.reserve(iterator.context_->totalTasks);
+
+  return std::move(iterator);
 }
 }
 }
index 9e3019d3da345260650eff231dcca6adf4ad11be..be4c25b1ec4a48e8d2d8870ed9aa45aa9f659ac2 100644 (file)
@@ -19,6 +19,7 @@
 #include <vector>
 
 #include <folly/Optional.h>
+#include <folly/experimental/fibers/FiberManager.h>
 #include <folly/experimental/fibers/Promise.h>
 #include <folly/futures/Try.h>
 
@@ -49,6 +50,8 @@ class TaskIterator {
  public:
   typedef T value_type;
 
+  TaskIterator() : fm_(FiberManager::getFiberManager()) {}
+
   // not copyable
   TaskIterator(const TaskIterator& other) = delete;
   TaskIterator& operator=(const TaskIterator& other) = delete;
@@ -57,6 +60,14 @@ class TaskIterator {
   TaskIterator(TaskIterator&& other) noexcept;
   TaskIterator& operator=(TaskIterator&& other) = delete;
 
+  /**
+   * Add one more task to the TaskIterator.
+   *
+   * @param func task to be added, will be scheduled on current FiberManager
+   */
+  template <typename F>
+  void addTask(F&& func);
+
   /**
    * @return True if there are tasks immediately available to be consumed (no
    *         need to await on them).
@@ -111,10 +122,9 @@ class TaskIterator {
     size_t tasksToFulfillPromise{0};
   };
 
-  std::shared_ptr<Context> context_;
-  size_t id_;
-
-  explicit TaskIterator(std::shared_ptr<Context> context);
+  std::shared_ptr<Context> context_{std::make_shared<Context>()};
+  size_t id_{std::numeric_limits<size_t>::max()};
+  FiberManager& fm_;
 
   folly::Try<T> awaitNextResult();
 };
index 08d182ff2890df8775e06cca992e9affcaaf0fc0..c19e95a0a3e3f95853ee87507348206c406b3e0e 100644 (file)
@@ -463,7 +463,7 @@ TEST(FiberManager, addTasksVoidThrow) {
   loopController.loop(std::move(loopFunc));
 }
 
-TEST(FiberManager, reserve) {
+TEST(FiberManager, addTasksReserve) {
   std::vector<Promise<int>> pendingFibers;
   bool taskAdded = false;
 
@@ -517,6 +517,42 @@ TEST(FiberManager, reserve) {
   loopController.loop(std::move(loopFunc));
 }
 
+TEST(FiberManager, addTaskDynamic) {
+  folly::EventBase evb;
+
+  Baton batons[3];
+
+  auto makeTask = [&](size_t taskId) {
+    return [&, taskId]() -> size_t {
+      batons[taskId].wait();
+      return taskId;
+    };
+  };
+
+  getFiberManager(evb)
+      .addTaskFuture([&]() {
+        TaskIterator<size_t> iterator;
+
+        iterator.addTask(makeTask(0));
+        iterator.addTask(makeTask(1));
+
+        batons[1].post();
+
+        EXPECT_EQ(1, iterator.awaitNext());
+
+        iterator.addTask(makeTask(2));
+
+        batons[2].post();
+
+        EXPECT_EQ(2, iterator.awaitNext());
+
+        batons[0].post();
+
+        EXPECT_EQ(0, iterator.awaitNext());
+      })
+      .waitVia(&evb);
+}
+
 TEST(FiberManager, forEach) {
   std::vector<Promise<int>> pendingFibers;
   bool taskAdded = false;