allow passing function pointers to Future::onError()
[folly.git] / folly / futures / Future-inl.h
index c4f9ca2f307ab6b49e6c1286f1947346d72867f4..2a8b56b888084e1b89175164d1430fab6ebfab75 100644 (file)
@@ -48,7 +48,81 @@ typedef folly::Baton<> FutureBatonType;
 }
 
 namespace detail {
-  std::shared_ptr<Timekeeper> getTimekeeperSingleton();
+std::shared_ptr<Timekeeper> getTimekeeperSingleton();
+
+//  Guarantees that the stored functor is destructed before the stored promise
+//  may be fulfilled. Assumes the stored functor to be noexcept-destructible.
+template <typename T, typename F>
+class CoreCallbackState {
+ public:
+  template <typename FF>
+  CoreCallbackState(Promise<T>&& promise, FF&& func) noexcept(
+      noexcept(F(std::declval<FF>())))
+      : func_(std::forward<FF>(func)), promise_(std::move(promise)) {
+    assert(before_barrier());
+  }
+
+  CoreCallbackState(CoreCallbackState&& that) noexcept(
+      noexcept(F(std::declval<F>()))) {
+    if (that.before_barrier()) {
+      new (&func_) F(std::move(that.func_));
+      promise_ = that.stealPromise();
+    }
+  }
+
+  CoreCallbackState& operator=(CoreCallbackState&&) = delete;
+
+  ~CoreCallbackState() {
+    if (before_barrier()) {
+      stealPromise();
+    }
+  }
+
+  template <typename... Args>
+  auto invoke(Args&&... args) noexcept(
+      noexcept(std::declval<F&&>()(std::declval<Args&&>()...))) {
+    assert(before_barrier());
+    return std::move(func_)(std::forward<Args>(args)...);
+  }
+
+  template <typename... Args>
+  auto tryInvoke(Args&&... args) noexcept {
+    return makeTryWith([&] { return invoke(std::forward<Args>(args)...); });
+  }
+
+  void setTry(Try<T>&& t) {
+    stealPromise().setTry(std::move(t));
+  }
+
+  void setException(exception_wrapper&& ew) {
+    stealPromise().setException(std::move(ew));
+  }
+
+  Promise<T> stealPromise() noexcept {
+    assert(before_barrier());
+    func_.~F();
+    return std::move(promise_);
+  }
+
+ private:
+  bool before_barrier() const noexcept {
+    return !promise_.isFulfilled();
+  }
+
+  union {
+    F func_;
+  };
+  Promise<T> promise_{detail::EmptyConstruct{}};
+};
+
+template <typename T, typename F>
+inline auto makeCoreCallbackState(Promise<T>&& p, F&& f) noexcept(
+    noexcept(CoreCallbackState<T, _t<std::decay<F>>>(
+        std::declval<Promise<T>&&>(),
+        std::declval<F&&>()))) {
+  return CoreCallbackState<T, _t<std::decay<F>>>(
+      std::move(p), std::forward<F>(f));
+}
 }
 
 template <class T>
@@ -160,13 +234,13 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) {
      in the destruction of the Future used to create it.
      */
   setCallback_(
-      [ func = std::forward<F>(func), pm = std::move(p) ](Try<T> && t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
         if (!isTry && t.hasException()) {
-          pm.setException(std::move(t.exception()));
+          state.setException(std::move(t.exception()));
         } else {
-          pm.setWith([&]() {
-            return std::move(func)(t.template get<isTry, Args>()...);
-          });
+          state.setTry(makeTryWith(
+              [&] { return state.invoke(t.template get<isTry, Args>()...); }));
         }
       });
 
@@ -191,30 +265,22 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) {
   auto f = p.getFuture();
   f.core_->setExecutorNoLock(getExecutor());
 
-  setCallback_([ func = std::forward<F>(func), pm = std::move(p) ](
-      Try<T> && t) mutable {
-    auto ew = [&] {
-      if (!isTry && t.hasException()) {
-        return std::move(t.exception());
-      } else {
-        try {
-          auto f2 = std::move(func)(t.template get<isTry, Args>()...);
-          // that didn't throw, now we can steal p
-          f2.setCallback_([p = std::move(pm)](Try<B> && b) mutable {
-            p.setTry(std::move(b));
-          });
-          return exception_wrapper();
-        } catch (const std::exception& e) {
-          return exception_wrapper(std::current_exception(), e);
-        } catch (...) {
-          return exception_wrapper(std::current_exception());
+  setCallback_(
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
+        if (!isTry && t.hasException()) {
+          state.setException(std::move(t.exception()));
+        } else {
+          auto tf2 = state.tryInvoke(t.template get<isTry, Args>()...);
+          if (tf2.hasException()) {
+            state.setException(std::move(tf2.exception()));
+          } else {
+            tf2->setCallback_([p = state.stealPromise()](Try<B> && b) mutable {
+              p.setTry(std::move(b));
+            });
+          }
         }
-      }
-    }();
-    if (ew) {
-      pm.setException(std::move(ew));
-    }
-  });
+      });
 
   return f;
 }
@@ -256,7 +322,7 @@ typename std::enable_if<
   !detail::Extract<F>::ReturnsFuture::value,
   Future<T>>::type
 Future<T>::onError(F&& func) {
-  typedef typename detail::Extract<F>::FirstArg Exn;
+  typedef std::remove_reference_t<typename detail::Extract<F>::FirstArg> Exn;
   static_assert(
       std::is_same<typename detail::Extract<F>::RawReturn, T>::value,
       "Return type of onError callback must be T or Future<T>");
@@ -266,11 +332,12 @@ Future<T>::onError(F&& func) {
   auto f = p.getFuture();
 
   setCallback_(
-      [ func = std::forward<F>(func), pm = std::move(p) ](Try<T> && t) mutable {
-        if (!t.template withException<Exn>([&](Exn& e) {
-              pm.setWith([&] { return std::move(func)(e); });
-            })) {
-          pm.setTry(std::move(t));
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
+        if (auto e = t.template tryGetExceptionObject<Exn>()) {
+          state.setTry(makeTryWith([&] { return state.invoke(*e); }));
+        } else {
+          state.setTry(std::move(t));
         }
       });
 
@@ -288,34 +355,27 @@ Future<T>::onError(F&& func) {
   static_assert(
       std::is_same<typename detail::Extract<F>::Return, Future<T>>::value,
       "Return type of onError callback must be T or Future<T>");
-  typedef typename detail::Extract<F>::FirstArg Exn;
+  typedef std::remove_reference_t<typename detail::Extract<F>::FirstArg> Exn;
 
   Promise<T> p;
   auto f = p.getFuture();
 
-  setCallback_([ pm = std::move(p), func = std::forward<F>(func) ](
-      Try<T> && t) mutable {
-    if (!t.template withException<Exn>([&](Exn& e) {
-          auto ew = [&] {
-            try {
-              auto f2 = std::move(func)(e);
-              f2.setCallback_([pm = std::move(pm)](Try<T> && t2) mutable {
-                pm.setTry(std::move(t2));
-              });
-              return exception_wrapper();
-            } catch (const std::exception& e2) {
-              return exception_wrapper(std::current_exception(), e2);
-            } catch (...) {
-              return exception_wrapper(std::current_exception());
-            }
-          }();
-          if (ew) {
-            pm.setException(std::move(ew));
+  setCallback_(
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
+        if (auto e = t.template tryGetExceptionObject<Exn>()) {
+          auto tf2 = state.tryInvoke(*e);
+          if (tf2.hasException()) {
+            state.setException(std::move(tf2.exception()));
+          } else {
+            tf2->setCallback_([p = state.stealPromise()](Try<T> && t3) mutable {
+              p.setTry(std::move(t3));
+            });
           }
-        })) {
-      pm.setTry(std::move(t));
-    }
-  });
+        } else {
+          state.setTry(std::move(t));
+        }
+      });
 
   return f;
 }
@@ -324,7 +384,7 @@ template <class T>
 template <class F>
 Future<T> Future<T>::ensure(F&& func) {
   return this->then([funcw = std::forward<F>(func)](Try<T> && t) mutable {
-    funcw();
+    std::move(funcw)();
     return makeFuture(std::move(t));
   });
 }
@@ -333,7 +393,7 @@ template <class T>
 template <class F>
 Future<T> Future<T>::onTimeout(Duration dur, F&& func, Timekeeper* tk) {
   return within(dur, tk).onError([funcw = std::forward<F>(func)](
-      TimedOut const&) { return funcw(); });
+      TimedOut const&) { return std::move(funcw)(); });
 }
 
 template <class T>
@@ -349,26 +409,19 @@ Future<T>::onError(F&& func) {
   Promise<T> p;
   auto f = p.getFuture();
   setCallback_(
-      [ pm = std::move(p), func = std::forward<F>(func) ](Try<T> t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> t) mutable {
         if (t.hasException()) {
-          auto ew = [&] {
-            try {
-              auto f2 = std::move(func)(std::move(t.exception()));
-              f2.setCallback_([pm = std::move(pm)](Try<T> t2) mutable {
-                pm.setTry(std::move(t2));
-              });
-              return exception_wrapper();
-            } catch (const std::exception& e2) {
-              return exception_wrapper(std::current_exception(), e2);
-            } catch (...) {
-              return exception_wrapper(std::current_exception());
-            }
-          }();
-          if (ew) {
-            pm.setException(std::move(ew));
+          auto tf2 = state.tryInvoke(std::move(t.exception()));
+          if (tf2.hasException()) {
+            state.setException(std::move(tf2.exception()));
+          } else {
+            tf2->setCallback_([p = state.stealPromise()](Try<T> && t3) mutable {
+              p.setTry(std::move(t3));
+            });
           }
         } else {
-          pm.setTry(std::move(t));
+          state.setTry(std::move(t));
         }
       });
 
@@ -390,11 +443,13 @@ Future<T>::onError(F&& func) {
   Promise<T> p;
   auto f = p.getFuture();
   setCallback_(
-      [ pm = std::move(p), func = std::forward<F>(func) ](Try<T> t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
         if (t.hasException()) {
-          pm.setWith([&] { return std::move(func)(std::move(t.exception())); });
+          state.setTry(makeTryWith(
+              [&] { return state.invoke(std::move(t.exception())); }));
         } else {
-          pm.setTry(std::move(t));
+          state.setTry(std::move(t));
         }
       });
 
@@ -457,8 +512,7 @@ inline Future<T> Future<T>::via(Executor* executor, int8_t priority) & {
 
 template <class Func>
 auto via(Executor* x, Func&& func)
-  -> Future<typename isFuture<decltype(func())>::Inner>
-{
+    -> Future<typename isFuture<decltype(std::declval<Func>()())>::Inner> {
   // TODO make this actually more performant. :-P #7260175
   return via(x).then(std::forward<Func>(func));
 }
@@ -504,7 +558,7 @@ makeFutureWith(F&& func) {
   using InnerType =
       typename isFuture<typename std::result_of<F()>::type>::Inner;
   try {
-    return func();
+    return std::forward<F>(func)();
   } catch (std::exception& e) {
     return makeFuture<InnerType>(
         exception_wrapper(std::current_exception(), e));
@@ -522,9 +576,8 @@ typename std::enable_if<
 makeFutureWith(F&& func) {
   using LiftedResult =
       typename Unit::Lift<typename std::result_of<F()>::type>::type;
-  return makeFuture<LiftedResult>(makeTryWith([&func]() mutable {
-    return func();
-  }));
+  return makeFuture<LiftedResult>(
+      makeTryWith([&func]() mutable { return std::forward<F>(func)(); }));
 }
 
 template <class T>
@@ -574,7 +627,7 @@ collectAll(Fs&&... fs) {
   auto ctx = std::make_shared<detail::CollectAllVariadicContext<
     typename std::decay<Fs>::type::value_type...>>();
   detail::collectVariadicHelper<detail::CollectAllVariadicContext>(
-    ctx, std::forward<typename std::decay<Fs>::type>(fs)...);
+      ctx, std::forward<Fs>(fs)...);
   return ctx->p.getFuture();
 }
 
@@ -677,7 +730,7 @@ collect(Fs&&... fs) {
   auto ctx = std::make_shared<detail::CollectVariadicContext<
     typename std::decay<Fs>::type::value_type...>>();
   detail::collectVariadicHelper<detail::CollectVariadicContext>(
-    ctx, std::forward<typename std::decay<Fs>::type>(fs)...);
+      ctx, std::forward<Fs>(fs)...);
   return ctx->p.getFuture();
 }