Make RequestContext provider overridable in order to save cost of setContext() on...
[folly.git] / folly / io / async / test / RequestContextTest.cpp
index 750ff7fac9bd0c3c2ad7c2f4e737aaf303d53e37..2ae4123b086fb687cfef49b0cbd51366832a5509 100644 (file)
@@ -77,6 +77,49 @@ TEST(RequestContext, SimpleTest) {
   EXPECT_TRUE(nullptr != RequestContext::get());
 }
 
+TEST(RequestContext, nonDefaultContextsAreThreadLocal) {
+  RequestContext* ctx1 = nullptr;
+  RequestContext* ctx2 = nullptr;
+
+  std::vector<std::thread> ts;
+  for (size_t i = 0; i < 2; ++i) {
+    auto*& ctx = (i == 0 ? ctx1 : ctx2);
+    ts.emplace_back([&ctx]() {
+      RequestContext::create();
+      ctx = RequestContext::get();
+    });
+  }
+  for (auto& t : ts) {
+    t.join();
+  }
+
+  EXPECT_NE(nullptr, ctx1);
+  EXPECT_NE(nullptr, ctx2);
+  EXPECT_NE(ctx1, ctx2);
+}
+
+TEST(RequestContext, customRequestContextProvider) {
+  auto customContext = std::make_shared<RequestContext>();
+  auto customProvider = [&customContext]() -> std::shared_ptr<RequestContext>& {
+    return customContext;
+  };
+
+  auto* const originalContext = RequestContext::get();
+  EXPECT_NE(nullptr, originalContext);
+
+  // Install new RequestContext provider
+  auto originalProvider =
+      RequestContext::setRequestContextProvider(std::move(customProvider));
+
+  auto* const newContext = RequestContext::get();
+  EXPECT_EQ(customContext.get(), newContext);
+  EXPECT_NE(originalContext, newContext);
+
+  // Restore original RequestContext provider
+  RequestContext::setRequestContextProvider(std::move(originalProvider));
+  EXPECT_EQ(originalContext, RequestContext::get());
+}
+
 TEST(RequestContext, setIfAbsentTest) {
   EXPECT_TRUE(RequestContext::get() != nullptr);