Allow accept callbacks to be short-circuited in primary event-base
authorPetr Lapukhov <petr@fb.com>
Wed, 17 Aug 2016 12:39:00 +0000 (05:39 -0700)
committerFacebook Github Bot 3 <facebook-github-bot-3-bot@fb.com>
Wed, 17 Aug 2016 12:53:28 +0000 (05:53 -0700)
Summary:
It looks like we were effectively avoiding short-circuiting callbacks submitted for execution in primary event-base (evb == nulptr). The check was there, but it was never effective, since on `addAcceptCallback` we would mask the `nullptr` with our event base pointer.

I see two ways to fix that: either modify the check

    if (info->eventBase == nullptr) { ...} on line 834

to compare to the presently attached event base, or store `eventBase = nullptr` into callbacks_ list (CallbackInfo struct). The second approach requires more changes (implemented here) but allows the caller to still submit callbacks for execution via notification queue event in primary event base by supplying eventBase parameter != nullptr in addAcceptCallback. I therefore chose the second approach.

The existing unit-tests needed modification to avoid using the "broken" nullptr semantics (most cases were assuming it would be using notification queue signaling). I quickly looked at fbcode, and it looks like we only have a few cases of addAcceptCallback() with nullptr, the unit-tests for those are passing.

NOTE: The removeAcceptCallback() semantics is different with regards to eventBase; nullptr here means "scan all callbacks regardless of event-base they belong to".

Reviewed By: djwatson

Differential Revision: D3714697

fbshipit-source-id: 2362bcff86a7e0604914b1cb7f1471fe4d03e78e

folly/io/async/AsyncServerSocket.cpp
folly/io/async/test/AsyncSocketTest2.cpp

index cbabc6f429ea2881b311003086e281a8f0136997..99e1feb1da5b7cebc4ca887a161bfe149792b85d 100644 (file)
@@ -125,7 +125,7 @@ class AsyncServerSocket::BackoffTimeout : public AsyncTimeout {
  public:
   // Disallow copy, move, and default constructors.
   BackoffTimeout(BackoffTimeout&&) = delete;
-  BackoffTimeout(AsyncServerSocket* socket)
+  explicit BackoffTimeout(AsyncServerSocket* socket)
       : AsyncTimeout(socket->getEventBase()), socket_(socket) {}
 
   void timeoutExpired() noexcept override { socket_->backoffTimeoutExpired(); }
@@ -219,7 +219,14 @@ int AsyncServerSocket::stopAccepting(int shutdownFlags) {
   for (std::vector<CallbackInfo>::iterator it = callbacksCopy.begin();
        it != callbacksCopy.end();
        ++it) {
-    it->consumer->stop(it->eventBase, it->callback);
+    // consumer may not be set if we are running in primary event base
+    if (it->consumer) {
+      DCHECK(it->eventBase);
+      it->consumer->stop(it->eventBase, it->callback);
+    } else {
+      DCHECK(it->callback);
+      it->callback->acceptStopped();
+    }
   }
 
   return result;
@@ -513,12 +520,23 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback,
   // start accepting once the callback is installed.
   bool runStartAccepting = accepting_ && callbacks_.empty();
 
+  callbacks_.emplace_back(callback, eventBase);
+
+  SCOPE_SUCCESS {
+    // If this is the first accept callback and we are supposed to be accepting,
+    // start accepting.
+    if (runStartAccepting) {
+      startAccepting();
+    }
+  };
+
   if (!eventBase) {
-    eventBase = eventBase_; // Run in AsyncServerSocket's eventbase
+    // Run in AsyncServerSocket's eventbase; notify that we are
+    // starting to accept connections
+    callback->acceptStarted();
+    return;
   }
 
-  callbacks_.emplace_back(callback, eventBase);
-
   // Start the remote acceptor.
   //
   // It would be nice if we could avoid starting the remote acceptor if
@@ -538,12 +556,6 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback,
     throw;
   }
   callbacks_.back().consumer = acceptor;
-
-  // If this is the first accept callback and we are supposed to be accepting,
-  // start accepting.
-  if (runStartAccepting) {
-    startAccepting();
-  }
 }
 
 void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback,
@@ -590,7 +602,16 @@ void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback,
     }
   }
 
-  info.consumer->stop(info.eventBase, info.callback);
+  if (info.consumer) {
+    // consumer could be nullptr is we run callbacks in primary event
+    // base
+    DCHECK(info.eventBase);
+    info.consumer->stop(info.eventBase, info.callback);
+  } else {
+    // callback invoked in the primary event base, just call directly
+    DCHECK(info.callback);
+    callback->acceptStopped();
+  }
 
   // If we are supposed to be accepting but the last accept callback
   // was removed, unregister for events until a callback is added.
index 35958c32fdcff3b672b034a4a738ff888594f721..4fa0a02ea37d9172259d5dd4b820e3e8c28e50c4 100644 (file)
@@ -1733,12 +1733,12 @@ TEST(AsyncSocketTest, ServerAcceptOptions) {
   TestAcceptCallback acceptCallback;
   acceptCallback.setConnectionAcceptedFn(
       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
-        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
       });
   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
-    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
   });
-  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
   serverSocket->startAccepting();
 
   // Connect to the server socket
@@ -1850,13 +1850,13 @@ TEST(AsyncSocketTest, RemoveAcceptCallback) {
         serverSocket->removeAcceptCallback(&cb7, nullptr);
       });
 
-  serverSocket->addAcceptCallback(&cb1, nullptr);
-  serverSocket->addAcceptCallback(&cb2, nullptr);
-  serverSocket->addAcceptCallback(&cb3, nullptr);
-  serverSocket->addAcceptCallback(&cb4, nullptr);
-  serverSocket->addAcceptCallback(&cb5, nullptr);
-  serverSocket->addAcceptCallback(&cb6, nullptr);
-  serverSocket->addAcceptCallback(&cb7, nullptr);
+  serverSocket->addAcceptCallback(&cb1, &eventBase);
+  serverSocket->addAcceptCallback(&cb2, &eventBase);
+  serverSocket->addAcceptCallback(&cb3, &eventBase);
+  serverSocket->addAcceptCallback(&cb4, &eventBase);
+  serverSocket->addAcceptCallback(&cb5, &eventBase);
+  serverSocket->addAcceptCallback(&cb6, &eventBase);
+  serverSocket->addAcceptCallback(&cb7, &eventBase);
   serverSocket->startAccepting();
 
   // Make several connections to the socket
@@ -1959,14 +1959,14 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
   cb1.setConnectionAcceptedFn(
       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
         CHECK_EQ(thread_id, std::this_thread::get_id());
-        serverSocket->removeAcceptCallback(&cb1, nullptr);
+        serverSocket->removeAcceptCallback(&cb1, &eventBase);
       });
   cb1.setAcceptStoppedFn([&](){
     CHECK_EQ(thread_id, std::this_thread::get_id());
   });
 
   // Test having callbacks remove other callbacks before them on the list,
-  serverSocket->addAcceptCallback(&cb1, nullptr);
+  serverSocket->addAcceptCallback(&cb1, &eventBase);
   serverSocket->startAccepting();
 
   // Make several connections to the socket
@@ -1999,20 +1999,22 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
 }
 
 void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
+  EventBase* eventBase = serverSocket->getEventBase();
+  CHECK(eventBase);
+
   // Add a callback to accept one connection then stop accepting
   TestAcceptCallback acceptCallback;
   acceptCallback.setConnectionAcceptedFn(
       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
-        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+        serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
       });
   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
-    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
   });
-  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->addAcceptCallback(&acceptCallback, eventBase);
   serverSocket->startAccepting();
 
   // Connect to the server socket
-  EventBase* eventBase = serverSocket->getEventBase();
   folly::SocketAddress serverAddress;
   serverSocket->getAddress(&serverAddress);
   AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
@@ -2181,12 +2183,12 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) {
   TestAcceptCallback acceptCallback;
   acceptCallback.setConnectionAcceptedFn(
       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
-        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
       });
   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
-    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
   });
-  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
   serverSocket->startAccepting();
 
   // Connect to the server socket
@@ -2232,7 +2234,7 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
   });
-  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
   serverSocket->startAccepting();
 
   // Connect to the server socket
@@ -2253,6 +2255,61 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
 }
 
+TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
+  EventBase eventBase;
+  TestConnectionEventCallback connectionEventCallback;
+
+  // Create a server socket
+  std::shared_ptr<AsyncServerSocket> serverSocket(
+      AsyncServerSocket::newSocket(&eventBase));
+  serverSocket->setConnectionEventCallback(&connectionEventCallback);
+  serverSocket->bind(0);
+  serverSocket->listen(16);
+  folly::SocketAddress serverAddress;
+  serverSocket->getAddress(&serverAddress);
+
+  // Add a callback to accept one connection then stop the loop
+  TestAcceptCallback acceptCallback;
+  acceptCallback.setConnectionAcceptedFn(
+      [&](int /* fd */, const folly::SocketAddress& /* addr */) {
+        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+      });
+  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
+    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+  });
+  bool acceptStartedFlag{false};
+  acceptCallback.setAcceptStartedFn([&acceptStartedFlag](){
+    acceptStartedFlag = true;
+  });
+  bool acceptStoppedFlag{false};
+  acceptCallback.setAcceptStoppedFn([&acceptStoppedFlag](){
+    acceptStoppedFlag = true;
+  });
+  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->startAccepting();
+
+  // Connect to the server socket
+  std::shared_ptr<AsyncSocket> socket(
+      AsyncSocket::newSocket(&eventBase, serverAddress));
+
+  eventBase.loop();
+
+  ASSERT_TRUE(acceptStartedFlag);
+  ASSERT_TRUE(acceptStoppedFlag);
+  // Validate the connection event counters
+  ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
+  ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
+  ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
+  ASSERT_EQ(
+      connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
+  ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
+}
+
+
+
 /**
  * Test AsyncServerSocket::getNumPendingMessagesInQueue()
  */