2 * Copyright (c) 2014, Facebook, Inc.
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree. An additional grant
7 * of patent rights can be found in the PATENTS file in the same directory.
10 #include <folly/wangle/acceptor/Acceptor.h>
12 #include <folly/wangle/acceptor/ManagedConnection.h>
13 #include <folly/wangle/ssl/SSLContextManager.h>
15 #include <boost/cast.hpp>
17 #include <folly/ScopeGuard.h>
18 #include <folly/io/async/EventBase.h>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
26 using folly::wangle::ConnectionManager;
27 using folly::wangle::ManagedConnection;
28 using std::chrono::microseconds;
29 using std::chrono::milliseconds;
33 using std::shared_ptr;
39 DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
40 "closing idle conns");
42 const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
45 static const std::string empty_string;
46 std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
49 * Lightweight wrapper class to keep track of a newly
50 * accepted connection during SSL handshaking.
52 class AcceptorHandshakeHelper :
53 public AsyncSSLSocket::HandshakeCB,
54 public ManagedConnection {
56 AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
58 const SocketAddress& clientAddr,
59 std::chrono::steady_clock::time_point acceptTime)
60 : socket_(std::move(socket)), acceptor_(acceptor),
61 acceptTime_(acceptTime), clientAddr_(clientAddr) {
62 acceptor_->downstreamConnectionManager_->addConnection(this, true);
63 if(acceptor_->parseClientHello_) {
64 socket_->enableClientHelloParsing();
66 socket_->sslAccept(this);
69 virtual void timeoutExpired() noexcept {
70 VLOG(4) << "SSL handshake timeout expired";
71 sslError_ = SSLErrorEnum::TIMEOUT;
74 virtual void describe(std::ostream& os) const {
75 os << "pending handshake on " << clientAddr_;
77 virtual bool isBusy() const {
80 virtual void notifyPendingShutdown() {}
81 virtual void closeWhenIdle() {}
83 virtual void dropConnection() {
84 VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
87 virtual void dumpConnectionState(uint8_t loglevel) {
91 // AsyncSSLSocket::HandshakeCallback API
92 virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept {
94 const unsigned char* nextProto = nullptr;
95 unsigned nextProtoLength = 0;
96 sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
99 VLOG(3) << "Client selected next protocol " <<
100 string((const char*)nextProto, nextProtoLength);
102 VLOG(3) << "Client did not select a next protocol";
106 // fill in SSL-related fields from TransportInfo
107 // the other fields like RTT are filled in the Acceptor
110 tinfo.acceptTime = acceptTime_;
111 tinfo.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
112 tinfo.sslSetupBytesRead = sock->getRawBytesReceived();
113 tinfo.sslSetupBytesWritten = sock->getRawBytesWritten();
114 tinfo.sslServerName = sock->getSSLServerName() ?
115 std::make_shared<std::string>(sock->getSSLServerName()) : nullptr;
116 tinfo.sslCipher = sock->getNegotiatedCipherName() ?
117 std::make_shared<std::string>(sock->getNegotiatedCipherName()) : nullptr;
118 tinfo.sslVersion = sock->getSSLVersion();
119 tinfo.sslCertSize = sock->getSSLCertSize();
120 tinfo.sslResume = SSLUtil::getResumeState(sock);
121 sock->getSSLClientCiphers(tinfo.sslClientCiphers);
122 sock->getSSLServerCiphers(tinfo.sslServerCiphers);
123 tinfo.sslClientComprMethods = sock->getSSLClientComprMethods();
124 tinfo.sslClientExts = sock->getSSLClientExts();
125 tinfo.sslNextProtocol.assign(
126 reinterpret_cast<const char*>(nextProto),
129 acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR);
130 acceptor_->downstreamConnectionManager_->removeConnection(this);
131 acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
132 nextProto ? string((const char*)nextProto, nextProtoLength) :
133 empty_string, tinfo);
137 virtual void handshakeErr(AsyncSSLSocket* sock,
138 const AsyncSocketException& ex) noexcept {
139 auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
140 VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
141 " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
142 sock->getRawBytesWritten() << " bytes sent: " <<
144 acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
145 acceptor_->sslConnectionError();
149 AsyncSSLSocket::UniquePtr socket_;
151 std::chrono::steady_clock::time_point acceptTime_;
152 SocketAddress clientAddr_;
153 SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
156 Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
157 accConfig_(accConfig),
158 socketOptions_(accConfig.getSocketOptions()) {
162 Acceptor::init(AsyncServerSocket* serverSocket,
163 EventBase* eventBase) {
164 CHECK(nullptr == this->base_);
166 if (accConfig_.isSSL()) {
167 if (!sslCtxManager_) {
168 sslCtxManager_ = folly::make_unique<SSLContextManager>(
171 accConfig_.strictSSL, nullptr);
173 for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
174 sslCtxManager_->addSSLContextConfig(
176 accConfig_.sslCacheOptions,
177 &accConfig_.initialTicketSeeds,
178 accConfig_.bindAddress,
180 parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
183 CHECK(sslCtxManager_->getDefaultSSLCtx());
187 state_ = State::kRunning;
188 downstreamConnectionManager_ = ConnectionManager::makeUnique(
189 eventBase, accConfig_.connectionIdleTimeout, this);
192 serverSocket->addAcceptCallback(this, eventBase);
194 for (auto& fd : serverSocket->getSockets()) {
198 for (const auto& opt: socketOptions_) {
199 opt.first.apply(fd, opt.second);
205 Acceptor::~Acceptor(void) {
208 void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
209 sslCtxManager_->addSSLContextConfig(sslCtxConfig,
210 accConfig_.sslCacheOptions,
211 &accConfig_.initialTicketSeeds,
212 accConfig_.bindAddress,
217 Acceptor::drainAllConnections() {
218 if (downstreamConnectionManager_) {
219 downstreamConnectionManager_->initiateGracefulShutdown(
220 std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
224 void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
225 IConnectionCounter* counter) {
226 loadShedConfig_ = from;
227 connectionCounter_ = counter;
230 bool Acceptor::canAccept(const SocketAddress& address) {
231 if (!connectionCounter_) {
235 uint64_t maxConnections = connectionCounter_->getMaxConnections();
236 if (maxConnections == 0) {
240 uint64_t currentConnections = connectionCounter_->getNumConnections();
241 if (currentConnections < maxConnections) {
245 if (loadShedConfig_.isWhitelisted(address)) {
249 // Take care of comparing connection count against max connections across
250 // all acceptors. Expensive since a lock must be taken to get the counter.
251 auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
252 if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
256 VLOG(4) << address.describe() << " not whitelisted";
261 Acceptor::connectionAccepted(
262 int fd, const SocketAddress& clientAddr) noexcept {
263 if (!canAccept(clientAddr)) {
267 auto acceptTime = std::chrono::steady_clock::now();
268 for (const auto& opt: socketOptions_) {
269 opt.first.apply(fd, opt.second);
272 onDoneAcceptingConnection(fd, clientAddr, acceptTime);
275 void Acceptor::onDoneAcceptingConnection(
277 const SocketAddress& clientAddr,
278 std::chrono::steady_clock::time_point acceptTime) noexcept {
279 processEstablishedConnection(fd, clientAddr, acceptTime);
283 Acceptor::processEstablishedConnection(
285 const SocketAddress& clientAddr,
286 std::chrono::steady_clock::time_point acceptTime) noexcept {
287 if (accConfig_.isSSL()) {
288 CHECK(sslCtxManager_);
289 AsyncSSLSocket::UniquePtr sslSock(
290 makeNewAsyncSSLSocket(
291 sslCtxManager_->getDefaultSSLCtx(), base_, fd));
292 ++numPendingSSLConns_;
293 ++totalNumPendingSSLConns_;
294 if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
295 VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
296 " too many handshakes in progress";
297 updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
298 SSLErrorEnum::DROPPED);
299 sslConnectionError();
302 new AcceptorHandshakeHelper(
303 std::move(sslSock), this, clientAddr, acceptTime);
307 tinfo.acceptTime = acceptTime;
308 AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
309 connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
314 Acceptor::connectionReady(
315 AsyncSocket::UniquePtr sock,
316 const SocketAddress& clientAddr,
317 const string& nextProtocolName,
318 TransportInfo& tinfo) {
319 // Limit the number of reads from the socket per poll loop iteration,
320 // both to keep memory usage under control and to prevent one fast-
321 // writing client from starving other connections.
322 sock->setMaxReadsPerEvent(16);
323 tinfo.initWithSocket(sock.get());
324 onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
328 Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
329 const SocketAddress& clientAddr,
330 const string& nextProtocol,
331 TransportInfo& tinfo) {
332 CHECK(numPendingSSLConns_ > 0);
333 connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
334 --numPendingSSLConns_;
335 --totalNumPendingSSLConns_;
336 if (state_ == State::kDraining) {
342 Acceptor::sslConnectionError() {
343 CHECK(numPendingSSLConns_ > 0);
344 --numPendingSSLConns_;
345 --totalNumPendingSSLConns_;
346 if (state_ == State::kDraining) {
352 Acceptor::acceptError(const std::exception& ex) noexcept {
353 // An error occurred.
354 // The most likely error is out of FDs. AsyncServerSocket will back off
355 // briefly if we are out of FDs, then continue accepting later.
356 // Just log a message here.
357 LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
361 Acceptor::acceptStopped() noexcept {
362 VLOG(3) << "Acceptor " << this << " acceptStopped()";
363 // Drain the open client connections
364 drainAllConnections();
366 // If we haven't yet finished draining, begin doing so by marking ourselves
367 // as in the draining state. We must be sure to hit checkDrained() here, as
368 // if we're completely idle, we can should consider ourself drained
369 // immediately (as there is no outstanding work to complete to cause us to
370 // re-evaluate this).
371 if (state_ != State::kDone) {
372 state_ = State::kDraining;
378 Acceptor::onEmpty(const ConnectionManager& cm) {
379 VLOG(3) << "Acceptor=" << this << " onEmpty()";
380 if (state_ == State::kDraining) {
386 Acceptor::checkDrained() {
387 CHECK(state_ == State::kDraining);
388 if (forceShutdownInProgress_ ||
389 (downstreamConnectionManager_->getNumConnections() != 0) ||
390 (numPendingSSLConns_ != 0)) {
394 VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
397 downstreamConnectionManager_.reset();
399 state_ = State::kDone;
401 onConnectionsDrained();
405 Acceptor::getConnTimeout() const {
406 return accConfig_.connectionIdleTimeout;
409 void Acceptor::addConnection(ManagedConnection* conn) {
410 // Add the socket to the timeout manager so that it can be cleaned
411 // up after being left idle for a long time.
412 downstreamConnectionManager_->addConnection(conn, true);
416 Acceptor::forceStop() {
417 base_->runInEventBaseThread([&] { dropAllConnections(); });
421 Acceptor::dropAllConnections() {
422 if (downstreamConnectionManager_) {
423 VLOG(3) << "Dropping all connections from Acceptor=" << this <<
424 " in thread " << base_;
425 assert(base_->isInEventBaseThread());
426 forceShutdownInProgress_ = true;
427 downstreamConnectionManager_->dropAllConnections();
428 CHECK(downstreamConnectionManager_->getNumConnections() == 0);
429 downstreamConnectionManager_.reset();
431 CHECK(numPendingSSLConns_ == 0);
433 state_ = State::kDone;
434 onConnectionsDrained();