switch (msg.type) {
case MessageType::MSG_NEW_CONN:
{
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onConnectionDequeuedByAcceptCallback(
+ msg.fd, msg.address);
+ }
callback_->connectionAccepted(msg.fd, msg.address);
break;
}
// callback more efficiently without having to use a notification queue.
RemoteAcceptor* acceptor = nullptr;
try {
- acceptor = new RemoteAcceptor(callback);
+ acceptor = new RemoteAcceptor(callback, connectionEventCallback_);
acceptor->start(eventBase, maxAtOnce, maxNumMsgsInQueue_);
} catch (...) {
callbacks_.pop_back();
address.setFromSockaddr(saddr, addrLen);
+ if (clientSocket >= 0 && connectionEventCallback_) {
+ connectionEventCallback_->onConnectionAccepted(clientSocket, address);
+ }
+
std::chrono::time_point<std::chrono::steady_clock> nowMs =
std::chrono::steady_clock::now();
auto timeSinceLastAccept = std::max<int64_t>(
++numDroppedConnections_;
if (clientSocket >= 0) {
closeNoInt(clientSocket);
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onConnectionDropped(clientSocket,
+ address);
+ }
}
continue;
}
} else {
dispatchError("accept() failed", errno);
}
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onConnectionAcceptError(errno);
+ }
return;
}
closeNoInt(clientSocket);
dispatchError("failed to set accepted socket to non-blocking mode",
errno);
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onConnectionDropped(clientSocket, address);
+ }
return;
}
#endif
return;
}
+ const SocketAddress addr(address);
// Create a message to send over the notification queue
QueueMessage msg;
msg.type = MessageType::MSG_NEW_CONN;
// Loop until we find a free queue to write to
while (true) {
if (info->consumer->getQueue()->tryPutMessageNoThrow(std::move(msg))) {
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onConnectionEnqueuedForAcceptCallback(socket,
+ addr);
+ }
// Success! return.
return;
- }
+ }
// We couldn't add to queue. Fall through to below
LOG(ERROR) << "failed to dispatch newly accepted socket:"
<< " all accept callback queues are full";
closeNoInt(socket);
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onConnectionDropped(socket, addr);
+ }
return;
}
// since we won't be able to re-enable ourselves later.
LOG(ERROR) << "failed to allocate AsyncServerSocket backoff"
<< " timer; unable to temporarly pause accepting";
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onBackoffError();
+ }
return;
}
}
if (!backoffTimeout_->scheduleTimeout(timeoutMS)) {
LOG(ERROR) << "failed to schedule AsyncServerSocket backoff timer;"
<< "unable to temporarly pause accepting";
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onBackoffError();
+ }
return;
}
for (auto& handler : sockets_) {
handler.unregisterHandler();
}
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onBackoffStarted();
+ }
}
void AsyncServerSocket::backoffTimeoutExpired() {
// If all of the callbacks were removed, we shouldn't re-enable accepts
if (callbacks_.empty()) {
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onBackoffEnded();
+ }
return;
}
abort();
}
}
+ if (connectionEventCallback_) {
+ connectionEventCallback_->onBackoffEnded();
+ }
}
// Disallow copy, move, and default construction.
AsyncServerSocket(AsyncServerSocket&&) = delete;
+ /**
+ * A callback interface to get notified of client socket events.
+ *
+ * The ConnectionEventCallback implementations need to be thread-safe as the
+ * callbacks may be called from different threads.
+ */
+ class ConnectionEventCallback {
+ public:
+ virtual ~ConnectionEventCallback() = default;
+
+ /**
+ * onConnectionAccepted() is called right after a client connection
+ * is accepted using the system accept()/accept4() APIs.
+ */
+ virtual void onConnectionAccepted(const int socket,
+ const SocketAddress& addr) noexcept = 0;
+
+ /**
+ * onConnectionAcceptError() is called when an error occurred accepting
+ * a connection.
+ */
+ virtual void onConnectionAcceptError(const int err) noexcept = 0;
+
+ /**
+ * onConnectionDropped() is called when a connection is dropped,
+ * probably because of some error encountered.
+ */
+ virtual void onConnectionDropped(const int socket,
+ const SocketAddress& addr) noexcept = 0;
+
+ /**
+ * onConnectionEnqueuedForAcceptCallback() is called when the
+ * connection is successfully enqueued for an AcceptCallback to pick up.
+ */
+ virtual void onConnectionEnqueuedForAcceptCallback(
+ const int socket,
+ const SocketAddress& addr) noexcept = 0;
+
+ /**
+ * onConnectionDequeuedByAcceptCallback() is called when the
+ * connection is successfully dequeued by an AcceptCallback.
+ */
+ virtual void onConnectionDequeuedByAcceptCallback(
+ const int socket,
+ const SocketAddress& addr) noexcept = 0;
+
+ /**
+ * onBackoffStarted is called when the socket has successfully started
+ * backing off accepting new client sockets.
+ */
+ virtual void onBackoffStarted() noexcept = 0;
+
+ /**
+ * onBackoffEnded is called when the backoff period has ended and the socket
+ * has successfully resumed accepting new connections if there is any
+ * AcceptCallback registered.
+ */
+ virtual void onBackoffEnded() noexcept = 0;
+
+ /**
+ * onBackoffError is called when there is an error entering backoff
+ */
+ virtual void onBackoffError() noexcept = 0;
+ };
+
class AcceptCallback {
public:
virtual ~AcceptCallback() = default;
*
* When a new socket is accepted, one of the AcceptCallbacks will be invoked
* with the new socket. The AcceptCallbacks are invoked in a round-robin
- * fashion. This allows the accepted sockets to distributed among a pool of
- * threads, each running its own EventBase object. This is a common model,
+ * fashion. This allows the accepted sockets to be distributed among a pool
+ * of threads, each running its own EventBase object. This is a common model,
* since most asynchronous-style servers typically run one EventBase thread
* per CPU.
*
return accepting_;
}
+ /**
+ * Set the ConnectionEventCallback
+ */
+ void setConnectionEventCallback(
+ ConnectionEventCallback* const connectionEventCallback) {
+ connectionEventCallback_ = connectionEventCallback;
+ }
+
+ /**
+ * Get the ConnectionEventCallback
+ */
+ ConnectionEventCallback* getConnectionEventCallback() const {
+ return connectionEventCallback_;
+ }
+
protected:
/**
* Protected destructor.
class RemoteAcceptor
: private NotificationQueue<QueueMessage>::Consumer {
public:
- explicit RemoteAcceptor(AcceptCallback *callback)
- : callback_(callback) {}
+ explicit RemoteAcceptor(AcceptCallback *callback,
+ ConnectionEventCallback *connectionEventCallback)
+ : callback_(callback),
+ connectionEventCallback_(connectionEventCallback) {}
~RemoteAcceptor() = default;
private:
AcceptCallback *callback_;
+ ConnectionEventCallback* connectionEventCallback_;
NotificationQueue<QueueMessage> queue_;
};
bool reusePortEnabled_{false};
bool closeOnExec_;
ShutdownSocketSet* shutdownSocketSet_;
+ ConnectionEventCallback* connectionEventCallback_{nullptr};
};
} // folly
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/EventBase.h>
+#include <folly/RWSpinLock.h>
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
///////////////////////////////////////////////////////////////////////////
// AsyncServerSocket tests
///////////////////////////////////////////////////////////////////////////
+namespace {
+/**
+ * Helper ConnectionEventCallback class for the test code.
+ * It maintains counters protected by a spin lock.
+ */
+class TestConnectionEventCallback :
+ public AsyncServerSocket::ConnectionEventCallback {
+ public:
+ virtual void onConnectionAccepted(
+ const int socket,
+ const SocketAddress& addr) noexcept override {
+ folly::RWSpinLock::WriteHolder holder(spinLock_);
+ connectionAccepted_++;
+ }
+
+ virtual void onConnectionAcceptError(const int err) noexcept override {
+ folly::RWSpinLock::WriteHolder holder(spinLock_);
+ connectionAcceptedError_++;
+ }
+
+ virtual void onConnectionDropped(
+ const int socket,
+ const SocketAddress& addr) noexcept override {
+ folly::RWSpinLock::WriteHolder holder(spinLock_);
+ connectionDropped_++;
+ }
+
+ virtual void onConnectionEnqueuedForAcceptCallback(
+ const int socket,
+ const SocketAddress& addr) noexcept override {
+ folly::RWSpinLock::WriteHolder holder(spinLock_);
+ connectionEnqueuedForAcceptCallback_++;
+ }
+
+ virtual void onConnectionDequeuedByAcceptCallback(
+ const int socket,
+ const SocketAddress& addr) noexcept override {
+ folly::RWSpinLock::WriteHolder holder(spinLock_);
+ connectionDequeuedByAcceptCallback_++;
+ }
+
+ virtual void onBackoffStarted() noexcept override {
+ folly::RWSpinLock::WriteHolder holder(spinLock_);
+ backoffStarted_++;
+ }
+
+ virtual void onBackoffEnded() noexcept override {
+ folly::RWSpinLock::WriteHolder holder(spinLock_);
+ backoffEnded_++;
+ }
+
+ virtual void onBackoffError() noexcept override {
+ folly::RWSpinLock::WriteHolder holder(spinLock_);
+ backoffError_++;
+ }
+
+ unsigned int getConnectionAccepted() const {
+ folly::RWSpinLock::ReadHolder holder(spinLock_);
+ return connectionAccepted_;
+ }
+
+ unsigned int getConnectionAcceptedError() const {
+ folly::RWSpinLock::ReadHolder holder(spinLock_);
+ return connectionAcceptedError_;
+ }
+
+ unsigned int getConnectionDropped() const {
+ folly::RWSpinLock::ReadHolder holder(spinLock_);
+ return connectionDropped_;
+ }
+
+ unsigned int getConnectionEnqueuedForAcceptCallback() const {
+ folly::RWSpinLock::ReadHolder holder(spinLock_);
+ return connectionEnqueuedForAcceptCallback_;
+ }
+
+ unsigned int getConnectionDequeuedByAcceptCallback() const {
+ folly::RWSpinLock::ReadHolder holder(spinLock_);
+ return connectionDequeuedByAcceptCallback_;
+ }
+
+ unsigned int getBackoffStarted() const {
+ folly::RWSpinLock::ReadHolder holder(spinLock_);
+ return backoffStarted_;
+ }
+
+ unsigned int getBackoffEnded() const {
+ folly::RWSpinLock::ReadHolder holder(spinLock_);
+ return backoffEnded_;
+ }
+
+ unsigned int getBackoffError() const {
+ folly::RWSpinLock::ReadHolder holder(spinLock_);
+ return backoffError_;
+ }
+
+ private:
+ mutable folly::RWSpinLock spinLock_;
+ unsigned int connectionAccepted_{0};
+ unsigned int connectionAcceptedError_{0};
+ unsigned int connectionDropped_{0};
+ unsigned int connectionEnqueuedForAcceptCallback_{0};
+ unsigned int connectionDequeuedByAcceptCallback_{0};
+ unsigned int backoffStarted_{0};
+ unsigned int backoffEnded_{0};
+ unsigned int backoffError_{0};
+};
/**
* Helper AcceptCallback class for the test code
std::deque<EventInfo> events_;
};
+}
/**
* Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
int flags = fcntl(fd, F_GETFL, 0);
CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
}
+
+TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
+ EventBase eventBase;
+ TestConnectionEventCallback connectionEventCallback;
+
+ // Create a server socket
+ std::shared_ptr<AsyncServerSocket> serverSocket(
+ AsyncServerSocket::newSocket(&eventBase));
+ serverSocket->setConnectionEventCallback(&connectionEventCallback);
+ serverSocket->bind(0);
+ serverSocket->listen(16);
+ folly::SocketAddress serverAddress;
+ serverSocket->getAddress(&serverAddress);
+
+ // Add a callback to accept one connection then stop the loop
+ TestAcceptCallback acceptCallback;
+ acceptCallback.setConnectionAcceptedFn(
+ [&](int fd, const folly::SocketAddress& addr) {
+ serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+ });
+ acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
+ serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+ });
+ serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+ serverSocket->startAccepting();
+
+ // Connect to the server socket
+ std::shared_ptr<AsyncSocket> socket(
+ AsyncSocket::newSocket(&eventBase, serverAddress));
+
+ eventBase.loop();
+
+ // Validate the connection event counters
+ ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
+ ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
+ ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
+ ASSERT_EQ(
+ connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
+ ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
+ ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
+ ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
+ ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
+}