From 6d3bf6468bcceef79ab89ac2c7a6df19832ef403 Mon Sep 17 00:00:00 2001 From: Mohammad Husain Date: Thu, 15 Oct 2015 21:13:25 -0700 Subject: [PATCH] Add connection event callback to AsyncServerSocket Summary: Adding a callback to AsyncServerSocket to get notified of client connection events. This can be used for example to record stats about these events. Reviewed By: @afrind Differential Revision: D2544776 fb-gh-sync-id: 20d22cfc939c5b937abec2b600c10b7228923ff3 --- folly/io/async/AsyncServerSocket.cpp | 45 ++++++- folly/io/async/AsyncServerSocket.h | 92 +++++++++++++- folly/io/async/test/AsyncSocketTest2.cpp | 152 +++++++++++++++++++++++ 3 files changed, 283 insertions(+), 6 deletions(-) diff --git a/folly/io/async/AsyncServerSocket.cpp b/folly/io/async/AsyncServerSocket.cpp index 75e0da30..b2b2d24b 100644 --- a/folly/io/async/AsyncServerSocket.cpp +++ b/folly/io/async/AsyncServerSocket.cpp @@ -91,6 +91,10 @@ void AsyncServerSocket::RemoteAcceptor::messageAvailable( switch (msg.type) { case MessageType::MSG_NEW_CONN: { + if (connectionEventCallback_) { + connectionEventCallback_->onConnectionDequeuedByAcceptCallback( + msg.fd, msg.address); + } callback_->connectionAccepted(msg.fd, msg.address); break; } @@ -515,7 +519,7 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback, // callback more efficiently without having to use a notification queue. RemoteAcceptor* acceptor = nullptr; try { - acceptor = new RemoteAcceptor(callback); + acceptor = new RemoteAcceptor(callback, connectionEventCallback_); acceptor->start(eventBase, maxAtOnce, maxNumMsgsInQueue_); } catch (...) { callbacks_.pop_back(); @@ -722,6 +726,10 @@ void AsyncServerSocket::handlerReady( address.setFromSockaddr(saddr, addrLen); + if (clientSocket >= 0 && connectionEventCallback_) { + connectionEventCallback_->onConnectionAccepted(clientSocket, address); + } + std::chrono::time_point nowMs = std::chrono::steady_clock::now(); auto timeSinceLastAccept = std::max( @@ -737,6 +745,10 @@ void AsyncServerSocket::handlerReady( ++numDroppedConnections_; if (clientSocket >= 0) { closeNoInt(clientSocket); + if (connectionEventCallback_) { + connectionEventCallback_->onConnectionDropped(clientSocket, + address); + } } continue; } @@ -760,6 +772,9 @@ void AsyncServerSocket::handlerReady( } else { dispatchError("accept() failed", errno); } + if (connectionEventCallback_) { + connectionEventCallback_->onConnectionAcceptError(errno); + } return; } @@ -769,6 +784,9 @@ void AsyncServerSocket::handlerReady( closeNoInt(clientSocket); dispatchError("failed to set accepted socket to non-blocking mode", errno); + if (connectionEventCallback_) { + connectionEventCallback_->onConnectionDropped(clientSocket, address); + } return; } #endif @@ -795,6 +813,7 @@ void AsyncServerSocket::dispatchSocket(int socket, return; } + const SocketAddress addr(address); // Create a message to send over the notification queue QueueMessage msg; msg.type = MessageType::MSG_NEW_CONN; @@ -804,9 +823,13 @@ void AsyncServerSocket::dispatchSocket(int socket, // Loop until we find a free queue to write to while (true) { if (info->consumer->getQueue()->tryPutMessageNoThrow(std::move(msg))) { + if (connectionEventCallback_) { + connectionEventCallback_->onConnectionEnqueuedForAcceptCallback(socket, + addr); + } // Success! return. return; - } + } // We couldn't add to queue. Fall through to below @@ -831,6 +854,9 @@ void AsyncServerSocket::dispatchSocket(int socket, LOG(ERROR) << "failed to dispatch newly accepted socket:" << " all accept callback queues are full"; closeNoInt(socket); + if (connectionEventCallback_) { + connectionEventCallback_->onConnectionDropped(socket, addr); + } return; } @@ -886,6 +912,9 @@ void AsyncServerSocket::enterBackoff() { // since we won't be able to re-enable ourselves later. LOG(ERROR) << "failed to allocate AsyncServerSocket backoff" << " timer; unable to temporarly pause accepting"; + if (connectionEventCallback_) { + connectionEventCallback_->onBackoffError(); + } return; } } @@ -903,6 +932,9 @@ void AsyncServerSocket::enterBackoff() { if (!backoffTimeout_->scheduleTimeout(timeoutMS)) { LOG(ERROR) << "failed to schedule AsyncServerSocket backoff timer;" << "unable to temporarly pause accepting"; + if (connectionEventCallback_) { + connectionEventCallback_->onBackoffError(); + } return; } @@ -912,6 +944,9 @@ void AsyncServerSocket::enterBackoff() { for (auto& handler : sockets_) { handler.unregisterHandler(); } + if (connectionEventCallback_) { + connectionEventCallback_->onBackoffStarted(); + } } void AsyncServerSocket::backoffTimeoutExpired() { @@ -924,6 +959,9 @@ void AsyncServerSocket::backoffTimeoutExpired() { // If all of the callbacks were removed, we shouldn't re-enable accepts if (callbacks_.empty()) { + if (connectionEventCallback_) { + connectionEventCallback_->onBackoffEnded(); + } return; } @@ -942,6 +980,9 @@ void AsyncServerSocket::backoffTimeoutExpired() { abort(); } } + if (connectionEventCallback_) { + connectionEventCallback_->onBackoffEnded(); + } } diff --git a/folly/io/async/AsyncServerSocket.h b/folly/io/async/AsyncServerSocket.h index 935e1917..4f1194f7 100644 --- a/folly/io/async/AsyncServerSocket.h +++ b/folly/io/async/AsyncServerSocket.h @@ -64,6 +64,71 @@ class AsyncServerSocket : public DelayedDestruction // Disallow copy, move, and default construction. AsyncServerSocket(AsyncServerSocket&&) = delete; + /** + * A callback interface to get notified of client socket events. + * + * The ConnectionEventCallback implementations need to be thread-safe as the + * callbacks may be called from different threads. + */ + class ConnectionEventCallback { + public: + virtual ~ConnectionEventCallback() = default; + + /** + * onConnectionAccepted() is called right after a client connection + * is accepted using the system accept()/accept4() APIs. + */ + virtual void onConnectionAccepted(const int socket, + const SocketAddress& addr) noexcept = 0; + + /** + * onConnectionAcceptError() is called when an error occurred accepting + * a connection. + */ + virtual void onConnectionAcceptError(const int err) noexcept = 0; + + /** + * onConnectionDropped() is called when a connection is dropped, + * probably because of some error encountered. + */ + virtual void onConnectionDropped(const int socket, + const SocketAddress& addr) noexcept = 0; + + /** + * onConnectionEnqueuedForAcceptCallback() is called when the + * connection is successfully enqueued for an AcceptCallback to pick up. + */ + virtual void onConnectionEnqueuedForAcceptCallback( + const int socket, + const SocketAddress& addr) noexcept = 0; + + /** + * onConnectionDequeuedByAcceptCallback() is called when the + * connection is successfully dequeued by an AcceptCallback. + */ + virtual void onConnectionDequeuedByAcceptCallback( + const int socket, + const SocketAddress& addr) noexcept = 0; + + /** + * onBackoffStarted is called when the socket has successfully started + * backing off accepting new client sockets. + */ + virtual void onBackoffStarted() noexcept = 0; + + /** + * onBackoffEnded is called when the backoff period has ended and the socket + * has successfully resumed accepting new connections if there is any + * AcceptCallback registered. + */ + virtual void onBackoffEnded() noexcept = 0; + + /** + * onBackoffError is called when there is an error entering backoff + */ + virtual void onBackoffError() noexcept = 0; + }; + class AcceptCallback { public: virtual ~AcceptCallback() = default; @@ -320,8 +385,8 @@ class AsyncServerSocket : public DelayedDestruction * * When a new socket is accepted, one of the AcceptCallbacks will be invoked * with the new socket. The AcceptCallbacks are invoked in a round-robin - * fashion. This allows the accepted sockets to distributed among a pool of - * threads, each running its own EventBase object. This is a common model, + * fashion. This allows the accepted sockets to be distributed among a pool + * of threads, each running its own EventBase object. This is a common model, * since most asynchronous-style servers typically run one EventBase thread * per CPU. * @@ -584,6 +649,21 @@ class AsyncServerSocket : public DelayedDestruction return accepting_; } + /** + * Set the ConnectionEventCallback + */ + void setConnectionEventCallback( + ConnectionEventCallback* const connectionEventCallback) { + connectionEventCallback_ = connectionEventCallback; + } + + /** + * Get the ConnectionEventCallback + */ + ConnectionEventCallback* getConnectionEventCallback() const { + return connectionEventCallback_; + } + protected: /** * Protected destructor. @@ -618,8 +698,10 @@ class AsyncServerSocket : public DelayedDestruction class RemoteAcceptor : private NotificationQueue::Consumer { public: - explicit RemoteAcceptor(AcceptCallback *callback) - : callback_(callback) {} + explicit RemoteAcceptor(AcceptCallback *callback, + ConnectionEventCallback *connectionEventCallback) + : callback_(callback), + connectionEventCallback_(connectionEventCallback) {} ~RemoteAcceptor() = default; @@ -634,6 +716,7 @@ class AsyncServerSocket : public DelayedDestruction private: AcceptCallback *callback_; + ConnectionEventCallback* connectionEventCallback_; NotificationQueue queue_; }; @@ -738,6 +821,7 @@ class AsyncServerSocket : public DelayedDestruction bool reusePortEnabled_{false}; bool closeOnExec_; ShutdownSocketSet* shutdownSocketSet_; + ConnectionEventCallback* connectionEventCallback_{nullptr}; }; } // folly diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp index 74da6621..2c822eef 100644 --- a/folly/io/async/test/AsyncSocketTest2.cpp +++ b/folly/io/async/test/AsyncSocketTest2.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -1452,6 +1453,113 @@ TEST(AsyncSocket, ConnectReadUninstallRead) { /////////////////////////////////////////////////////////////////////////// // AsyncServerSocket tests /////////////////////////////////////////////////////////////////////////// +namespace { +/** + * Helper ConnectionEventCallback class for the test code. + * It maintains counters protected by a spin lock. + */ +class TestConnectionEventCallback : + public AsyncServerSocket::ConnectionEventCallback { + public: + virtual void onConnectionAccepted( + const int socket, + const SocketAddress& addr) noexcept override { + folly::RWSpinLock::WriteHolder holder(spinLock_); + connectionAccepted_++; + } + + virtual void onConnectionAcceptError(const int err) noexcept override { + folly::RWSpinLock::WriteHolder holder(spinLock_); + connectionAcceptedError_++; + } + + virtual void onConnectionDropped( + const int socket, + const SocketAddress& addr) noexcept override { + folly::RWSpinLock::WriteHolder holder(spinLock_); + connectionDropped_++; + } + + virtual void onConnectionEnqueuedForAcceptCallback( + const int socket, + const SocketAddress& addr) noexcept override { + folly::RWSpinLock::WriteHolder holder(spinLock_); + connectionEnqueuedForAcceptCallback_++; + } + + virtual void onConnectionDequeuedByAcceptCallback( + const int socket, + const SocketAddress& addr) noexcept override { + folly::RWSpinLock::WriteHolder holder(spinLock_); + connectionDequeuedByAcceptCallback_++; + } + + virtual void onBackoffStarted() noexcept override { + folly::RWSpinLock::WriteHolder holder(spinLock_); + backoffStarted_++; + } + + virtual void onBackoffEnded() noexcept override { + folly::RWSpinLock::WriteHolder holder(spinLock_); + backoffEnded_++; + } + + virtual void onBackoffError() noexcept override { + folly::RWSpinLock::WriteHolder holder(spinLock_); + backoffError_++; + } + + unsigned int getConnectionAccepted() const { + folly::RWSpinLock::ReadHolder holder(spinLock_); + return connectionAccepted_; + } + + unsigned int getConnectionAcceptedError() const { + folly::RWSpinLock::ReadHolder holder(spinLock_); + return connectionAcceptedError_; + } + + unsigned int getConnectionDropped() const { + folly::RWSpinLock::ReadHolder holder(spinLock_); + return connectionDropped_; + } + + unsigned int getConnectionEnqueuedForAcceptCallback() const { + folly::RWSpinLock::ReadHolder holder(spinLock_); + return connectionEnqueuedForAcceptCallback_; + } + + unsigned int getConnectionDequeuedByAcceptCallback() const { + folly::RWSpinLock::ReadHolder holder(spinLock_); + return connectionDequeuedByAcceptCallback_; + } + + unsigned int getBackoffStarted() const { + folly::RWSpinLock::ReadHolder holder(spinLock_); + return backoffStarted_; + } + + unsigned int getBackoffEnded() const { + folly::RWSpinLock::ReadHolder holder(spinLock_); + return backoffEnded_; + } + + unsigned int getBackoffError() const { + folly::RWSpinLock::ReadHolder holder(spinLock_); + return backoffError_; + } + + private: + mutable folly::RWSpinLock spinLock_; + unsigned int connectionAccepted_{0}; + unsigned int connectionAcceptedError_{0}; + unsigned int connectionDropped_{0}; + unsigned int connectionEnqueuedForAcceptCallback_{0}; + unsigned int connectionDequeuedByAcceptCallback_{0}; + unsigned int backoffStarted_{0}; + unsigned int backoffEnded_{0}; + unsigned int backoffError_{0}; +}; /** * Helper AcceptCallback class for the test code @@ -1552,6 +1660,7 @@ class TestAcceptCallback : public AsyncServerSocket::AcceptCallback { std::deque events_; }; +} /** * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set @@ -2043,3 +2152,46 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) { int flags = fcntl(fd, F_GETFL, 0); CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK); } + +TEST(AsyncSocketTest, ConnectionEventCallbackDefault) { + EventBase eventBase; + TestConnectionEventCallback connectionEventCallback; + + // Create a server socket + std::shared_ptr serverSocket( + AsyncServerSocket::newSocket(&eventBase)); + serverSocket->setConnectionEventCallback(&connectionEventCallback); + serverSocket->bind(0); + serverSocket->listen(16); + folly::SocketAddress serverAddress; + serverSocket->getAddress(&serverAddress); + + // Add a callback to accept one connection then stop the loop + TestAcceptCallback acceptCallback; + acceptCallback.setConnectionAcceptedFn( + [&](int fd, const folly::SocketAddress& addr) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + acceptCallback.setAcceptErrorFn([&](const std::exception& ex) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + serverSocket->addAcceptCallback(&acceptCallback, nullptr); + serverSocket->startAccepting(); + + // Connect to the server socket + std::shared_ptr socket( + AsyncSocket::newSocket(&eventBase, serverAddress)); + + eventBase.loop(); + + // Validate the connection event counters + ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1); + ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0); + ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0); + ASSERT_EQ( + connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1); + ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1); + ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0); + ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0); + ASSERT_EQ(connectionEventCallback.getBackoffError(), 0); +} -- 2.34.1