From 0be5c0a490a250b05590c62a9c70b6f654bfb546 Mon Sep 17 00:00:00 2001 From: Yedidya Feldblum Date: Mon, 8 May 2017 18:50:51 -0700 Subject: [PATCH] Destroy promise/future callback functions before waking waiters Summary: Code may pass a callback which captures an object with a destructor which mutates through a stored reference, triggering heap-use-after-free or stack-use-after-scope. ```lang=c++ void performDataRace() { auto number = std::make_unique(0); auto guard = folly::makeGuard([&number] { *number = 1; }); folly::via(getSomeExecutor(), [guard = std::move(guard)]() mutable {}).wait(); // data race - we may wake and destruct number before guard is destructed on the // executor thread, which is both stack-use-after-scope and heap-use-after-free! } ``` We can avoid this condition by always destructing the provided functor before setting any result on the promise. Reviewed By: spacedentist Differential Revision: D4982969 fbshipit-source-id: 71134c1657bdd4c38c12d8ca17f8335ef4c27352 --- folly/futures/Future-inl.h | 202 +++++++++++++------ folly/futures/Promise-inl.h | 4 + folly/futures/Promise.h | 10 + folly/futures/test/CallbackLifetimeTest.cpp | 207 ++++++++++++++++++++ folly/test/Makefile.am | 1 + 5 files changed, 360 insertions(+), 64 deletions(-) create mode 100644 folly/futures/test/CallbackLifetimeTest.cpp diff --git a/folly/futures/Future-inl.h b/folly/futures/Future-inl.h index 0e6c7e31..e1e34956 100644 --- a/folly/futures/Future-inl.h +++ b/folly/futures/Future-inl.h @@ -48,7 +48,76 @@ typedef folly::Baton<> FutureBatonType; } namespace detail { - std::shared_ptr getTimekeeperSingleton(); +std::shared_ptr getTimekeeperSingleton(); + +// Guarantees that the stored functor is destructed before the stored promise +// may be fulfilled. Assumes the stored functor to be noexcept-destructible. +template +class CoreCallbackState { + public: + template + CoreCallbackState(Promise&& promise, FF&& func) noexcept( + noexcept(F(std::declval()))) + : func_(std::forward(func)), promise_(std::move(promise)) { + assert(before_barrier()); + } + + CoreCallbackState(CoreCallbackState&& that) noexcept( + noexcept(F(std::declval()))) { + if (that.before_barrier()) { + new (&func_) F(std::move(that.func_)); + promise_ = that.stealPromise(); + } + } + + CoreCallbackState& operator=(CoreCallbackState&&) = delete; + + ~CoreCallbackState() { + if (before_barrier()) { + stealPromise(); + } + } + + template + auto invoke(Args&&... args) noexcept( + noexcept(std::declval()(std::declval()...))) { + assert(before_barrier()); + return std::move(func_)(std::forward(args)...); + } + + void setTry(Try&& t) { + stealPromise().setTry(std::move(t)); + } + + void setException(exception_wrapper&& ew) { + stealPromise().setException(std::move(ew)); + } + + Promise stealPromise() noexcept { + assert(before_barrier()); + func_.~F(); + return std::move(promise_); + } + + private: + bool before_barrier() const noexcept { + return !promise_.isFulfilled(); + } + + union { + F func_; + }; + Promise promise_{detail::EmptyConstruct{}}; +}; + +template +inline auto makeCoreCallbackState(Promise&& p, F&& f) noexcept( + noexcept(CoreCallbackState>>( + std::declval&&>(), + std::declval()))) { + return CoreCallbackState>>( + std::move(p), std::forward(f)); +} } template @@ -160,13 +229,13 @@ Future::thenImplementation(F&& func, detail::argResult) { in the destruction of the Future used to create it. */ setCallback_( - [ func = std::forward(func), pm = std::move(p) ](Try && t) mutable { + [state = detail::makeCoreCallbackState( + std::move(p), std::forward(func))](Try && t) mutable { if (!isTry && t.hasException()) { - pm.setException(std::move(t.exception())); + state.setException(std::move(t.exception())); } else { - pm.setWith([&]() { - return std::move(func)(t.template get()...); - }); + state.setTry(makeTryWith( + [&] { return state.invoke(t.template get()...); })); } }); @@ -191,30 +260,31 @@ Future::thenImplementation(F&& func, detail::argResult) { auto f = p.getFuture(); f.core_->setExecutorNoLock(getExecutor()); - setCallback_([ func = std::forward(func), pm = std::move(p) ]( - Try && t) mutable { - auto ew = [&] { - if (!isTry && t.hasException()) { - return std::move(t.exception()); - } else { - try { - auto f2 = std::move(func)(t.template get()...); - // that didn't throw, now we can steal p - f2.setCallback_([p = std::move(pm)](Try && b) mutable { - p.setTry(std::move(b)); - }); - return exception_wrapper(); - } catch (const std::exception& e) { - return exception_wrapper(std::current_exception(), e); - } catch (...) { - return exception_wrapper(std::current_exception()); + setCallback_( + [state = detail::makeCoreCallbackState( + std::move(p), std::forward(func))](Try && t) mutable { + auto ew = [&] { + if (!isTry && t.hasException()) { + return std::move(t.exception()); + } else { + try { + auto f2 = state.invoke(t.template get()...); + // that didn't throw, now we can steal p + f2.setCallback_([p = state.stealPromise()](Try && b) mutable { + p.setTry(std::move(b)); + }); + return exception_wrapper(); + } catch (const std::exception& e) { + return exception_wrapper(std::current_exception(), e); + } catch (...) { + return exception_wrapper(std::current_exception()); + } + } + }(); + if (ew) { + state.setException(std::move(ew)); } - } - }(); - if (ew) { - pm.setException(std::move(ew)); - } - }); + }); return f; } @@ -266,11 +336,12 @@ Future::onError(F&& func) { auto f = p.getFuture(); setCallback_( - [ func = std::forward(func), pm = std::move(p) ](Try && t) mutable { + [state = detail::makeCoreCallbackState( + std::move(p), std::forward(func))](Try && t) mutable { if (!t.template withException([&](Exn& e) { - pm.setWith([&] { return std::move(func)(e); }); + state.setTry(makeTryWith([&] { return state.invoke(e); })); })) { - pm.setTry(std::move(t)); + state.setTry(std::move(t)); } }); @@ -293,29 +364,29 @@ Future::onError(F&& func) { Promise p; auto f = p.getFuture(); - setCallback_([ pm = std::move(p), func = std::forward(func) ]( - Try && t) mutable { - if (!t.template withException([&](Exn& e) { - auto ew = [&] { - try { - auto f2 = std::move(func)(e); - f2.setCallback_([pm = std::move(pm)](Try && t2) mutable { - pm.setTry(std::move(t2)); - }); - return exception_wrapper(); - } catch (const std::exception& e2) { - return exception_wrapper(std::current_exception(), e2); - } catch (...) { - return exception_wrapper(std::current_exception()); - } - }(); - if (ew) { - pm.setException(std::move(ew)); - } - })) { - pm.setTry(std::move(t)); - } - }); + setCallback_( + [state = detail::makeCoreCallbackState( + std::move(p), std::forward(func))](Try && t) mutable { + if (!t.template withException([&](Exn& e) { + auto ew = [&] { + try { + auto f2 = state.invoke(e); + f2.setCallback_([p = state.stealPromise()]( + Try && t2) mutable { p.setTry(std::move(t2)); }); + return exception_wrapper(); + } catch (const std::exception& e2) { + return exception_wrapper(std::current_exception(), e2); + } catch (...) { + return exception_wrapper(std::current_exception()); + } + }(); + if (ew) { + state.setException(std::move(ew)); + } + })) { + state.setTry(std::move(t)); + } + }); return f; } @@ -349,13 +420,14 @@ Future::onError(F&& func) { Promise p; auto f = p.getFuture(); setCallback_( - [ pm = std::move(p), func = std::forward(func) ](Try t) mutable { + [state = detail::makeCoreCallbackState( + std::move(p), std::forward(func))](Try t) mutable { if (t.hasException()) { auto ew = [&] { try { - auto f2 = std::move(func)(std::move(t.exception())); - f2.setCallback_([pm = std::move(pm)](Try t2) mutable { - pm.setTry(std::move(t2)); + auto f2 = state.invoke(std::move(t.exception())); + f2.setCallback_([p = state.stealPromise()](Try t2) mutable { + p.setTry(std::move(t2)); }); return exception_wrapper(); } catch (const std::exception& e2) { @@ -365,10 +437,10 @@ Future::onError(F&& func) { } }(); if (ew) { - pm.setException(std::move(ew)); + state.setException(std::move(ew)); } } else { - pm.setTry(std::move(t)); + state.setTry(std::move(t)); } }); @@ -390,11 +462,13 @@ Future::onError(F&& func) { Promise p; auto f = p.getFuture(); setCallback_( - [ pm = std::move(p), func = std::forward(func) ](Try t) mutable { + [state = detail::makeCoreCallbackState( + std::move(p), std::forward(func))](Try t) mutable { if (t.hasException()) { - pm.setWith([&] { return std::move(func)(std::move(t.exception())); }); + state.setTry(makeTryWith( + [&] { return state.invoke(std::move(t.exception())); })); } else { - pm.setTry(std::move(t)); + state.setTry(std::move(t)); } }); diff --git a/folly/futures/Promise-inl.h b/folly/futures/Promise-inl.h index c55d34ca..61d4d933 100644 --- a/folly/futures/Promise-inl.h +++ b/folly/futures/Promise-inl.h @@ -59,6 +59,10 @@ void Promise::throwIfRetrieved() { } } +template +Promise::Promise(detail::EmptyConstruct) noexcept + : retrieved_(false), core_(nullptr) {} + template Promise::~Promise() { detach(); diff --git a/folly/futures/Promise.h b/folly/futures/Promise.h index a2793e29..75119129 100644 --- a/folly/futures/Promise.h +++ b/folly/futures/Promise.h @@ -25,6 +25,12 @@ namespace folly { // forward declaration template class Future; +namespace detail { +struct EmptyConstruct {}; +template +class CoreCallbackState; +} + template class Promise { public: @@ -98,6 +104,8 @@ class Promise { private: typedef typename Future::corePtr corePtr; template friend class Future; + template + friend class detail::CoreCallbackState; // Whether the Future has been retrieved (a one-time operation). bool retrieved_; @@ -105,6 +113,8 @@ class Promise { // shared core state object corePtr core_; + explicit Promise(detail::EmptyConstruct) noexcept; + void throwIfFulfilled(); void throwIfRetrieved(); void detach(); diff --git a/folly/futures/test/CallbackLifetimeTest.cpp b/folly/futures/test/CallbackLifetimeTest.cpp new file mode 100644 index 00000000..98fa9fdf --- /dev/null +++ b/folly/futures/test/CallbackLifetimeTest.cpp @@ -0,0 +1,207 @@ +/* + * Copyright 2017 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include + +using namespace folly; + +namespace { + +/*** + * The basic premise is to check that the callback passed to then or onError + * is destructed before wait returns on the resulting future. + * + * The approach is to use callbacks where the destructor sleeps 500ms and then + * mutates a counter allocated on the caller stack. The caller checks the + * counter immediately after calling wait. Were the callback not destructed + * before wait returns, then we would very likely see an unchanged counter just + * after wait returns. But if, as we expect, the callback were destructed + * before wait returns, then we must be guaranteed to see a mutated counter + * just after wait returns. + * + * Note that the failure condition is not strictly guaranteed under load. :( + */ +class CallbackLifetimeTest : public testing::Test { + public: + using CounterPtr = std::unique_ptr; + + static bool kRaiseWillThrow() { + return true; + } + static constexpr auto kDelay() { + return std::chrono::milliseconds(500); + } + + auto mkC() { + return std::make_unique(0); + } + auto mkCGuard(CounterPtr& ptr) { + return makeGuard([&] { + /* sleep override */ std::this_thread::sleep_for(kDelay()); + ++*ptr; + }); + } + + static void raise() { + if (kRaiseWillThrow()) { // to avoid marking [[noreturn]] + throw std::runtime_error("raise"); + } + } + static Future raiseFut() { + raise(); + return makeFuture(); + } + + TestExecutor executor{2}; // need at least 2 threads for internal futures +}; +} + +TEST_F(CallbackLifetimeTest, thenReturnsValue) { + auto c = mkC(); + via(&executor).then([_ = mkCGuard(c)]{}).wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, thenReturnsValueThrows) { + auto c = mkC(); + via(&executor).then([_ = mkCGuard(c)] { raise(); }).wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, thenReturnsFuture) { + auto c = mkC(); + via(&executor).then([_ = mkCGuard(c)] { return makeFuture(); }).wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, thenReturnsFutureThrows) { + auto c = mkC(); + via(&executor).then([_ = mkCGuard(c)] { return raiseFut(); }).wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueMatch) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](std::exception&){}) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueMatchThrows) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](std::exception&) { raise(); }) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueWrong) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](std::logic_error&){}) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueWrongThrows) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](std::logic_error&) { raise(); }) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureMatch) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](std::exception&) { return makeFuture(); }) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureMatchThrows) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](std::exception&) { return raiseFut(); }) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureWrong) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](std::logic_error&) { return makeFuture(); }) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureWrongThrows) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](std::logic_error&) { return raiseFut(); }) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsValue) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](exception_wrapper &&){}) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsValueThrows) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](exception_wrapper &&) { raise(); }) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsFuture) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](exception_wrapper &&) { return makeFuture(); }) + .wait(); + EXPECT_EQ(1, *c); +} + +TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsFutureThrows) { + auto c = mkC(); + via(&executor) + .then(raise) + .onError([_ = mkCGuard(c)](exception_wrapper &&) { return raiseFut(); }) + .wait(); + EXPECT_EQ(1, *c); +} diff --git a/folly/test/Makefile.am b/folly/test/Makefile.am index d3d12b51..60fdea58 100644 --- a/folly/test/Makefile.am +++ b/folly/test/Makefile.am @@ -262,6 +262,7 @@ unit_test_LDADD = libfollytestmain.la TESTS += unit_test futures_test_SOURCES = \ + ../futures/test/CallbackLifetimeTest.cpp \ ../futures/test/CollectTest.cpp \ ../futures/test/ContextTest.cpp \ ../futures/test/CoreTest.cpp \ -- 2.34.1