Allow accept callbacks to be short-circuited in primary event-base
[folly.git] / folly / io / async / test / AsyncSocketTest2.cpp
index 35958c32fdcff3b672b034a4a738ff888594f721..4fa0a02ea37d9172259d5dd4b820e3e8c28e50c4 100644 (file)
@@ -1733,12 +1733,12 @@ TEST(AsyncSocketTest, ServerAcceptOptions) {
   TestAcceptCallback acceptCallback;
   acceptCallback.setConnectionAcceptedFn(
       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
-        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
       });
   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
-    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
   });
-  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
   serverSocket->startAccepting();
 
   // Connect to the server socket
@@ -1850,13 +1850,13 @@ TEST(AsyncSocketTest, RemoveAcceptCallback) {
         serverSocket->removeAcceptCallback(&cb7, nullptr);
       });
 
-  serverSocket->addAcceptCallback(&cb1, nullptr);
-  serverSocket->addAcceptCallback(&cb2, nullptr);
-  serverSocket->addAcceptCallback(&cb3, nullptr);
-  serverSocket->addAcceptCallback(&cb4, nullptr);
-  serverSocket->addAcceptCallback(&cb5, nullptr);
-  serverSocket->addAcceptCallback(&cb6, nullptr);
-  serverSocket->addAcceptCallback(&cb7, nullptr);
+  serverSocket->addAcceptCallback(&cb1, &eventBase);
+  serverSocket->addAcceptCallback(&cb2, &eventBase);
+  serverSocket->addAcceptCallback(&cb3, &eventBase);
+  serverSocket->addAcceptCallback(&cb4, &eventBase);
+  serverSocket->addAcceptCallback(&cb5, &eventBase);
+  serverSocket->addAcceptCallback(&cb6, &eventBase);
+  serverSocket->addAcceptCallback(&cb7, &eventBase);
   serverSocket->startAccepting();
 
   // Make several connections to the socket
@@ -1959,14 +1959,14 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
   cb1.setConnectionAcceptedFn(
       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
         CHECK_EQ(thread_id, std::this_thread::get_id());
-        serverSocket->removeAcceptCallback(&cb1, nullptr);
+        serverSocket->removeAcceptCallback(&cb1, &eventBase);
       });
   cb1.setAcceptStoppedFn([&](){
     CHECK_EQ(thread_id, std::this_thread::get_id());
   });
 
   // Test having callbacks remove other callbacks before them on the list,
-  serverSocket->addAcceptCallback(&cb1, nullptr);
+  serverSocket->addAcceptCallback(&cb1, &eventBase);
   serverSocket->startAccepting();
 
   // Make several connections to the socket
@@ -1999,20 +1999,22 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
 }
 
 void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
+  EventBase* eventBase = serverSocket->getEventBase();
+  CHECK(eventBase);
+
   // Add a callback to accept one connection then stop accepting
   TestAcceptCallback acceptCallback;
   acceptCallback.setConnectionAcceptedFn(
       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
-        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+        serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
       });
   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
-    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
   });
-  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->addAcceptCallback(&acceptCallback, eventBase);
   serverSocket->startAccepting();
 
   // Connect to the server socket
-  EventBase* eventBase = serverSocket->getEventBase();
   folly::SocketAddress serverAddress;
   serverSocket->getAddress(&serverAddress);
   AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
@@ -2181,12 +2183,12 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) {
   TestAcceptCallback acceptCallback;
   acceptCallback.setConnectionAcceptedFn(
       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
-        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
       });
   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
-    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
   });
-  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
   serverSocket->startAccepting();
 
   // Connect to the server socket
@@ -2232,7 +2234,7 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
   });
-  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
   serverSocket->startAccepting();
 
   // Connect to the server socket
@@ -2253,6 +2255,61 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
 }
 
+TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
+  EventBase eventBase;
+  TestConnectionEventCallback connectionEventCallback;
+
+  // Create a server socket
+  std::shared_ptr<AsyncServerSocket> serverSocket(
+      AsyncServerSocket::newSocket(&eventBase));
+  serverSocket->setConnectionEventCallback(&connectionEventCallback);
+  serverSocket->bind(0);
+  serverSocket->listen(16);
+  folly::SocketAddress serverAddress;
+  serverSocket->getAddress(&serverAddress);
+
+  // Add a callback to accept one connection then stop the loop
+  TestAcceptCallback acceptCallback;
+  acceptCallback.setConnectionAcceptedFn(
+      [&](int /* fd */, const folly::SocketAddress& /* addr */) {
+        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+      });
+  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
+    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+  });
+  bool acceptStartedFlag{false};
+  acceptCallback.setAcceptStartedFn([&acceptStartedFlag](){
+    acceptStartedFlag = true;
+  });
+  bool acceptStoppedFlag{false};
+  acceptCallback.setAcceptStoppedFn([&acceptStoppedFlag](){
+    acceptStoppedFlag = true;
+  });
+  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->startAccepting();
+
+  // Connect to the server socket
+  std::shared_ptr<AsyncSocket> socket(
+      AsyncSocket::newSocket(&eventBase, serverAddress));
+
+  eventBase.loop();
+
+  ASSERT_TRUE(acceptStartedFlag);
+  ASSERT_TRUE(acceptStoppedFlag);
+  // Validate the connection event counters
+  ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
+  ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
+  ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
+  ASSERT_EQ(
+      connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
+  ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
+}
+
+
+
 /**
  * Test AsyncServerSocket::getNumPendingMessagesInQueue()
  */