2 * Copyright (c) 2015, Facebook, Inc.
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree. An additional grant
7 * of patent rights can be found in the PATENTS file in the same directory.
10 #include <folly/wangle/acceptor/Acceptor.h>
12 #include <folly/wangle/acceptor/ManagedConnection.h>
13 #include <folly/wangle/ssl/SSLContextManager.h>
15 #include <boost/cast.hpp>
17 #include <folly/ScopeGuard.h>
18 #include <folly/io/async/EventBase.h>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <gflags/gflags.h>
27 using folly::wangle::ConnectionManager;
28 using folly::wangle::ManagedConnection;
29 using std::chrono::microseconds;
30 using std::chrono::milliseconds;
34 using std::shared_ptr;
40 DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
41 "closing idle conns");
43 const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
46 static const std::string empty_string;
47 std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
50 * Lightweight wrapper class to keep track of a newly
51 * accepted connection during SSL handshaking.
53 class AcceptorHandshakeHelper :
54 public AsyncSSLSocket::HandshakeCB,
55 public ManagedConnection {
57 AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
59 const SocketAddress& clientAddr,
60 std::chrono::steady_clock::time_point acceptTime,
62 : socket_(std::move(socket)), acceptor_(acceptor),
63 acceptTime_(acceptTime), clientAddr_(clientAddr),
65 acceptor_->downstreamConnectionManager_->addConnection(this, true);
66 if(acceptor_->parseClientHello_) {
67 socket_->enableClientHelloParsing();
69 socket_->sslAccept(this);
72 void timeoutExpired() noexcept override {
73 VLOG(4) << "SSL handshake timeout expired";
74 sslError_ = SSLErrorEnum::TIMEOUT;
77 void describe(std::ostream& os) const override {
78 os << "pending handshake on " << clientAddr_;
80 bool isBusy() const override { return true; }
81 void notifyPendingShutdown() override {}
82 void closeWhenIdle() override {}
84 void dropConnection() override {
85 VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
88 void dumpConnectionState(uint8_t loglevel) override {}
91 // AsyncSSLSocket::HandshakeCallback API
92 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
94 const unsigned char* nextProto = nullptr;
95 unsigned nextProtoLength = 0;
96 sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
99 VLOG(3) << "Client selected next protocol " <<
100 string((const char*)nextProto, nextProtoLength);
102 VLOG(3) << "Client did not select a next protocol";
106 // fill in SSL-related fields from TransportInfo
107 // the other fields like RTT are filled in the Acceptor
109 tinfo_.acceptTime = acceptTime_;
110 tinfo_.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(
111 std::chrono::steady_clock::now() - acceptTime_
113 tinfo_.sslSetupBytesRead = sock->getRawBytesReceived();
114 tinfo_.sslSetupBytesWritten = sock->getRawBytesWritten();
115 tinfo_.sslServerName = sock->getSSLServerName() ?
116 std::make_shared<std::string>(sock->getSSLServerName()) : nullptr;
117 tinfo_.sslCipher = sock->getNegotiatedCipherName() ?
118 std::make_shared<std::string>(sock->getNegotiatedCipherName()) : nullptr;
119 tinfo_.sslVersion = sock->getSSLVersion();
120 tinfo_.sslCertSize = sock->getSSLCertSize();
121 tinfo_.sslResume = SSLUtil::getResumeState(sock);
122 tinfo_.sslClientCiphers = std::make_shared<std::string>();
123 sock->getSSLClientCiphers(*tinfo_.sslClientCiphers);
124 tinfo_.sslServerCiphers = std::make_shared<std::string>();
125 sock->getSSLServerCiphers(*tinfo_.sslServerCiphers);
126 tinfo_.sslClientComprMethods =
127 std::make_shared<std::string>(sock->getSSLClientComprMethods());
128 tinfo_.sslClientExts =
129 std::make_shared<std::string>(sock->getSSLClientExts());
130 tinfo_.sslNextProtocol = std::make_shared<std::string>();
131 tinfo_.sslNextProtocol->assign(reinterpret_cast<const char*>(nextProto),
134 acceptor_->updateSSLStats(
137 SSLErrorEnum::NO_ERROR
139 acceptor_->downstreamConnectionManager_->removeConnection(this);
140 acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
141 nextProto ? string((const char*)nextProto, nextProtoLength) :
142 empty_string, tinfo_);
146 void handshakeErr(AsyncSSLSocket* sock,
147 const AsyncSocketException& ex) noexcept override {
148 auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
149 VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
150 " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
151 sock->getRawBytesWritten() << " bytes sent: " <<
153 acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
154 acceptor_->sslConnectionError();
158 AsyncSSLSocket::UniquePtr socket_;
160 std::chrono::steady_clock::time_point acceptTime_;
161 SocketAddress clientAddr_;
162 TransportInfo tinfo_;
163 SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
166 Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
167 accConfig_(accConfig),
168 socketOptions_(accConfig.getSocketOptions()) {
172 Acceptor::init(AsyncServerSocket* serverSocket,
173 EventBase* eventBase) {
174 CHECK(nullptr == this->base_);
176 if (accConfig_.isSSL()) {
177 if (!sslCtxManager_) {
178 sslCtxManager_ = folly::make_unique<SSLContextManager>(
181 accConfig_.strictSSL, nullptr);
183 for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
184 sslCtxManager_->addSSLContextConfig(
186 accConfig_.sslCacheOptions,
187 &accConfig_.initialTicketSeeds,
188 accConfig_.bindAddress,
190 parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
193 CHECK(sslCtxManager_->getDefaultSSLCtx());
197 state_ = State::kRunning;
198 downstreamConnectionManager_ = ConnectionManager::makeUnique(
199 eventBase, accConfig_.connectionIdleTimeout, this);
202 serverSocket->addAcceptCallback(this, eventBase);
204 for (auto& fd : serverSocket->getSockets()) {
208 for (const auto& opt: socketOptions_) {
209 opt.first.apply(fd, opt.second);
215 Acceptor::~Acceptor(void) {
218 void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
219 sslCtxManager_->addSSLContextConfig(sslCtxConfig,
220 accConfig_.sslCacheOptions,
221 &accConfig_.initialTicketSeeds,
222 accConfig_.bindAddress,
227 Acceptor::drainAllConnections() {
228 if (downstreamConnectionManager_) {
229 downstreamConnectionManager_->initiateGracefulShutdown(
230 std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
234 void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
235 IConnectionCounter* counter) {
236 loadShedConfig_ = from;
237 connectionCounter_ = counter;
240 bool Acceptor::canAccept(const SocketAddress& address) {
241 if (!connectionCounter_) {
245 uint64_t maxConnections = connectionCounter_->getMaxConnections();
246 if (maxConnections == 0) {
250 uint64_t currentConnections = connectionCounter_->getNumConnections();
251 if (currentConnections < maxConnections) {
255 if (loadShedConfig_.isWhitelisted(address)) {
259 // Take care of comparing connection count against max connections across
260 // all acceptors. Expensive since a lock must be taken to get the counter.
261 auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
262 if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
266 VLOG(4) << address.describe() << " not whitelisted";
271 Acceptor::connectionAccepted(
272 int fd, const SocketAddress& clientAddr) noexcept {
273 if (!canAccept(clientAddr)) {
277 auto acceptTime = std::chrono::steady_clock::now();
278 for (const auto& opt: socketOptions_) {
279 opt.first.apply(fd, opt.second);
282 onDoneAcceptingConnection(fd, clientAddr, acceptTime);
285 void Acceptor::onDoneAcceptingConnection(
287 const SocketAddress& clientAddr,
288 std::chrono::steady_clock::time_point acceptTime) noexcept {
290 processEstablishedConnection(fd, clientAddr, acceptTime, tinfo);
294 Acceptor::processEstablishedConnection(
296 const SocketAddress& clientAddr,
297 std::chrono::steady_clock::time_point acceptTime,
298 TransportInfo& tinfo) noexcept {
299 if (accConfig_.isSSL()) {
300 CHECK(sslCtxManager_);
301 AsyncSSLSocket::UniquePtr sslSock(
302 makeNewAsyncSSLSocket(
303 sslCtxManager_->getDefaultSSLCtx(), base_, fd));
304 ++numPendingSSLConns_;
305 ++totalNumPendingSSLConns_;
306 if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
307 VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
308 " too many handshakes in progress";
309 updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
310 SSLErrorEnum::DROPPED);
311 sslConnectionError();
314 new AcceptorHandshakeHelper(
323 tinfo.acceptTime = acceptTime;
324 AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
325 connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
330 Acceptor::connectionReady(
331 AsyncSocket::UniquePtr sock,
332 const SocketAddress& clientAddr,
333 const string& nextProtocolName,
334 TransportInfo& tinfo) {
335 // Limit the number of reads from the socket per poll loop iteration,
336 // both to keep memory usage under control and to prevent one fast-
337 // writing client from starving other connections.
338 sock->setMaxReadsPerEvent(16);
339 tinfo.initWithSocket(sock.get());
340 onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
344 Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
345 const SocketAddress& clientAddr,
346 const string& nextProtocol,
347 TransportInfo& tinfo) {
348 CHECK(numPendingSSLConns_ > 0);
349 connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
350 --numPendingSSLConns_;
351 --totalNumPendingSSLConns_;
352 if (state_ == State::kDraining) {
358 Acceptor::sslConnectionError() {
359 CHECK(numPendingSSLConns_ > 0);
360 --numPendingSSLConns_;
361 --totalNumPendingSSLConns_;
362 if (state_ == State::kDraining) {
368 Acceptor::acceptError(const std::exception& ex) noexcept {
369 // An error occurred.
370 // The most likely error is out of FDs. AsyncServerSocket will back off
371 // briefly if we are out of FDs, then continue accepting later.
372 // Just log a message here.
373 LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
377 Acceptor::acceptStopped() noexcept {
378 VLOG(3) << "Acceptor " << this << " acceptStopped()";
379 // Drain the open client connections
380 drainAllConnections();
382 // If we haven't yet finished draining, begin doing so by marking ourselves
383 // as in the draining state. We must be sure to hit checkDrained() here, as
384 // if we're completely idle, we can should consider ourself drained
385 // immediately (as there is no outstanding work to complete to cause us to
386 // re-evaluate this).
387 if (state_ != State::kDone) {
388 state_ = State::kDraining;
394 Acceptor::onEmpty(const ConnectionManager& cm) {
395 VLOG(3) << "Acceptor=" << this << " onEmpty()";
396 if (state_ == State::kDraining) {
402 Acceptor::checkDrained() {
403 CHECK(state_ == State::kDraining);
404 if (forceShutdownInProgress_ ||
405 (downstreamConnectionManager_->getNumConnections() != 0) ||
406 (numPendingSSLConns_ != 0)) {
410 VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
413 downstreamConnectionManager_.reset();
415 state_ = State::kDone;
417 onConnectionsDrained();
421 Acceptor::getConnTimeout() const {
422 return accConfig_.connectionIdleTimeout;
425 void Acceptor::addConnection(ManagedConnection* conn) {
426 // Add the socket to the timeout manager so that it can be cleaned
427 // up after being left idle for a long time.
428 downstreamConnectionManager_->addConnection(conn, true);
432 Acceptor::forceStop() {
433 base_->runInEventBaseThread([&] { dropAllConnections(); });
437 Acceptor::dropAllConnections() {
438 if (downstreamConnectionManager_) {
439 VLOG(3) << "Dropping all connections from Acceptor=" << this <<
440 " in thread " << base_;
441 assert(base_->isInEventBaseThread());
442 forceShutdownInProgress_ = true;
443 downstreamConnectionManager_->dropAllConnections();
444 CHECK(downstreamConnectionManager_->getNumConnections() == 0);
445 downstreamConnectionManager_.reset();
447 CHECK(numPendingSSLConns_ == 0);
449 state_ = State::kDone;
450 onConnectionsDrained();