rearrange folly::Function so that its template arguments are deducable.
authorEric Niebler <eniebler@fb.com>
Wed, 4 May 2016 21:45:20 +0000 (14:45 -0700)
committerFacebook Github Bot 0 <facebook-github-bot-0-bot@fb.com>
Wed, 4 May 2016 21:50:25 +0000 (14:50 -0700)
Summary:
`folly::Function` was an alias to a more complex type with template arguments that could not be deduced. For example, the call to `foo` below was failing to compile.

```
template <class R, class... As>
void foo(folly::Function<R(As...)> f) {
}

int main() {
  foo( folly::Function<void(int)>{ [](int i){} } );
}
```

Rearrange the code so that folly::Function is no longer an alias, thus making its template arguments deducable.

Reviewed By: luciang, spacedentist

Differential Revision: D3256130

fb-gh-sync-id: fb403e48d161635b3b7f36e53b1679eb46cbfe7f
fbshipit-source-id: fb403e48d161635b3b7f36e53b1679eb46cbfe7f

folly/Function.h
folly/test/FunctionTest.cpp

index f5926ac343edc9a1967bb02cdd83bab50693b6fb..a1b230e0d446f9cd0089a2fedd8de835ec9f72ee 100644 (file)
 
 namespace folly {
 
-namespace impl {
-template <typename FunctionType, bool Const = false>
+template <typename FunctionType>
 class Function;
 
 template <typename ReturnType, typename... Args>
-Function<ReturnType(Args...), true> constCastFunction(
-    Function<ReturnType(Args...), false>&&) noexcept;
-} // impl
+Function<ReturnType(Args...) const> constCastFunction(
+    Function<ReturnType(Args...)>&&) noexcept;
 
 namespace detail {
 namespace function {
@@ -246,9 +244,6 @@ union Data {
   typename std::aligned_storage<6 * sizeof(void*)>::type small;
 };
 
-template <bool If, typename T>
-using ConstIf = typename std::conditional<If, const T, T>::type;
-
 template <typename Fun, typename FunT = typename std::decay<Fun>::type>
 using IsSmall = std::integral_constant<
     bool,
@@ -259,6 +254,8 @@ using IsSmall = std::integral_constant<
 using SmallTag = std::true_type;
 using HeapTag = std::false_type;
 
+struct CoerceTag {};
+
 template <typename T>
 bool isNullPtrFn(T* p) {
   return p == nullptr;
@@ -268,21 +265,135 @@ std::false_type isNullPtrFn(T&&) {
   return {};
 }
 
-template <typename ReturnType, typename... Args>
-ReturnType uninitCall(Data&, Args&&...) {
-  throw std::bad_function_call();
-}
 inline bool uninitNoop(Op, Data*, Data*) {
   return false;
 }
 
+template <typename FunctionType>
+struct FunctionTraits;
+
+template <typename ReturnType, typename... Args>
+struct FunctionTraits<ReturnType(Args...)> {
+  using Call = ReturnType (*)(Data&, Args&&...);
+  using IsConst = std::false_type;
+  using ConstSignature = ReturnType(Args...) const;
+  using NonConstSignature = ReturnType(Args...);
+  using OtherSignature = ConstSignature;
+
+  template <typename F, typename G = typename std::decay<F>::type>
+  using ResultOf = decltype(
+      static_cast<ReturnType>(std::declval<G&>()(std::declval<Args>()...)));
+
+  template <typename Fun>
+  static ReturnType callSmall(Data& p, Args&&... args) {
+    return static_cast<ReturnType>((*static_cast<Fun*>(
+        static_cast<void*>(&p.small)))(static_cast<Args&&>(args)...));
+  }
+
+  template <typename Fun>
+  static ReturnType callBig(Data& p, Args&&... args) {
+    return static_cast<ReturnType>(
+        (*static_cast<Fun*>(p.big))(static_cast<Args&&>(args)...));
+  }
+
+  static ReturnType uninitCall(Data&, Args&&...) {
+    throw std::bad_function_call();
+  }
+
+  ReturnType operator()(Args... args) {
+    auto& fn = *static_cast<Function<ReturnType(Args...)>*>(this);
+    return fn.call_(fn.data_, static_cast<Args&&>(args)...);
+  }
+
+  struct SharedFunctionImpl {
+    std::shared_ptr<Function<ReturnType(Args...)>> sp_;
+    ReturnType operator()(Args&&... args) const {
+      return (*sp_)(static_cast<Args&&>(args)...);
+    }
+  };
+};
+
+template <typename ReturnType, typename... Args>
+struct FunctionTraits<ReturnType(Args...) const> {
+  using Call = ReturnType (*)(Data&, Args&&...);
+  using IsConst = std::true_type;
+  using ConstSignature = ReturnType(Args...) const;
+  using NonConstSignature = ReturnType(Args...);
+  using OtherSignature = NonConstSignature;
+
+  template <typename F, typename G = typename std::decay<F>::type>
+  using ResultOf = decltype(static_cast<ReturnType>(
+      std::declval<const G&>()(std::declval<Args>()...)));
+
+  template <typename Fun>
+  static ReturnType callSmall(Data& p, Args&&... args) {
+    return static_cast<ReturnType>((*static_cast<const Fun*>(
+        static_cast<void*>(&p.small)))(static_cast<Args&&>(args)...));
+  }
+
+  template <typename Fun>
+  static ReturnType callBig(Data& p, Args&&... args) {
+    return static_cast<ReturnType>(
+        (*static_cast<const Fun*>(p.big))(static_cast<Args&&>(args)...));
+  }
+
+  static ReturnType uninitCall(Data&, Args&&...) {
+    throw std::bad_function_call();
+  }
+
+  ReturnType operator()(Args... args) const {
+    auto& fn = *static_cast<const Function<ReturnType(Args...) const>*>(this);
+    return fn.call_(fn.data_, static_cast<Args&&>(args)...);
+  }
+
+  struct SharedFunctionImpl {
+    std::shared_ptr<Function<ReturnType(Args...) const>> sp_;
+    ReturnType operator()(Args&&... args) const {
+      return (*sp_)(static_cast<Args&&>(args)...);
+    }
+  };
+};
+
+template <typename Fun>
+bool execSmall(Op o, Data* src, Data* dst) {
+  switch (o) {
+    case Op::MOVE:
+      ::new (static_cast<void*>(&dst->small))
+          Fun(std::move(*static_cast<Fun*>(static_cast<void*>(&src->small))));
+      FOLLY_FALLTHROUGH;
+    case Op::NUKE:
+      static_cast<Fun*>(static_cast<void*>(&src->small))->~Fun();
+      break;
+    case Op::FULL:
+      return true;
+    case Op::HEAP:
+      break;
+  }
+  return false;
+}
+
+template <typename Fun>
+bool execBig(Op o, Data* src, Data* dst) {
+  switch (o) {
+    case Op::MOVE:
+      dst->big = src->big;
+      src->big = nullptr;
+      break;
+    case Op::NUKE:
+      delete static_cast<Fun*>(src->big);
+      break;
+    case Op::FULL:
+    case Op::HEAP:
+      break;
+  }
+  return true;
+}
+
 } // namespace function
 } // namespace detail
 
-namespace impl {
-
-template <typename ReturnType, typename... Args, bool Const>
-class Function<ReturnType(Args...), Const> final {
+template <typename FunctionType>
+class Function final : private detail::function::FunctionTraits<FunctionType> {
   // These utility types are defined outside of the template to reduce
   // the number of instantiations, and then imported in the class
   // namespace for convenience.
@@ -290,98 +401,53 @@ class Function<ReturnType(Args...), Const> final {
   using Op = detail::function::Op;
   using SmallTag = detail::function::SmallTag;
   using HeapTag = detail::function::HeapTag;
-  using Call = ReturnType (*)(Data&, Args&&...);
+  using CoerceTag = detail::function::CoerceTag;
+
+  using Traits = detail::function::FunctionTraits<FunctionType>;
+  using Call = typename Traits::Call;
   using Exec = bool (*)(Op, Data*, Data*);
 
-  template <typename T>
-  using ConstIf = detail::function::ConstIf<Const, T>;
   template <typename Fun>
   using IsSmall = detail::function::IsSmall<Fun>;
 
-  /**
-   * @Function is const-safe:
-   * - @call_ takes @Data as non-const param to avoid code/data duplication.
-   * - @data_ can only be mutated if @constCastFunction is used.
-   */
+  using OtherSignature = typename Traits::OtherSignature;
+
+  // The `data_` member is mutable to allow `constCastFunction` to work without
+  // invoking undefined behavior. Const-correctness is only violated when
+  // `FunctionType` is a const function type (e.g., `int() const`) and `*this`
+  // is the result of calling `constCastFunction`.
   mutable Data data_;
-  Call call_{&detail::function::uninitCall<ReturnType, Args...>};
+  Call call_{&Traits::uninitCall};
   Exec exec_{&detail::function::uninitNoop};
 
-  friend Function<ReturnType(Args...), true> constCastFunction<>(
-      Function<ReturnType(Args...), false>&&) noexcept;
-  friend class Function<ReturnType(Args...), !Const>;
-
-  template <typename Fun>
-  struct OpsSmall {
-    using FunT = typename std::decay<Fun>::type;
-    static ReturnType call(Data& p, Args&&... args) {
-      return static_cast<ReturnType>((*static_cast<ConstIf<FunT>*>(
-          static_cast<void*>(&p.small)))(static_cast<Args&&>(args)...));
-    }
-    static bool exec(Op o, Data* src, Data* dst) {
-      switch (o) {
-        case Op::MOVE:
-          ::new (static_cast<void*>(&dst->small)) FunT(
-              std::move(*static_cast<FunT*>(static_cast<void*>(&src->small))));
-          FOLLY_FALLTHROUGH;
-        case Op::NUKE:
-          static_cast<FunT*>(static_cast<void*>(&src->small))->~FunT();
-          break;
-        case Op::FULL:
-          return true;
-        case Op::HEAP:
-          break;
-      }
-      return false;
-    }
-  };
+  friend Traits;
+  friend Function<typename Traits::ConstSignature> folly::constCastFunction<>(
+      Function<typename Traits::NonConstSignature>&&) noexcept;
+  friend class Function<OtherSignature>;
 
   template <typename Fun>
   Function(Fun&& fun, SmallTag) noexcept {
-    using Ops = OpsSmall<Fun>;
+    using FunT = typename std::decay<Fun>::type;
     if (!detail::function::isNullPtrFn(fun)) {
-      ::new (static_cast<void*>(&data_.small))
-          typename Ops::FunT(static_cast<Fun&&>(fun));
-      exec_ = &Ops::exec;
-      call_ = &Ops::call;
+      ::new (static_cast<void*>(&data_.small)) FunT(static_cast<Fun&&>(fun));
+      call_ = &Traits::template callSmall<FunT>;
+      exec_ = &detail::function::execSmall<FunT>;
     }
   }
 
-  template <typename Fun>
-  struct OpsHeap {
-    using FunT = typename std::decay<Fun>::type;
-    static ReturnType call(Data& p, Args&&... args) {
-      return static_cast<ReturnType>(
-          (*static_cast<ConstIf<FunT>*>(p.big))(static_cast<Args&&>(args)...));
-    }
-    static bool exec(Op o, Data* src, Data* dst) {
-      switch (o) {
-        case Op::MOVE:
-          dst->big = src->big;
-          src->big = nullptr;
-          break;
-        case Op::NUKE:
-          delete static_cast<FunT*>(src->big);
-          break;
-        case Op::FULL:
-        case Op::HEAP:
-          break;
-      }
-      return true;
-    }
-  };
-
   template <typename Fun>
   Function(Fun&& fun, HeapTag) {
-    using Ops = OpsHeap<Fun>;
-    data_.big = new typename Ops::FunT(static_cast<Fun&&>(fun));
-    call_ = &Ops::call;
-    exec_ = &Ops::exec;
+    using FunT = typename std::decay<Fun>::type;
+    data_.big = new FunT(static_cast<Fun&&>(fun));
+    call_ = &Traits::template callBig<FunT>;
+    exec_ = &detail::function::execBig<FunT>;
   }
 
-  template <typename F, typename G = typename std::decay<F>::type>
-  using ResultOf = decltype(static_cast<ReturnType>(
-      std::declval<ConstIf<G>&>()(std::declval<Args>()...)));
+  Function(Function<OtherSignature>&& that, CoerceTag) noexcept {
+    that.exec_(Op::MOVE, &that.data_, &data_);
+    std::swap(call_, that.call_);
+    std::swap(exec_, that.exec_);
+  }
 
  public:
   /**
@@ -427,7 +493,7 @@ class Function<ReturnType(Args...), Const> final {
    * \note `typename = ResultOf<Fun>` prevents this overload from being
    * selected by overload resolution when `fun` is not a compatible function.
    */
-  template <class Fun, typename = ResultOf<Fun>>
+  template <class Fun, typename = typename Traits::template ResultOf<Fun>>
   /* implicit */ Function(Fun&& fun) noexcept(IsSmall<Fun>::value)
       : Function(static_cast<Fun&&>(fun), IsSmall<Fun>{}) {}
 
@@ -435,13 +501,10 @@ class Function<ReturnType(Args...), Const> final {
    * For moving a `Function<X(Ys..) const>` into a `Function<X(Ys...)>`.
    */
   template <
-      bool OtherConst,
-      typename std::enable_if<!Const && OtherConst, int>::type = 0>
-  Function(Function<ReturnType(Args...), OtherConst>&& that) noexcept {
-    that.exec_(Op::MOVE, &that.data_, &data_);
-    std::swap(call_, that.call_);
-    std::swap(exec_, that.exec_);
-  }
+      bool Const = Traits::IsConst::value,
+      typename std::enable_if<!Const, int>::type = 0>
+  Function(Function<OtherSignature>&& that) noexcept
+      : Function(std::move(that), CoerceTag{}) {}
 
   /**
    * If `ptr` is null, constructs an empty `Function`. Otherwise,
@@ -489,7 +552,7 @@ class Function<ReturnType(Args...), Const> final {
    * \note `typename = ResultOf<Fun>` prevents this overload from being
    * selected by overload resolution when `fun` is not a compatible function.
    */
-  template <class Fun, typename = ResultOf<Fun>>
+  template <class Fun, typename = typename Traits::template ResultOf<Fun>>
   Function& operator=(Fun&& fun) noexcept(
       noexcept(/* implicit */ Function(std::declval<Fun>()))) {
     // Doing this in place is more efficient when we can do so safely.
@@ -526,31 +589,8 @@ class Function<ReturnType(Args...), Const> final {
 
   /**
    * Call the wrapped callable object with the specified arguments.
-   * If this `Function` object is a const `folly::Function` object,
-   * this overload shall not participate in overload resolution.
    */
-  template <
-      // `True` makes `operator()` a template so we can SFINAE on `Const`,
-      // which is non-deduced here.
-      bool True = true,
-      typename std::enable_if<True && !Const, int>::type = 0>
-  ReturnType operator()(Args... args) {
-    return call_(data_, static_cast<Args&&>(args)...);
-  }
-
-  /**
-   * Call the wrapped callable object with the specified arguments.
-   * If this `Function` object is not a const `folly::Function` object,
-   * this overload shall not participate in overload resolution.
-   */
-  template <
-      // `True` makes `operator()` a template so we can SFINAE on `Const`,
-      // which is non-deduced here.
-      bool True = true,
-      typename std::enable_if<True && Const, int>::type = 0>
-  ReturnType operator()(Args... args) const {
-    return call_(data_, static_cast<Args&&>(args)...);
-  }
+  using Traits::operator();
 
   /**
    * Exchanges the callable objects of `*this` and `that`.
@@ -582,80 +622,51 @@ class Function<ReturnType(Args...), Const> final {
    * Note that the returned `std::function` will share its state (i.e. captured
    * data) across all copies you make of it, so be very careful when copying.
    */
-  std::function<ReturnType(Args...)> asStdFunction() && {
-    struct Impl {
-      std::shared_ptr<Function> sp_;
-      ReturnType operator()(Args&&... args) const {
-        return (*sp_)(static_cast<Args&&>(args)...);
-      }
-    };
+  std::function<typename Traits::NonConstSignature> asStdFunction() && {
+    using Impl = typename Traits::SharedFunctionImpl;
     return Impl{std::make_shared<Function>(std::move(*this))};
   }
 };
 
-template <typename FunctionType, bool Const>
-void swap(
-    Function<FunctionType, Const>& lhs,
-    Function<FunctionType, Const>& rhs) noexcept {
+template <typename FunctionType>
+void swap(Function<FunctionType>& lhs, Function<FunctionType>& rhs) noexcept {
   lhs.swap(rhs);
 }
 
-template <typename FunctionType, bool Const>
-bool operator==(const Function<FunctionType, Const>& fn, std::nullptr_t) {
+template <typename FunctionType>
+bool operator==(const Function<FunctionType>& fn, std::nullptr_t) {
   return !fn;
 }
 
-template <typename FunctionType, bool Const>
-bool operator==(std::nullptr_t, const Function<FunctionType, Const>& fn) {
+template <typename FunctionType>
+bool operator==(std::nullptr_t, const Function<FunctionType>& fn) {
   return !fn;
 }
 
-template <typename FunctionType, bool Const>
-bool operator!=(const Function<FunctionType, Const>& fn, std::nullptr_t) {
+template <typename FunctionType>
+bool operator!=(const Function<FunctionType>& fn, std::nullptr_t) {
   return !(fn == nullptr);
 }
 
-template <typename FunctionType, bool Const>
-bool operator!=(std::nullptr_t, const Function<FunctionType, Const>& fn) {
+template <typename FunctionType>
+bool operator!=(std::nullptr_t, const Function<FunctionType>& fn) {
   return !(nullptr == fn);
 }
 
 /**
- * NOTE: See detailed note about @constCastFunction at the top of the file.
+ * NOTE: See detailed note about `constCastFunction` at the top of the file.
  * This is potentially dangerous and requires the equivalent of a `const_cast`.
  */
 template <typename ReturnType, typename... Args>
-Function<ReturnType(Args...), true> constCastFunction(
-    Function<ReturnType(Args...), false>&& that) noexcept {
-  Function<ReturnType(Args...), true> fn{};
-  that.exec_(detail::function::Op::MOVE, &that.data_, &fn.data_);
-  std::swap(fn.call_, that.call_);
-  std::swap(fn.exec_, that.exec_);
-  return fn;
+Function<ReturnType(Args...) const> constCastFunction(
+    Function<ReturnType(Args...)>&& that) noexcept {
+  return Function<ReturnType(Args...) const>{std::move(that),
+                                             detail::function::CoerceTag{}};
 }
 
-template <typename FunctionType>
-Function<FunctionType, true> constCastFunction(
-    Function<FunctionType, true>&& that) noexcept {
-  return std::move(that);
-}
-
-template <typename FunctionType>
-struct MakeFunction {};
-
-template <typename ReturnType, typename... Args>
-struct MakeFunction<ReturnType(Args...)> {
-  using type = Function<ReturnType(Args...), false>;
-};
-
 template <typename ReturnType, typename... Args>
-struct MakeFunction<ReturnType(Args...) const> {
-  using type = Function<ReturnType(Args...), true>;
-};
-} // namespace impl
-
-/* using override */ using impl::constCastFunction;
-
-template <typename FunctionType>
-using Function = typename impl::MakeFunction<FunctionType>::type;
+Function<ReturnType(Args...) const> constCastFunction(
+    Function<ReturnType(Args...) const>&& that) noexcept {
+  return std::move(that);
 }
+} // namespace folly
index 7c5159c9ae0873f28ac1b8e5dd209fefbbb435a6..3f1385160f3b1118eee2b35916dcaa9226d58d02 100644 (file)
@@ -49,6 +49,10 @@ struct Functor {
     return oldvalue;
   }
 };
+
+template <typename Ret, typename... Args>
+void deduceArgs(Function<Ret(Args...)>) {}
+
 } // namespace
 
 // TEST =====================================================================
@@ -849,3 +853,9 @@ TEST(Function, SelfMoveAssign) {
   f = std::move(g);
   EXPECT_TRUE(f);
 }
+
+TEST(Function, DeducableArguments) {
+  deduceArgs(Function<void()>{[] {}});
+  deduceArgs(Function<void(int, float)>{[](int, float) {}});
+  deduceArgs(Function<int(int, float)>{[](int i, float) { return i; }});
+}