Fix passing RequestContext to executor thread
authorAlex Landau <alandau@fb.com>
Wed, 24 Jun 2015 23:20:08 +0000 (16:20 -0700)
committerSara Golemon <sgolemon@fb.com>
Fri, 26 Jun 2015 18:45:41 +0000 (11:45 -0700)
Summary: If x->add() executes its lambda on a different thread and
doesn't pass the context on its own, the callback wouldn't
have the correct context set.

Reviewed By: @djwatson

Differential Revision: D2189318

folly/futures/detail/Core.h
folly/futures/test/FutureTest.cpp

index d21d74ce419a7895e02b6acef4dc838ca4d9a95b..f00188622d5f473c6cbdeb071d2adb4e9f58eeab 100644 (file)
@@ -313,8 +313,6 @@ class Core {
   }
 
   void doCallback() {
-    RequestContext::setContext(context_);
-
     Executor* x = executor_;
     int8_t priority;
     if (x) {
@@ -333,20 +331,24 @@ class Core {
         if (LIKELY(x->getNumPriorities() == 1)) {
           x->add([this]() mutable {
             SCOPE_EXIT { detachOne(); };
+            RequestContext::setContext(context_);
             callback_(std::move(*result_));
           });
         } else {
           x->addWithPriority([this]() mutable {
             SCOPE_EXIT { detachOne(); };
+            RequestContext::setContext(context_);
             callback_(std::move(*result_));
           }, priority);
         }
       } catch (...) {
         --attached_; // Account for extra ++attached_ before try
+        RequestContext::setContext(context_);
         result_ = Try<T>(exception_wrapper(std::current_exception()));
         callback_(std::move(*result_));
       }
     } else {
+      RequestContext::setContext(context_);
       callback_(std::move(*result_));
     }
   }
index 723b358e7052bc64d1c9192b0b990d555a63fa59..d6e1990c859b1c6a5df10392ac0ae7382dd8d5f6 100644 (file)
@@ -680,3 +680,57 @@ TEST(Future, thenDynamic) {
   p.setValue(2);
   EXPECT_EQ(f.get(), 5);
 }
+
+TEST(Future, RequestContext) {
+  class NewThreadExecutor : public Executor {
+   public:
+    ~NewThreadExecutor() override {
+      std::for_each(v_.begin(), v_.end(), [](std::thread& t){ t.join(); });
+    }
+    void add(Func f) override {
+      if (throwsOnAdd_) { throw std::exception(); }
+      v_.emplace_back(std::move(f));
+    }
+    void addWithPriority(Func f, int8_t prio) override { add(std::move(f)); }
+    uint8_t getNumPriorities() const override { return numPriorities_; }
+
+    void setHandlesPriorities() { numPriorities_ = 2; }
+    void setThrowsOnAdd() { throwsOnAdd_ = true; }
+   private:
+    std::vector<std::thread> v_;
+    uint8_t numPriorities_ = 1;
+    bool throwsOnAdd_ = false;
+  };
+
+  struct MyRequestData : RequestData {
+    MyRequestData(bool value = false) : value(value) {}
+    bool value;
+  };
+
+  NewThreadExecutor e;
+  RequestContext::create();
+  RequestContext::get()->setContextData("key",
+      folly::make_unique<MyRequestData>(true));
+  auto checker = [](int lineno) {
+    return [lineno](Try<int>&& t) {
+      auto d = static_cast<MyRequestData*>(
+        RequestContext::get()->getContextData("key"));
+      EXPECT_TRUE(d && d->value) << "on line " << lineno;
+    };
+  };
+
+  makeFuture(1).via(&e).then(checker(__LINE__));
+
+  e.setHandlesPriorities();
+  makeFuture(2).via(&e).then(checker(__LINE__));
+
+  Promise<int> p1, p2;
+  p1.getFuture().then(checker(__LINE__));
+
+  e.setThrowsOnAdd();
+  p2.getFuture().via(&e).then(checker(__LINE__));
+
+  RequestContext::create();
+  p1.setValue(3);
+  p2.setValue(4);
+}