From 0416e1ea440f816d5543feb5e834f09906ddfcab Mon Sep 17 00:00:00 2001
From: Andrii Grynenko <andrii@fb.com>
Date: Wed, 9 Nov 2016 20:19:57 -0800
Subject: [PATCH] Fix wrong use of upgrade lock

Reviewed By: yfeldblum, nbronson

Differential Revision: D4149681

fbshipit-source-id: 37bd1b0b7d1ad6e6fa813228307abebfe772012f
---
 folly/Singleton-inl.h |  18 +++-----
 folly/Singleton.cpp   | 105 +++++++++++++++++++-----------------------
 folly/Singleton.h     |  43 +++++++++--------
 3 files changed, 79 insertions(+), 87 deletions(-)

diff --git a/folly/Singleton-inl.h b/folly/Singleton-inl.h
index 19ed5c52..4a6c071b 100644
--- a/folly/Singleton-inl.h
+++ b/folly/Singleton-inl.h
@@ -72,12 +72,11 @@ void SingletonHolder<T>::registerSingletonMock(CreateFunc c, TeardownFunc t) {
   destroyInstance();
 
   {
-    RWSpinLock::WriteHolder wh(&vault_.mutex_);
+    auto creationOrder = vault_.creationOrder_.wlock();
 
-    auto it = std::find(
-        vault_.creation_order_.begin(), vault_.creation_order_.end(), type());
-    if (it != vault_.creation_order_.end()) {
-      vault_.creation_order_.erase(it);
+    auto it = std::find(creationOrder->begin(), creationOrder->end(), type());
+    if (it != creationOrder->end()) {
+      creationOrder->erase(it);
     }
   }
 
@@ -224,8 +223,8 @@ void SingletonHolder<T>::createInstance() {
 
   creating_thread_.store(std::this_thread::get_id(), std::memory_order_release);
 
-  RWSpinLock::ReadHolder rh(&vault_.stateMutex_);
-  if (vault_.state_ == SingletonVault::SingletonVaultState::Quiescing) {
+  auto state = vault_.state_.rlock();
+  if (state->state == SingletonVault::SingletonVaultState::Quiescing) {
     if (vault_.type_ != SingletonVault::Type::Relaxed) {
       LOG(FATAL) << "Requesting singleton after vault was destroyed.";
     }
@@ -278,10 +277,7 @@ void SingletonHolder<T>::createInstance() {
   // may access instance and instance_weak w/o synchronization.
   state_.store(SingletonHolderState::Living, std::memory_order_release);
 
-  {
-    RWSpinLock::WriteHolder wh(&vault_.mutex_);
-    vault_.creation_order_.push_back(type());
-  }
+  vault_.creationOrder_.wlock()->push_back(type());
 }
 
 }
diff --git a/folly/Singleton.cpp b/folly/Singleton.cpp
index 4a690c72..fe8d7db7 100644
--- a/folly/Singleton.cpp
+++ b/folly/Singleton.cpp
@@ -73,48 +73,41 @@ FatalHelper __attribute__ ((__init_priority__ (101))) fatalHelper;
 SingletonVault::~SingletonVault() { destroyInstances(); }
 
 void SingletonVault::registerSingleton(detail::SingletonHolderBase* entry) {
-  RWSpinLock::ReadHolder rh(&stateMutex_);
+  auto state = state_.rlock();
+  stateCheck(SingletonVaultState::Running, *state);
 
-  stateCheck(SingletonVaultState::Running);
-
-  if (UNLIKELY(registrationComplete_)) {
+  if (UNLIKELY(state->registrationComplete)) {
     LOG(ERROR) << "Registering singleton after registrationComplete().";
   }
 
-  RWSpinLock::ReadHolder rhMutex(&mutex_);
-  CHECK_THROW(singletons_.find(entry->type()) == singletons_.end(),
-              std::logic_error);
-
-  RWSpinLock::UpgradedHolder wh(&mutex_);
-  singletons_[entry->type()] = entry;
+  auto singletons = singletons_.wlock();
+  CHECK_THROW(
+      singletons->emplace(entry->type(), entry).second, std::logic_error);
 }
 
 void SingletonVault::addEagerInitSingleton(detail::SingletonHolderBase* entry) {
-  RWSpinLock::ReadHolder rh(&stateMutex_);
-
-  stateCheck(SingletonVaultState::Running);
+  auto state = state_.rlock();
+  stateCheck(SingletonVaultState::Running, *state);
 
-  if (UNLIKELY(registrationComplete_)) {
+  if (UNLIKELY(state->registrationComplete)) {
     LOG(ERROR) << "Registering for eager-load after registrationComplete().";
   }
 
-  RWSpinLock::ReadHolder rhMutex(&mutex_);
-  CHECK_THROW(singletons_.find(entry->type()) != singletons_.end(),
-              std::logic_error);
+  CHECK_THROW(singletons_.rlock()->count(entry->type()), std::logic_error);
 
-  RWSpinLock::UpgradedHolder wh(&mutex_);
-  eagerInitSingletons_.insert(entry);
+  auto eagerInitSingletons = eagerInitSingletons_.wlock();
+  eagerInitSingletons->insert(entry);
 }
 
 void SingletonVault::registrationComplete() {
   std::atexit([](){ SingletonVault::singleton()->destroyInstances(); });
 
-  RWSpinLock::WriteHolder wh(&stateMutex_);
-
-  stateCheck(SingletonVaultState::Running);
+  auto state = state_.wlock();
+  stateCheck(SingletonVaultState::Running, *state);
 
+  auto singletons = singletons_.rlock();
   if (type_ == Type::Strict) {
-    for (const auto& p : singletons_) {
+    for (const auto& p : *singletons) {
       if (p.second->hasLiveInstance()) {
         throw std::runtime_error(
             "Singleton created before registration was complete.");
@@ -122,38 +115,37 @@ void SingletonVault::registrationComplete() {
     }
   }
 
-  registrationComplete_ = true;
+  state->registrationComplete = true;
 }
 
 void SingletonVault::doEagerInit() {
-  std::unordered_set<detail::SingletonHolderBase*> singletonSet;
   {
-    RWSpinLock::ReadHolder rh(&stateMutex_);
-    stateCheck(SingletonVaultState::Running);
-    if (UNLIKELY(!registrationComplete_)) {
+    auto state = state_.rlock();
+    stateCheck(SingletonVaultState::Running, *state);
+    if (UNLIKELY(!state->registrationComplete)) {
       throw std::logic_error("registrationComplete() not yet called");
     }
-    singletonSet = eagerInitSingletons_; // copy set of pointers
   }
 
-  for (auto *single : singletonSet) {
+  auto eagerInitSingletons = eagerInitSingletons_.rlock();
+  for (auto* single : *eagerInitSingletons) {
     single->createInstance();
   }
 }
 
 void SingletonVault::doEagerInitVia(Executor& exe, folly::Baton<>* done) {
-  std::unordered_set<detail::SingletonHolderBase*> singletonSet;
   {
-    RWSpinLock::ReadHolder rh(&stateMutex_);
-    stateCheck(SingletonVaultState::Running);
-    if (UNLIKELY(!registrationComplete_)) {
+    auto state = state_.rlock();
+    stateCheck(SingletonVaultState::Running, *state);
+    if (UNLIKELY(!state->registrationComplete)) {
       throw std::logic_error("registrationComplete() not yet called");
     }
-    singletonSet = eagerInitSingletons_; // copy set of pointers
   }
 
-  auto countdown = std::make_shared<std::atomic<size_t>>(singletonSet.size());
-  for (auto* single : singletonSet) {
+  auto eagerInitSingletons = eagerInitSingletons_.rlock();
+  auto countdown =
+      std::make_shared<std::atomic<size_t>>(eagerInitSingletons->size());
+  for (auto* single : *eagerInitSingletons) {
     // countdown is retained by shared_ptr, and will be alive until last lambda
     // is done.  notifyBaton is provided by the caller, and expected to remain
     // present (if it's non-nullptr).  singletonSet can go out of scope but
@@ -179,36 +171,35 @@ void SingletonVault::doEagerInitVia(Executor& exe, folly::Baton<>* done) {
 }
 
 void SingletonVault::destroyInstances() {
-  RWSpinLock::WriteHolder state_wh(&stateMutex_);
-
-  if (state_ == SingletonVaultState::Quiescing) {
+  auto stateW = state_.wlock();
+  if (stateW->state == SingletonVaultState::Quiescing) {
     return;
   }
-  state_ = SingletonVaultState::Quiescing;
-
-  RWSpinLock::ReadHolder state_rh(std::move(state_wh));
+  stateW->state = SingletonVaultState::Quiescing;
 
+  auto stateR = stateW.moveFromWriteToRead();
   {
-    RWSpinLock::ReadHolder rh(&mutex_);
+    auto singletons = singletons_.rlock();
+    auto creationOrder = creationOrder_.rlock();
 
-    CHECK_GE(singletons_.size(), creation_order_.size());
+    CHECK_GE(singletons->size(), creationOrder->size());
 
     // Release all ReadMostlyMainPtrs at once
     {
       ReadMostlyMainPtrDeleter<> deleter;
-      for (auto& singleton_type : creation_order_) {
-        singletons_[singleton_type]->preDestroyInstance(deleter);
+      for (auto& singleton_type : *creationOrder) {
+        singletons->at(singleton_type)->preDestroyInstance(deleter);
       }
     }
 
-    for (auto type_iter = creation_order_.rbegin();
-         type_iter != creation_order_.rend();
+    for (auto type_iter = creationOrder->rbegin();
+         type_iter != creationOrder->rend();
          ++type_iter) {
-      singletons_[*type_iter]->destroyInstance();
+      singletons->at(*type_iter)->destroyInstance();
     }
 
-    for (auto& singleton_type: creation_order_) {
-      auto singleton = singletons_[singleton_type];
+    for (auto& singleton_type : *creationOrder) {
+      auto singleton = singletons->at(singleton_type);
       if (!singleton->hasLiveInstance()) {
         continue;
       }
@@ -218,17 +209,17 @@ void SingletonVault::destroyInstances() {
   }
 
   {
-    RWSpinLock::WriteHolder wh(&mutex_);
-    creation_order_.clear();
+    auto creationOrder = creationOrder_.wlock();
+    creationOrder->clear();
   }
 }
 
 void SingletonVault::reenableInstances() {
-  RWSpinLock::WriteHolder state_wh(&stateMutex_);
+  auto state = state_.wlock();
 
-  stateCheck(SingletonVaultState::Quiescing);
+  stateCheck(SingletonVaultState::Quiescing, *state);
 
-  state_ = SingletonVaultState::Running;
+  state->state = SingletonVaultState::Running;
 }
 
 void SingletonVault::scheduleDestroyInstances() {
diff --git a/folly/Singleton.h b/folly/Singleton.h
index 126efc97..c377f3ad 100644
--- a/folly/Singleton.h
+++ b/folly/Singleton.h
@@ -112,14 +112,15 @@
 
 #pragma once
 #include <folly/Baton.h>
+#include <folly/Demangle.h>
 #include <folly/Exception.h>
+#include <folly/Executor.h>
 #include <folly/Hash.h>
 #include <folly/Memory.h>
 #include <folly/RWSpinLock.h>
-#include <folly/Demangle.h>
-#include <folly/Executor.h>
-#include <folly/experimental/ReadMostlySharedPtr.h>
+#include <folly/Synchronized.h>
 #include <folly/detail/StaticSingletonManager.h>
+#include <folly/experimental/ReadMostlySharedPtr.h>
 
 #include <algorithm>
 #include <atomic>
@@ -409,9 +410,7 @@ class SingletonVault {
 
   // For testing; how many registered and living singletons we have.
   size_t registeredSingletonCount() const {
-    RWSpinLock::ReadHolder rh(&mutex_);
-
-    return singletons_.size();
+    return singletons_.rlock()->size();
   }
 
   /**
@@ -421,10 +420,10 @@ class SingletonVault {
   bool eagerInitComplete() const;
 
   size_t livingSingletonCount() const {
-    RWSpinLock::ReadHolder rh(&mutex_);
+    auto singletons = singletons_.rlock();
 
     size_t ret = 0;
-    for (const auto& p : singletons_) {
+    for (const auto& p : *singletons) {
       if (p.second->hasLiveInstance()) {
         ++ret;
       }
@@ -470,13 +469,20 @@ class SingletonVault {
     Quiescing,
   };
 
+  struct State {
+    SingletonVaultState state{SingletonVaultState::Running};
+    bool registrationComplete{false};
+  };
+
   // Each singleton in the vault can be in two states: dead
   // (registered but never created), living (CreateFunc returned an instance).
 
-  void stateCheck(SingletonVaultState expected,
-                  const char* msg="Unexpected singleton state change") {
-    if (expected != state_) {
-        throw std::logic_error(msg);
+  static void stateCheck(
+      SingletonVaultState expected,
+      const State& state,
+      const char* msg = "Unexpected singleton state change") {
+    if (expected != state.state) {
+      throw std::logic_error(msg);
     }
   }
 
@@ -496,14 +502,13 @@ class SingletonVault {
   typedef std::unordered_map<detail::TypeDescriptor,
                              detail::SingletonHolderBase*,
                              detail::TypeDescriptorHasher> SingletonMap;
+  folly::Synchronized<SingletonMap> singletons_;
+  folly::Synchronized<std::unordered_set<detail::SingletonHolderBase*>>
+      eagerInitSingletons_;
+  folly::Synchronized<std::vector<detail::TypeDescriptor>> creationOrder_;
+
+  folly::Synchronized<State> state_;
 
-  mutable folly::RWSpinLock mutex_;
-  SingletonMap singletons_;
-  std::unordered_set<detail::SingletonHolderBase*> eagerInitSingletons_;
-  std::vector<detail::TypeDescriptor> creation_order_;
-  SingletonVaultState state_{SingletonVaultState::Running};
-  bool registrationComplete_{false};
-  folly::RWSpinLock stateMutex_;
   Type type_{Type::Relaxed};
 };
 
-- 
2.34.1