Fix a race in Observable context destruction
authorAndrii Grynenko <andrii@fb.com>
Mon, 1 May 2017 21:51:49 +0000 (14:51 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 1 May 2017 22:05:36 +0000 (15:05 -0700)
Summary: In the subscribe callback It's possible that we lock the Context shared_ptr and while update is running, all other shared_ptr's are released. This will result in Context to be destroyed from the wrong thread (thread runnning subcribe callback), which is not desired.

Reviewed By: yfeldblum

Differential Revision: D4964605

fbshipit-source-id: 285327a6873ccb7393fa3067ba7e612c29dbc454

folly/experimental/observer/Observable-inl.h
folly/experimental/observer/Observer-inl.h
folly/experimental/observer/Observer.h
folly/experimental/observer/detail/ObserverManager.cpp
folly/experimental/observer/detail/ObserverManager.h
folly/experimental/observer/test/ObserverTest.cpp

index fdf62f0bc1c123d287cd05f63a44b87eed3005e3..231991e6adafab0c83827bd82b3061a73a4b7fb8 100644 (file)
@@ -22,7 +22,9 @@ template <typename Observable, typename Traits>
 class ObserverCreator<Observable, Traits>::Context {
  public:
   template <typename... Args>
-  Context(Args&&... args) : observable_(std::forward<Args>(args)...) {}
+  Context(Args&&... args) : observable_(std::forward<Args>(args)...) {
+    updateValue();
+  }
 
   ~Context() {
     if (value_.copy()) {
@@ -47,21 +49,11 @@ class ObserverCreator<Observable, Traits>::Context {
     // callbacks (getting new value from observable and storing it into value_
     // is not atomic).
     std::lock_guard<std::mutex> lg(updateMutex_);
-
-    {
-      auto newValue = Traits::get(observable_);
-      if (!newValue) {
-        throw std::logic_error("Observable returned nullptr.");
-      }
-      value_.swap(newValue);
-    }
+    updateValue();
 
     bool expected = false;
     if (updateRequested_.compare_exchange_strong(expected, true)) {
-      if (auto core = coreWeak_.lock()) {
-        observer_detail::ObserverManager::scheduleRefreshNewVersion(
-            std::move(core));
-      }
+      observer_detail::ObserverManager::scheduleRefreshNewVersion(coreWeak_);
     }
   }
 
@@ -71,6 +63,14 @@ class ObserverCreator<Observable, Traits>::Context {
   }
 
  private:
+  void updateValue() {
+    auto newValue = Traits::get(observable_);
+    if (!newValue) {
+      throw std::logic_error("Observable returned nullptr.");
+    }
+    value_.swap(newValue);
+  }
+
   folly::Synchronized<std::shared_ptr<const T>> value_;
   std::atomic<bool> updateRequested_{false};
 
@@ -89,24 +89,68 @@ ObserverCreator<Observable, Traits>::ObserverCreator(Args&&... args)
 template <typename Observable, typename Traits>
 Observer<typename ObserverCreator<Observable, Traits>::T>
 ObserverCreator<Observable, Traits>::getObserver()&& {
-  auto core = observer_detail::Core::create([context = context_]() {
+  // This master shared_ptr allows grabbing derived weak_ptrs, pointing to the
+  // the same Context object, but using a separate reference count. Master
+  // shared_ptr destructor then blocks until all shared_ptrs obtained from
+  // derived weak_ptrs are released.
+  class ContextMasterPointer {
+   public:
+    explicit ContextMasterPointer(std::shared_ptr<Context> context)
+        : contextMaster_(std::move(context)),
+          context_(
+              contextMaster_.get(),
+              [destroyBaton = destroyBaton_](Context*) {
+                destroyBaton->post();
+              }) {}
+    ~ContextMasterPointer() {
+      if (context_) {
+        context_.reset();
+        destroyBaton_->wait();
+      }
+    }
+    ContextMasterPointer(const ContextMasterPointer&) = delete;
+    ContextMasterPointer(ContextMasterPointer&&) = default;
+    ContextMasterPointer& operator=(const ContextMasterPointer&) = delete;
+    ContextMasterPointer& operator=(ContextMasterPointer&&) = default;
+
+    Context* operator->() const {
+      return contextMaster_.get();
+    }
+
+    std::weak_ptr<Context> get_weak() {
+      return context_;
+    }
+
+   private:
+    std::shared_ptr<folly::Baton<>> destroyBaton_{
+        std::make_shared<folly::Baton<>>()};
+    std::shared_ptr<Context> contextMaster_;
+    std::shared_ptr<Context> context_;
+  };
+  // We want to make sure that Context can only be destroyed when Core is
+  // destroyed. So we have to avoid the situation when subscribe callback is
+  // locking Context shared_ptr and remains the last to release it.
+  // We solve this by having Core hold the master shared_ptr and subscription
+  // callback gets derived weak_ptr.
+  ContextMasterPointer contextMaster(context_);
+  auto contextWeak = contextMaster.get_weak();
+  auto observer = makeObserver([context = std::move(contextMaster)]() {
     return context->get();
   });
 
-  context_->setCore(core);
-
-  context_->subscribe([contextWeak = std::weak_ptr<Context>(context_)] {
+  context_->setCore(observer.core_);
+  context_->subscribe([contextWeak = std::move(contextWeak)] {
     if (auto context = contextWeak.lock()) {
       context->update();
     }
   });
 
+  // Do an extra update in case observable was updated between observer creation
+  // and setting updates callback.
   context_->update();
   context_.reset();
 
-  DCHECK(core->getVersion() > 0);
-
-  return Observer<T>(std::move(core));
+  return observer;
 }
 }
 }
index 55088cdfd8b850f5f7ab1234b1e42050de957357..bdc62a092f6c1af3a1bae244b8f7f124b028267c 100644 (file)
@@ -38,10 +38,10 @@ Observer<observer_detail::ResultOfUnwrapSharedPtr<F>> makeObserver(
     F&& creator) {
   auto core = observer_detail::Core::
       create([creator = std::forward<F>(creator)]() mutable {
-        return std::static_pointer_cast<void>(creator());
+        return std::static_pointer_cast<const void>(creator());
       });
 
-  observer_detail::ObserverManager::scheduleRefreshNewVersion(core);
+  observer_detail::ObserverManager::initCore(core);
 
   return Observer<observer_detail::ResultOfUnwrapSharedPtr<F>>(core);
 }
index 662a0113608d87a6852cbffe162961d60849efd6..192293e5c0389e9c4b58287bad098e266644a98f 100644 (file)
@@ -134,6 +134,9 @@ class Observer {
   }
 
  private:
+  template <typename Observable, typename Traits>
+  friend class ObserverCreator;
+
   observer_detail::Core::Ptr core_;
 };
 
index 7654dff5e93399f665ada289af2a0de622caef1b..f909ef57f3da31a9d88df714e4157e1b10ba51d7 100644 (file)
@@ -106,28 +106,35 @@ class ObserverManager::NextQueue {
   explicit NextQueue(ObserverManager& manager)
       : manager_(manager), queue_(kNextQueueSize) {
     thread_ = std::thread([&]() {
-      Core::Ptr queueCore;
+      Core::WeakPtr queueCoreWeak;
 
       while (true) {
-        queue_.blockingRead(queueCore);
-
-        if (!queueCore) {
+        queue_.blockingRead(queueCoreWeak);
+        if (stop_) {
           return;
         }
 
         std::vector<Core::Ptr> cores;
-        cores.emplace_back(std::move(queueCore));
+        {
+          auto queueCore = queueCoreWeak.lock();
+          if (!queueCore) {
+            continue;
+          }
+          cores.emplace_back(std::move(queueCore));
+        }
 
         {
           SharedMutexReadPriority::WriteHolder wh(manager_.versionMutex_);
 
           // We can't pick more tasks from the queue after we bumped the
           // version, so we have to do this while holding the lock.
-          while (cores.size() < kNextQueueSize && queue_.read(queueCore)) {
-            if (!queueCore) {
+          while (cores.size() < kNextQueueSize && queue_.read(queueCoreWeak)) {
+            if (stop_) {
               return;
             }
-            cores.emplace_back(std::move(queueCore));
+            if (auto queueCore = queueCoreWeak.lock()) {
+              cores.emplace_back(std::move(queueCore));
+            }
           }
 
           ++manager_.version_;
@@ -140,20 +147,22 @@ class ObserverManager::NextQueue {
     });
   }
 
-  void add(Core::Ptr core) {
+  void add(Core::WeakPtr core) {
     queue_.blockingWrite(std::move(core));
   }
 
   ~NextQueue() {
-    // Emtpy element signals thread to terminate
-    queue_.blockingWrite(nullptr);
+    stop_ = true;
+    // Write to the queue to notify the thread.
+    queue_.blockingWrite(Core::WeakPtr());
     thread_.join();
   }
 
  private:
   ObserverManager& manager_;
-  MPMCQueue<Core::Ptr> queue_;
+  MPMCQueue<Core::WeakPtr> queue_;
   std::thread thread_;
+  std::atomic<bool> stop_{false};
 };
 
 ObserverManager::ObserverManager() {
@@ -172,7 +181,7 @@ void ObserverManager::scheduleCurrent(Function<void()> task) {
   currentQueue_->add(std::move(task));
 }
 
-void ObserverManager::scheduleNext(Core::Ptr core) {
+void ObserverManager::scheduleNext(Core::WeakPtr core) {
   nextQueue_->add(std::move(core));
 }
 
index 5e206dfb54087fa75028c63be6ed161be7906968..cfb1e70a05eef1a04c46e2bf8c98dbc9b534ca0d 100644 (file)
@@ -93,19 +93,19 @@ class ObserverManager {
     return future;
   }
 
-  static void scheduleRefreshNewVersion(Core::Ptr core) {
-    if (core->getVersion() == 0) {
-      scheduleRefresh(std::move(core), 1).get();
-      return;
-    }
-
+  static void scheduleRefreshNewVersion(Core::WeakPtr coreWeak) {
     auto instance = getInstance();
 
     if (!instance) {
       return;
     }
 
-    instance->scheduleNext(std::move(core));
+    instance->scheduleNext(std::move(coreWeak));
+  }
+
+  static void initCore(Core::Ptr core) {
+    DCHECK(core->getVersion() == 0);
+    scheduleRefresh(std::move(core), 1).get();
   }
 
   class DependencyRecorder {
@@ -189,7 +189,7 @@ class ObserverManager {
   struct Singleton;
 
   void scheduleCurrent(Function<void()>);
-  void scheduleNext(Core::Ptr);
+  void scheduleNext(Core::WeakPtr);
 
   class CurrentQueue;
   class NextQueue;
index 62fcf57bc9b1cb5cdca2f1c15ec59752ea9c394f..ed372f5aadd2b63354b09a44d05ab48ae2d64834 100644 (file)
@@ -262,3 +262,62 @@ TEST(Observer, TLObserver) {
   k = std::make_unique<folly::observer::TLObserver<int>>(createTLObserver(41));
   EXPECT_EQ(41, ***k);
 }
+
+TEST(Observer, SubscribeCallback) {
+  static auto mainThreadId = std::this_thread::get_id();
+  static std::function<void()> updatesCob;
+  static bool slowGet = false;
+  static std::atomic<size_t> getCallsStart{0};
+  static std::atomic<size_t> getCallsFinish{0};
+
+  struct Observable {
+    ~Observable() {
+      EXPECT_EQ(mainThreadId, std::this_thread::get_id());
+    }
+  };
+  struct Traits {
+    using element_type = int;
+    static std::shared_ptr<const int> get(Observable&) {
+      ++getCallsStart;
+      if (slowGet) {
+        /* sleep override */ std::this_thread::sleep_for(
+            std::chrono::seconds{2});
+      }
+      ++getCallsFinish;
+      return std::make_shared<const int>(42);
+    }
+
+    static void subscribe(Observable&, std::function<void()> cob) {
+      updatesCob = std::move(cob);
+    }
+
+    static void unsubscribe(Observable&) {}
+  };
+
+  std::thread cobThread;
+  {
+    auto observer =
+        folly::observer::ObserverCreator<Observable, Traits>().getObserver();
+
+    EXPECT_TRUE(updatesCob);
+    EXPECT_EQ(2, getCallsStart);
+    EXPECT_EQ(2, getCallsFinish);
+
+    updatesCob();
+    EXPECT_EQ(3, getCallsStart);
+    EXPECT_EQ(3, getCallsFinish);
+
+    slowGet = true;
+    cobThread = std::thread([] { updatesCob(); });
+    /* sleep override */ std::this_thread::sleep_for(std::chrono::seconds{1});
+    EXPECT_EQ(4, getCallsStart);
+    EXPECT_EQ(3, getCallsFinish);
+
+    // Observer is destroyed here
+  }
+
+  // Make sure that destroying the observer actually joined the updates callback
+  EXPECT_EQ(4, getCallsStart);
+  EXPECT_EQ(4, getCallsFinish);
+  cobThread.join();
+}