apply all sockopts to listening sockets
[folly.git] / folly / wangle / acceptor / Acceptor.cpp
1 /*
2  *  Copyright (c) 2014, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree. An additional grant
7  *  of patent rights can be found in the PATENTS file in the same directory.
8  *
9  */
10 #include <folly/wangle/acceptor/Acceptor.h>
11
12 #include <folly/wangle/acceptor/ManagedConnection.h>
13 #include <folly/wangle/ssl/SSLContextManager.h>
14
15 #include <boost/cast.hpp>
16 #include <fcntl.h>
17 #include <folly/ScopeGuard.h>
18 #include <folly/io/async/EventBase.h>
19 #include <fstream>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <unistd.h>
25
26 using folly::wangle::ConnectionManager;
27 using folly::wangle::ManagedConnection;
28 using std::chrono::microseconds;
29 using std::chrono::milliseconds;
30 using std::filebuf;
31 using std::ifstream;
32 using std::ios;
33 using std::shared_ptr;
34 using std::string;
35
36 namespace folly {
37
38 #ifndef NO_LIB_GFLAGS
39 DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
40              "closing idle conns");
41 #else
42 const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
43 #endif
44
45 static const std::string empty_string;
46 std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
47
48 /**
49  * Lightweight wrapper class to keep track of a newly
50  * accepted connection during SSL handshaking.
51  */
52 class AcceptorHandshakeHelper :
53       public AsyncSSLSocket::HandshakeCB,
54       public ManagedConnection {
55  public:
56   AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
57                           Acceptor* acceptor,
58                           const SocketAddress& clientAddr,
59                           std::chrono::steady_clock::time_point acceptTime)
60     : socket_(std::move(socket)), acceptor_(acceptor),
61       acceptTime_(acceptTime), clientAddr_(clientAddr) {
62     acceptor_->downstreamConnectionManager_->addConnection(this, true);
63     if(acceptor_->parseClientHello_)  {
64       socket_->enableClientHelloParsing();
65     }
66     socket_->sslAccept(this);
67   }
68
69   virtual void timeoutExpired() noexcept {
70     VLOG(4) << "SSL handshake timeout expired";
71     sslError_ = SSLErrorEnum::TIMEOUT;
72     dropConnection();
73   }
74   virtual void describe(std::ostream& os) const {
75     os << "pending handshake on " << clientAddr_;
76   }
77   virtual bool isBusy() const {
78     return true;
79   }
80   virtual void notifyPendingShutdown() {}
81   virtual void closeWhenIdle() {}
82
83   virtual void dropConnection() {
84     VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
85     socket_->closeNow();
86   }
87   virtual void dumpConnectionState(uint8_t loglevel) {
88   }
89
90  private:
91   // AsyncSSLSocket::HandshakeCallback API
92   virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept {
93
94     const unsigned char* nextProto = nullptr;
95     unsigned nextProtoLength = 0;
96     sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
97     if (VLOG_IS_ON(3)) {
98       if (nextProto) {
99         VLOG(3) << "Client selected next protocol " <<
100             string((const char*)nextProto, nextProtoLength);
101       } else {
102         VLOG(3) << "Client did not select a next protocol";
103       }
104     }
105
106     // fill in SSL-related fields from TransportInfo
107     // the other fields like RTT are filled in the Acceptor
108     TransportInfo tinfo;
109     tinfo.ssl = true;
110     tinfo.acceptTime = acceptTime_;
111     tinfo.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
112     tinfo.sslSetupBytesRead = sock->getRawBytesReceived();
113     tinfo.sslSetupBytesWritten = sock->getRawBytesWritten();
114     tinfo.sslServerName = sock->getSSLServerName() ?
115       std::make_shared<std::string>(sock->getSSLServerName()) : nullptr;
116     tinfo.sslCipher = sock->getNegotiatedCipherName() ?
117       std::make_shared<std::string>(sock->getNegotiatedCipherName()) : nullptr;
118     tinfo.sslVersion = sock->getSSLVersion();
119     tinfo.sslCertSize = sock->getSSLCertSize();
120     tinfo.sslResume = SSLUtil::getResumeState(sock);
121     sock->getSSLClientCiphers(tinfo.sslClientCiphers);
122     sock->getSSLServerCiphers(tinfo.sslServerCiphers);
123     tinfo.sslClientComprMethods = sock->getSSLClientComprMethods();
124     tinfo.sslClientExts = sock->getSSLClientExts();
125     tinfo.sslNextProtocol.assign(
126         reinterpret_cast<const char*>(nextProto),
127         nextProtoLength);
128
129     acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR);
130     acceptor_->downstreamConnectionManager_->removeConnection(this);
131     acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
132         nextProto ? string((const char*)nextProto, nextProtoLength) :
133                                   empty_string, tinfo);
134     delete this;
135   }
136
137   virtual void handshakeErr(AsyncSSLSocket* sock,
138                             const AsyncSocketException& ex) noexcept {
139     auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
140     VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
141         " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
142         sock->getRawBytesWritten() << " bytes sent: " <<
143         ex.what();
144     acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
145     acceptor_->sslConnectionError();
146     delete this;
147   }
148
149   AsyncSSLSocket::UniquePtr socket_;
150   Acceptor* acceptor_;
151   std::chrono::steady_clock::time_point acceptTime_;
152   SocketAddress clientAddr_;
153   SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
154 };
155
156 Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
157   accConfig_(accConfig),
158   socketOptions_(accConfig.getSocketOptions()) {
159 }
160
161 void
162 Acceptor::init(AsyncServerSocket* serverSocket,
163                EventBase* eventBase) {
164   CHECK(nullptr == this->base_);
165
166   if (accConfig_.isSSL()) {
167     if (!sslCtxManager_) {
168       sslCtxManager_ = folly::make_unique<SSLContextManager>(
169         eventBase,
170         "vip_" + getName(),
171         accConfig_.strictSSL, nullptr);
172     }
173     for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
174       sslCtxManager_->addSSLContextConfig(
175         sslCtxConfig,
176         accConfig_.sslCacheOptions,
177         &accConfig_.initialTicketSeeds,
178         accConfig_.bindAddress,
179         cacheProvider_);
180       parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
181     }
182
183     CHECK(sslCtxManager_->getDefaultSSLCtx());
184   }
185
186   base_ = eventBase;
187   state_ = State::kRunning;
188   downstreamConnectionManager_ = ConnectionManager::makeUnique(
189     eventBase, accConfig_.connectionIdleTimeout, this);
190
191   if (serverSocket) {
192     serverSocket->addAcceptCallback(this, eventBase);
193
194     for (auto& fd : serverSocket->getSockets()) {
195       if (fd < 0) {
196         continue;
197       }
198       for (const auto& opt: socketOptions_) {
199         opt.first.apply(fd, opt.second);
200       }
201     }
202   }
203 }
204
205 Acceptor::~Acceptor(void) {
206 }
207
208 void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
209   sslCtxManager_->addSSLContextConfig(sslCtxConfig,
210                                       accConfig_.sslCacheOptions,
211                                       &accConfig_.initialTicketSeeds,
212                                       accConfig_.bindAddress,
213                                       cacheProvider_);
214 }
215
216 void
217 Acceptor::drainAllConnections() {
218   if (downstreamConnectionManager_) {
219     downstreamConnectionManager_->initiateGracefulShutdown(
220       std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
221   }
222 }
223
224 void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
225                        IConnectionCounter* counter) {
226   loadShedConfig_ = from;
227   connectionCounter_ = counter;
228 }
229
230 bool Acceptor::canAccept(const SocketAddress& address) {
231   if (!connectionCounter_) {
232     return true;
233   }
234
235   uint64_t maxConnections = connectionCounter_->getMaxConnections();
236   if (maxConnections == 0) {
237     return true;
238   }
239
240   uint64_t currentConnections = connectionCounter_->getNumConnections();
241   if (currentConnections < maxConnections) {
242     return true;
243   }
244
245   if (loadShedConfig_.isWhitelisted(address)) {
246     return true;
247   }
248
249   // Take care of comparing connection count against max connections across
250   // all acceptors. Expensive since a lock must be taken to get the counter.
251   auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
252   if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
253     return true;
254   }
255
256   VLOG(4) << address.describe() << " not whitelisted";
257   return false;
258 }
259
260 void
261 Acceptor::connectionAccepted(
262     int fd, const SocketAddress& clientAddr) noexcept {
263   if (!canAccept(clientAddr)) {
264     close(fd);
265     return;
266   }
267   auto acceptTime = std::chrono::steady_clock::now();
268   for (const auto& opt: socketOptions_) {
269     opt.first.apply(fd, opt.second);
270   }
271
272   onDoneAcceptingConnection(fd, clientAddr, acceptTime);
273 }
274
275 void Acceptor::onDoneAcceptingConnection(
276     int fd,
277     const SocketAddress& clientAddr,
278     std::chrono::steady_clock::time_point acceptTime) noexcept {
279   processEstablishedConnection(fd, clientAddr, acceptTime);
280 }
281
282 void
283 Acceptor::processEstablishedConnection(
284     int fd,
285     const SocketAddress& clientAddr,
286     std::chrono::steady_clock::time_point acceptTime) noexcept {
287   if (accConfig_.isSSL()) {
288     CHECK(sslCtxManager_);
289     AsyncSSLSocket::UniquePtr sslSock(
290       makeNewAsyncSSLSocket(
291         sslCtxManager_->getDefaultSSLCtx(), base_, fd));
292     ++numPendingSSLConns_;
293     ++totalNumPendingSSLConns_;
294     if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
295       VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
296         " too many handshakes in progress";
297       updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
298                      SSLErrorEnum::DROPPED);
299       sslConnectionError();
300       return;
301     }
302     new AcceptorHandshakeHelper(
303       std::move(sslSock), this, clientAddr, acceptTime);
304   } else {
305     TransportInfo tinfo;
306     tinfo.ssl = false;
307     tinfo.acceptTime = acceptTime;
308     AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
309     connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
310   }
311 }
312
313 void
314 Acceptor::connectionReady(
315     AsyncSocket::UniquePtr sock,
316     const SocketAddress& clientAddr,
317     const string& nextProtocolName,
318     TransportInfo& tinfo) {
319   // Limit the number of reads from the socket per poll loop iteration,
320   // both to keep memory usage under control and to prevent one fast-
321   // writing client from starving other connections.
322   sock->setMaxReadsPerEvent(16);
323   tinfo.initWithSocket(sock.get());
324   onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
325 }
326
327 void
328 Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
329                              const SocketAddress& clientAddr,
330                              const string& nextProtocol,
331                              TransportInfo& tinfo) {
332   CHECK(numPendingSSLConns_ > 0);
333   connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
334   --numPendingSSLConns_;
335   --totalNumPendingSSLConns_;
336   if (state_ == State::kDraining) {
337     checkDrained();
338   }
339 }
340
341 void
342 Acceptor::sslConnectionError() {
343   CHECK(numPendingSSLConns_ > 0);
344   --numPendingSSLConns_;
345   --totalNumPendingSSLConns_;
346   if (state_ == State::kDraining) {
347     checkDrained();
348   }
349 }
350
351 void
352 Acceptor::acceptError(const std::exception& ex) noexcept {
353   // An error occurred.
354   // The most likely error is out of FDs.  AsyncServerSocket will back off
355   // briefly if we are out of FDs, then continue accepting later.
356   // Just log a message here.
357   LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
358 }
359
360 void
361 Acceptor::acceptStopped() noexcept {
362   VLOG(3) << "Acceptor " << this << " acceptStopped()";
363   // Drain the open client connections
364   drainAllConnections();
365
366   // If we haven't yet finished draining, begin doing so by marking ourselves
367   // as in the draining state. We must be sure to hit checkDrained() here, as
368   // if we're completely idle, we can should consider ourself drained
369   // immediately (as there is no outstanding work to complete to cause us to
370   // re-evaluate this).
371   if (state_ != State::kDone) {
372     state_ = State::kDraining;
373     checkDrained();
374   }
375 }
376
377 void
378 Acceptor::onEmpty(const ConnectionManager& cm) {
379   VLOG(3) << "Acceptor=" << this << " onEmpty()";
380   if (state_ == State::kDraining) {
381     checkDrained();
382   }
383 }
384
385 void
386 Acceptor::checkDrained() {
387   CHECK(state_ == State::kDraining);
388   if (forceShutdownInProgress_ ||
389       (downstreamConnectionManager_->getNumConnections() != 0) ||
390       (numPendingSSLConns_ != 0)) {
391     return;
392   }
393
394   VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
395           << base_;
396
397   downstreamConnectionManager_.reset();
398
399   state_ = State::kDone;
400
401   onConnectionsDrained();
402 }
403
404 milliseconds
405 Acceptor::getConnTimeout() const {
406   return accConfig_.connectionIdleTimeout;
407 }
408
409 void Acceptor::addConnection(ManagedConnection* conn) {
410   // Add the socket to the timeout manager so that it can be cleaned
411   // up after being left idle for a long time.
412   downstreamConnectionManager_->addConnection(conn, true);
413 }
414
415 void
416 Acceptor::forceStop() {
417   base_->runInEventBaseThread([&] { dropAllConnections(); });
418 }
419
420 void
421 Acceptor::dropAllConnections() {
422   if (downstreamConnectionManager_) {
423     VLOG(3) << "Dropping all connections from Acceptor=" << this <<
424       " in thread " << base_;
425     assert(base_->isInEventBaseThread());
426     forceShutdownInProgress_ = true;
427     downstreamConnectionManager_->dropAllConnections();
428     CHECK(downstreamConnectionManager_->getNumConnections() == 0);
429     downstreamConnectionManager_.reset();
430   }
431   CHECK(numPendingSSLConns_ == 0);
432
433   state_ = State::kDone;
434   onConnectionsDrained();
435 }
436
437 } // namespace