2 * Copyright (c) 2014, Facebook, Inc.
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree. An additional grant
7 * of patent rights can be found in the PATENTS file in the same directory.
10 #include <folly/wangle/acceptor/Acceptor.h>
12 #include <folly/wangle/acceptor/ManagedConnection.h>
13 #include <folly/wangle/ssl/SSLContextManager.h>
15 #include <boost/cast.hpp>
17 #include <folly/ScopeGuard.h>
18 #include <folly/io/async/EventBase.h>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <gflags/gflags.h>
27 using folly::wangle::ConnectionManager;
28 using folly::wangle::ManagedConnection;
29 using std::chrono::microseconds;
30 using std::chrono::milliseconds;
34 using std::shared_ptr;
40 DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
41 "closing idle conns");
43 const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
46 static const std::string empty_string;
47 std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
50 * Lightweight wrapper class to keep track of a newly
51 * accepted connection during SSL handshaking.
53 class AcceptorHandshakeHelper :
54 public AsyncSSLSocket::HandshakeCB,
55 public ManagedConnection {
57 AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
59 const SocketAddress& clientAddr,
60 std::chrono::steady_clock::time_point acceptTime)
61 : socket_(std::move(socket)), acceptor_(acceptor),
62 acceptTime_(acceptTime), clientAddr_(clientAddr) {
63 acceptor_->downstreamConnectionManager_->addConnection(this, true);
64 if(acceptor_->parseClientHello_) {
65 socket_->enableClientHelloParsing();
67 socket_->sslAccept(this);
70 virtual void timeoutExpired() noexcept {
71 VLOG(4) << "SSL handshake timeout expired";
72 sslError_ = SSLErrorEnum::TIMEOUT;
75 virtual void describe(std::ostream& os) const {
76 os << "pending handshake on " << clientAddr_;
78 virtual bool isBusy() const {
81 virtual void notifyPendingShutdown() {}
82 virtual void closeWhenIdle() {}
84 virtual void dropConnection() {
85 VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
88 virtual void dumpConnectionState(uint8_t loglevel) {
92 // AsyncSSLSocket::HandshakeCallback API
93 virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept {
95 const unsigned char* nextProto = nullptr;
96 unsigned nextProtoLength = 0;
97 sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
100 VLOG(3) << "Client selected next protocol " <<
101 string((const char*)nextProto, nextProtoLength);
103 VLOG(3) << "Client did not select a next protocol";
107 // fill in SSL-related fields from TransportInfo
108 // the other fields like RTT are filled in the Acceptor
111 tinfo.acceptTime = acceptTime_;
112 tinfo.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
113 tinfo.sslSetupBytesRead = sock->getRawBytesReceived();
114 tinfo.sslSetupBytesWritten = sock->getRawBytesWritten();
115 tinfo.sslServerName = sock->getSSLServerName() ?
116 std::make_shared<std::string>(sock->getSSLServerName()) : nullptr;
117 tinfo.sslCipher = sock->getNegotiatedCipherName() ?
118 std::make_shared<std::string>(sock->getNegotiatedCipherName()) : nullptr;
119 tinfo.sslVersion = sock->getSSLVersion();
120 tinfo.sslCertSize = sock->getSSLCertSize();
121 tinfo.sslResume = SSLUtil::getResumeState(sock);
122 tinfo.sslClientCiphers = std::make_shared<std::string>();
123 sock->getSSLClientCiphers(*tinfo.sslClientCiphers);
124 tinfo.sslServerCiphers = std::make_shared<std::string>();
125 sock->getSSLServerCiphers(*tinfo.sslServerCiphers);
126 tinfo.sslClientComprMethods =
127 std::make_shared<std::string>(sock->getSSLClientComprMethods());
128 tinfo.sslClientExts =
129 std::make_shared<std::string>(sock->getSSLClientExts());
130 tinfo.sslNextProtocol = std::make_shared<std::string>();
131 tinfo.sslNextProtocol->assign(reinterpret_cast<const char*>(nextProto),
134 acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR);
135 acceptor_->downstreamConnectionManager_->removeConnection(this);
136 acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
137 nextProto ? string((const char*)nextProto, nextProtoLength) :
138 empty_string, tinfo);
142 virtual void handshakeErr(AsyncSSLSocket* sock,
143 const AsyncSocketException& ex) noexcept {
144 auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
145 VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
146 " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
147 sock->getRawBytesWritten() << " bytes sent: " <<
149 acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
150 acceptor_->sslConnectionError();
154 AsyncSSLSocket::UniquePtr socket_;
156 std::chrono::steady_clock::time_point acceptTime_;
157 SocketAddress clientAddr_;
158 SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
161 Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
162 accConfig_(accConfig),
163 socketOptions_(accConfig.getSocketOptions()) {
167 Acceptor::init(AsyncServerSocket* serverSocket,
168 EventBase* eventBase) {
169 CHECK(nullptr == this->base_);
171 if (accConfig_.isSSL()) {
172 if (!sslCtxManager_) {
173 sslCtxManager_ = folly::make_unique<SSLContextManager>(
176 accConfig_.strictSSL, nullptr);
178 for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
179 sslCtxManager_->addSSLContextConfig(
181 accConfig_.sslCacheOptions,
182 &accConfig_.initialTicketSeeds,
183 accConfig_.bindAddress,
185 parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
188 CHECK(sslCtxManager_->getDefaultSSLCtx());
192 state_ = State::kRunning;
193 downstreamConnectionManager_ = ConnectionManager::makeUnique(
194 eventBase, accConfig_.connectionIdleTimeout, this);
197 serverSocket->addAcceptCallback(this, eventBase);
199 for (auto& fd : serverSocket->getSockets()) {
203 for (const auto& opt: socketOptions_) {
204 opt.first.apply(fd, opt.second);
210 Acceptor::~Acceptor(void) {
213 void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
214 sslCtxManager_->addSSLContextConfig(sslCtxConfig,
215 accConfig_.sslCacheOptions,
216 &accConfig_.initialTicketSeeds,
217 accConfig_.bindAddress,
222 Acceptor::drainAllConnections() {
223 if (downstreamConnectionManager_) {
224 downstreamConnectionManager_->initiateGracefulShutdown(
225 std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
229 void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
230 IConnectionCounter* counter) {
231 loadShedConfig_ = from;
232 connectionCounter_ = counter;
235 bool Acceptor::canAccept(const SocketAddress& address) {
236 if (!connectionCounter_) {
240 uint64_t maxConnections = connectionCounter_->getMaxConnections();
241 if (maxConnections == 0) {
245 uint64_t currentConnections = connectionCounter_->getNumConnections();
246 if (currentConnections < maxConnections) {
250 if (loadShedConfig_.isWhitelisted(address)) {
254 // Take care of comparing connection count against max connections across
255 // all acceptors. Expensive since a lock must be taken to get the counter.
256 auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
257 if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
261 VLOG(4) << address.describe() << " not whitelisted";
266 Acceptor::connectionAccepted(
267 int fd, const SocketAddress& clientAddr) noexcept {
268 if (!canAccept(clientAddr)) {
272 auto acceptTime = std::chrono::steady_clock::now();
273 for (const auto& opt: socketOptions_) {
274 opt.first.apply(fd, opt.second);
277 onDoneAcceptingConnection(fd, clientAddr, acceptTime);
280 void Acceptor::onDoneAcceptingConnection(
282 const SocketAddress& clientAddr,
283 std::chrono::steady_clock::time_point acceptTime) noexcept {
284 processEstablishedConnection(fd, clientAddr, acceptTime);
288 Acceptor::processEstablishedConnection(
290 const SocketAddress& clientAddr,
291 std::chrono::steady_clock::time_point acceptTime) noexcept {
292 if (accConfig_.isSSL()) {
293 CHECK(sslCtxManager_);
294 AsyncSSLSocket::UniquePtr sslSock(
295 makeNewAsyncSSLSocket(
296 sslCtxManager_->getDefaultSSLCtx(), base_, fd));
297 ++numPendingSSLConns_;
298 ++totalNumPendingSSLConns_;
299 if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
300 VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
301 " too many handshakes in progress";
302 updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
303 SSLErrorEnum::DROPPED);
304 sslConnectionError();
307 new AcceptorHandshakeHelper(
308 std::move(sslSock), this, clientAddr, acceptTime);
312 tinfo.acceptTime = acceptTime;
313 AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
314 connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
319 Acceptor::connectionReady(
320 AsyncSocket::UniquePtr sock,
321 const SocketAddress& clientAddr,
322 const string& nextProtocolName,
323 TransportInfo& tinfo) {
324 // Limit the number of reads from the socket per poll loop iteration,
325 // both to keep memory usage under control and to prevent one fast-
326 // writing client from starving other connections.
327 sock->setMaxReadsPerEvent(16);
328 tinfo.initWithSocket(sock.get());
329 onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
333 Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
334 const SocketAddress& clientAddr,
335 const string& nextProtocol,
336 TransportInfo& tinfo) {
337 CHECK(numPendingSSLConns_ > 0);
338 connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
339 --numPendingSSLConns_;
340 --totalNumPendingSSLConns_;
341 if (state_ == State::kDraining) {
347 Acceptor::sslConnectionError() {
348 CHECK(numPendingSSLConns_ > 0);
349 --numPendingSSLConns_;
350 --totalNumPendingSSLConns_;
351 if (state_ == State::kDraining) {
357 Acceptor::acceptError(const std::exception& ex) noexcept {
358 // An error occurred.
359 // The most likely error is out of FDs. AsyncServerSocket will back off
360 // briefly if we are out of FDs, then continue accepting later.
361 // Just log a message here.
362 LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
366 Acceptor::acceptStopped() noexcept {
367 VLOG(3) << "Acceptor " << this << " acceptStopped()";
368 // Drain the open client connections
369 drainAllConnections();
371 // If we haven't yet finished draining, begin doing so by marking ourselves
372 // as in the draining state. We must be sure to hit checkDrained() here, as
373 // if we're completely idle, we can should consider ourself drained
374 // immediately (as there is no outstanding work to complete to cause us to
375 // re-evaluate this).
376 if (state_ != State::kDone) {
377 state_ = State::kDraining;
383 Acceptor::onEmpty(const ConnectionManager& cm) {
384 VLOG(3) << "Acceptor=" << this << " onEmpty()";
385 if (state_ == State::kDraining) {
391 Acceptor::checkDrained() {
392 CHECK(state_ == State::kDraining);
393 if (forceShutdownInProgress_ ||
394 (downstreamConnectionManager_->getNumConnections() != 0) ||
395 (numPendingSSLConns_ != 0)) {
399 VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
402 downstreamConnectionManager_.reset();
404 state_ = State::kDone;
406 onConnectionsDrained();
410 Acceptor::getConnTimeout() const {
411 return accConfig_.connectionIdleTimeout;
414 void Acceptor::addConnection(ManagedConnection* conn) {
415 // Add the socket to the timeout manager so that it can be cleaned
416 // up after being left idle for a long time.
417 downstreamConnectionManager_->addConnection(conn, true);
421 Acceptor::forceStop() {
422 base_->runInEventBaseThread([&] { dropAllConnections(); });
426 Acceptor::dropAllConnections() {
427 if (downstreamConnectionManager_) {
428 VLOG(3) << "Dropping all connections from Acceptor=" << this <<
429 " in thread " << base_;
430 assert(base_->isInEventBaseThread());
431 forceShutdownInProgress_ = true;
432 downstreamConnectionManager_->dropAllConnections();
433 CHECK(downstreamConnectionManager_->getNumConnections() == 0);
434 downstreamConnectionManager_.reset();
436 CHECK(numPendingSSLConns_ == 0);
438 state_ = State::kDone;
439 onConnectionsDrained();