From 9b8488e85fd3018975ce47f80d332e906cbf35fe Mon Sep 17 00:00:00 2001 From: Petr Lapukhov Date: Wed, 17 Aug 2016 05:39:00 -0700 Subject: [PATCH] Allow accept callbacks to be short-circuited in primary event-base Summary: It looks like we were effectively avoiding short-circuiting callbacks submitted for execution in primary event-base (evb == nulptr). The check was there, but it was never effective, since on `addAcceptCallback` we would mask the `nullptr` with our event base pointer. I see two ways to fix that: either modify the check if (info->eventBase == nullptr) { ...} on line 834 to compare to the presently attached event base, or store `eventBase = nullptr` into callbacks_ list (CallbackInfo struct). The second approach requires more changes (implemented here) but allows the caller to still submit callbacks for execution via notification queue event in primary event base by supplying eventBase parameter != nullptr in addAcceptCallback. I therefore chose the second approach. The existing unit-tests needed modification to avoid using the "broken" nullptr semantics (most cases were assuming it would be using notification queue signaling). I quickly looked at fbcode, and it looks like we only have a few cases of addAcceptCallback() with nullptr, the unit-tests for those are passing. NOTE: The removeAcceptCallback() semantics is different with regards to eventBase; nullptr here means "scan all callbacks regardless of event-base they belong to". Reviewed By: djwatson Differential Revision: D3714697 fbshipit-source-id: 2362bcff86a7e0604914b1cb7f1471fe4d03e78e --- folly/io/async/AsyncServerSocket.cpp | 45 ++++++++--- folly/io/async/test/AsyncSocketTest2.cpp | 97 +++++++++++++++++++----- 2 files changed, 110 insertions(+), 32 deletions(-) diff --git a/folly/io/async/AsyncServerSocket.cpp b/folly/io/async/AsyncServerSocket.cpp index cbabc6f4..99e1feb1 100644 --- a/folly/io/async/AsyncServerSocket.cpp +++ b/folly/io/async/AsyncServerSocket.cpp @@ -125,7 +125,7 @@ class AsyncServerSocket::BackoffTimeout : public AsyncTimeout { public: // Disallow copy, move, and default constructors. BackoffTimeout(BackoffTimeout&&) = delete; - BackoffTimeout(AsyncServerSocket* socket) + explicit BackoffTimeout(AsyncServerSocket* socket) : AsyncTimeout(socket->getEventBase()), socket_(socket) {} void timeoutExpired() noexcept override { socket_->backoffTimeoutExpired(); } @@ -219,7 +219,14 @@ int AsyncServerSocket::stopAccepting(int shutdownFlags) { for (std::vector::iterator it = callbacksCopy.begin(); it != callbacksCopy.end(); ++it) { - it->consumer->stop(it->eventBase, it->callback); + // consumer may not be set if we are running in primary event base + if (it->consumer) { + DCHECK(it->eventBase); + it->consumer->stop(it->eventBase, it->callback); + } else { + DCHECK(it->callback); + it->callback->acceptStopped(); + } } return result; @@ -513,12 +520,23 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback, // start accepting once the callback is installed. bool runStartAccepting = accepting_ && callbacks_.empty(); + callbacks_.emplace_back(callback, eventBase); + + SCOPE_SUCCESS { + // If this is the first accept callback and we are supposed to be accepting, + // start accepting. + if (runStartAccepting) { + startAccepting(); + } + }; + if (!eventBase) { - eventBase = eventBase_; // Run in AsyncServerSocket's eventbase + // Run in AsyncServerSocket's eventbase; notify that we are + // starting to accept connections + callback->acceptStarted(); + return; } - callbacks_.emplace_back(callback, eventBase); - // Start the remote acceptor. // // It would be nice if we could avoid starting the remote acceptor if @@ -538,12 +556,6 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback, throw; } callbacks_.back().consumer = acceptor; - - // If this is the first accept callback and we are supposed to be accepting, - // start accepting. - if (runStartAccepting) { - startAccepting(); - } } void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback, @@ -590,7 +602,16 @@ void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback, } } - info.consumer->stop(info.eventBase, info.callback); + if (info.consumer) { + // consumer could be nullptr is we run callbacks in primary event + // base + DCHECK(info.eventBase); + info.consumer->stop(info.eventBase, info.callback); + } else { + // callback invoked in the primary event base, just call directly + DCHECK(info.callback); + callback->acceptStopped(); + } // If we are supposed to be accepting but the last accept callback // was removed, unregister for events until a callback is added. diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp index 35958c32..4fa0a02e 100644 --- a/folly/io/async/test/AsyncSocketTest2.cpp +++ b/folly/io/async/test/AsyncSocketTest2.cpp @@ -1733,12 +1733,12 @@ TEST(AsyncSocketTest, ServerAcceptOptions) { TestAcceptCallback acceptCallback; acceptCallback.setConnectionAcceptedFn( [&](int /* fd */, const folly::SocketAddress& /* addr */) { - serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + serverSocket->removeAcceptCallback(&acceptCallback, &eventBase); }); acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { - serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + serverSocket->removeAcceptCallback(&acceptCallback, &eventBase); }); - serverSocket->addAcceptCallback(&acceptCallback, nullptr); + serverSocket->addAcceptCallback(&acceptCallback, &eventBase); serverSocket->startAccepting(); // Connect to the server socket @@ -1850,13 +1850,13 @@ TEST(AsyncSocketTest, RemoveAcceptCallback) { serverSocket->removeAcceptCallback(&cb7, nullptr); }); - serverSocket->addAcceptCallback(&cb1, nullptr); - serverSocket->addAcceptCallback(&cb2, nullptr); - serverSocket->addAcceptCallback(&cb3, nullptr); - serverSocket->addAcceptCallback(&cb4, nullptr); - serverSocket->addAcceptCallback(&cb5, nullptr); - serverSocket->addAcceptCallback(&cb6, nullptr); - serverSocket->addAcceptCallback(&cb7, nullptr); + serverSocket->addAcceptCallback(&cb1, &eventBase); + serverSocket->addAcceptCallback(&cb2, &eventBase); + serverSocket->addAcceptCallback(&cb3, &eventBase); + serverSocket->addAcceptCallback(&cb4, &eventBase); + serverSocket->addAcceptCallback(&cb5, &eventBase); + serverSocket->addAcceptCallback(&cb6, &eventBase); + serverSocket->addAcceptCallback(&cb7, &eventBase); serverSocket->startAccepting(); // Make several connections to the socket @@ -1959,14 +1959,14 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) { cb1.setConnectionAcceptedFn( [&](int /* fd */, const folly::SocketAddress& /* addr */) { CHECK_EQ(thread_id, std::this_thread::get_id()); - serverSocket->removeAcceptCallback(&cb1, nullptr); + serverSocket->removeAcceptCallback(&cb1, &eventBase); }); cb1.setAcceptStoppedFn([&](){ CHECK_EQ(thread_id, std::this_thread::get_id()); }); // Test having callbacks remove other callbacks before them on the list, - serverSocket->addAcceptCallback(&cb1, nullptr); + serverSocket->addAcceptCallback(&cb1, &eventBase); serverSocket->startAccepting(); // Make several connections to the socket @@ -1999,20 +1999,22 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) { } void serverSocketSanityTest(AsyncServerSocket* serverSocket) { + EventBase* eventBase = serverSocket->getEventBase(); + CHECK(eventBase); + // Add a callback to accept one connection then stop accepting TestAcceptCallback acceptCallback; acceptCallback.setConnectionAcceptedFn( [&](int /* fd */, const folly::SocketAddress& /* addr */) { - serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + serverSocket->removeAcceptCallback(&acceptCallback, eventBase); }); acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { - serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + serverSocket->removeAcceptCallback(&acceptCallback, eventBase); }); - serverSocket->addAcceptCallback(&acceptCallback, nullptr); + serverSocket->addAcceptCallback(&acceptCallback, eventBase); serverSocket->startAccepting(); // Connect to the server socket - EventBase* eventBase = serverSocket->getEventBase(); folly::SocketAddress serverAddress; serverSocket->getAddress(&serverAddress); AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress)); @@ -2181,12 +2183,12 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) { TestAcceptCallback acceptCallback; acceptCallback.setConnectionAcceptedFn( [&](int /* fd */, const folly::SocketAddress& /* addr */) { - serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + serverSocket->removeAcceptCallback(&acceptCallback, &eventBase); }); acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { - serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + serverSocket->removeAcceptCallback(&acceptCallback, &eventBase); }); - serverSocket->addAcceptCallback(&acceptCallback, nullptr); + serverSocket->addAcceptCallback(&acceptCallback, &eventBase); serverSocket->startAccepting(); // Connect to the server socket @@ -2232,7 +2234,7 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) { acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { serverSocket->removeAcceptCallback(&acceptCallback, nullptr); }); - serverSocket->addAcceptCallback(&acceptCallback, nullptr); + serverSocket->addAcceptCallback(&acceptCallback, &eventBase); serverSocket->startAccepting(); // Connect to the server socket @@ -2253,6 +2255,61 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) { ASSERT_EQ(connectionEventCallback.getBackoffError(), 0); } +TEST(AsyncSocketTest, CallbackInPrimaryEventBase) { + EventBase eventBase; + TestConnectionEventCallback connectionEventCallback; + + // Create a server socket + std::shared_ptr serverSocket( + AsyncServerSocket::newSocket(&eventBase)); + serverSocket->setConnectionEventCallback(&connectionEventCallback); + serverSocket->bind(0); + serverSocket->listen(16); + folly::SocketAddress serverAddress; + serverSocket->getAddress(&serverAddress); + + // Add a callback to accept one connection then stop the loop + TestAcceptCallback acceptCallback; + acceptCallback.setConnectionAcceptedFn( + [&](int /* fd */, const folly::SocketAddress& /* addr */) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + bool acceptStartedFlag{false}; + acceptCallback.setAcceptStartedFn([&acceptStartedFlag](){ + acceptStartedFlag = true; + }); + bool acceptStoppedFlag{false}; + acceptCallback.setAcceptStoppedFn([&acceptStoppedFlag](){ + acceptStoppedFlag = true; + }); + serverSocket->addAcceptCallback(&acceptCallback, nullptr); + serverSocket->startAccepting(); + + // Connect to the server socket + std::shared_ptr socket( + AsyncSocket::newSocket(&eventBase, serverAddress)); + + eventBase.loop(); + + ASSERT_TRUE(acceptStartedFlag); + ASSERT_TRUE(acceptStoppedFlag); + // Validate the connection event counters + ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1); + ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0); + ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0); + ASSERT_EQ( + connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0); + ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0); + ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0); + ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0); + ASSERT_EQ(connectionEventCallback.getBackoffError(), 0); +} + + + /** * Test AsyncServerSocket::getNumPendingMessagesInQueue() */ -- 2.34.1