remove always true if-predicate (gcc-5 -Wlogical-op)
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/noncopyable.hpp>
23 #include <errno.h>
24 #include <fcntl.h>
25 #include <openssl/err.h>
26 #include <openssl/asn1.h>
27 #include <openssl/ssl.h>
28 #include <sys/types.h>
29 #include <chrono>
30
31 #include <folly/Bits.h>
32 #include <folly/SocketAddress.h>
33 #include <folly/SpinLock.h>
34 #include <folly/io/Cursor.h>
35 #include <folly/io/IOBuf.h>
36 #include <folly/portability/OpenSSL.h>
37 #include <folly/portability/Unistd.h>
38
39 using folly::SocketAddress;
40 using folly::SSLContext;
41 using std::string;
42 using std::shared_ptr;
43
44 using folly::Endian;
45 using folly::IOBuf;
46 using folly::SpinLock;
47 using folly::SpinLockGuard;
48 using folly::io::Cursor;
49 using std::unique_ptr;
50 using std::bind;
51
52 namespace {
53 using folly::AsyncSocket;
54 using folly::AsyncSocketException;
55 using folly::AsyncSSLSocket;
56 using folly::Optional;
57 using folly::SSLContext;
58 // For OpenSSL portability API
59 using namespace folly::ssl;
60 using folly::ssl::OpenSSLUtils;
61
62 // We have one single dummy SSL context so that we can implement attach
63 // and detach methods in a thread safe fashion without modifying opnessl.
64 static SSLContext *dummyCtx = nullptr;
65 static SpinLock dummyCtxLock;
66
67 // If given min write size is less than this, buffer will be allocated on
68 // stack, otherwise it is allocated on heap
69 const size_t MAX_STACK_BUF_SIZE = 2048;
70
71 // This converts "illegal" shutdowns into ZERO_RETURN
72 inline bool zero_return(int error, int rc) {
73   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
74 }
75
76 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
77                                 public AsyncSSLSocket::HandshakeCB {
78
79  private:
80   AsyncSSLSocket *sslSocket_;
81   AsyncSSLSocket::ConnectCallback *callback_;
82   int timeout_;
83   int64_t startTime_;
84
85  protected:
86   ~AsyncSSLSocketConnector() override {}
87
88  public:
89   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
90                            AsyncSocket::ConnectCallback *callback,
91                            int timeout) :
92       sslSocket_(sslSocket),
93       callback_(callback),
94       timeout_(timeout),
95       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
96                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
97   }
98
99   void connectSuccess() noexcept override {
100     VLOG(7) << "client socket connected";
101
102     int64_t timeoutLeft = 0;
103     if (timeout_ > 0) {
104       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
105         std::chrono::steady_clock::now().time_since_epoch()).count();
106
107       timeoutLeft = timeout_ - (curTime - startTime_);
108       if (timeoutLeft <= 0) {
109         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
110                                 "SSL connect timed out");
111         fail(ex);
112         delete this;
113         return;
114       }
115     }
116     sslSocket_->sslConn(this, std::chrono::milliseconds(timeoutLeft));
117   }
118
119   void connectErr(const AsyncSocketException& ex) noexcept override {
120     VLOG(1) << "TCP connect failed: " << ex.what();
121     fail(ex);
122     delete this;
123   }
124
125   void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
126     VLOG(7) << "client handshake success";
127     if (callback_) {
128       callback_->connectSuccess();
129     }
130     delete this;
131   }
132
133   void handshakeErr(AsyncSSLSocket* /* socket */,
134                     const AsyncSocketException& ex) noexcept override {
135     VLOG(1) << "client handshakeErr: " << ex.what();
136     fail(ex);
137     delete this;
138   }
139
140   void fail(const AsyncSocketException &ex) {
141     // fail is a noop if called twice
142     if (callback_) {
143       AsyncSSLSocket::ConnectCallback *cb = callback_;
144       callback_ = nullptr;
145
146       cb->connectErr(ex);
147       sslSocket_->closeNow();
148       // closeNow can call handshakeErr if it hasn't been called already.
149       // So this may have been deleted, no member variable access beyond this
150       // point
151       // Note that closeNow may invoke writeError callbacks if the socket had
152       // write data pending connection completion.
153     }
154   }
155 };
156
157 void setup_SSL_CTX(SSL_CTX *ctx) {
158 #ifdef SSL_MODE_RELEASE_BUFFERS
159   SSL_CTX_set_mode(ctx,
160                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
161                    SSL_MODE_ENABLE_PARTIAL_WRITE
162                    | SSL_MODE_RELEASE_BUFFERS
163                    );
164 #else
165   SSL_CTX_set_mode(ctx,
166                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
167                    SSL_MODE_ENABLE_PARTIAL_WRITE
168                    );
169 #endif
170 // SSL_CTX_set_mode is a Macro
171 #ifdef SSL_MODE_WRITE_IOVEC
172   SSL_CTX_set_mode(ctx,
173                    SSL_CTX_get_mode(ctx)
174                    | SSL_MODE_WRITE_IOVEC);
175 #endif
176
177 }
178
179 BIO_METHOD sslWriteBioMethod;
180
181 void* initsslWriteBioMethod(void) {
182   memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
183   // override the bwrite method for MSG_EOR support
184   OpenSSLUtils::setCustomBioWriteMethod(
185       &sslWriteBioMethod, AsyncSSLSocket::bioWrite);
186
187   // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
188   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
189   // then have specific handlings. The sslWriteBioWrite should be compatible
190   // with the one in openssl.
191
192   // Return something here to enable AsyncSSLSocket to call this method using
193   // a function-scoped static.
194   return nullptr;
195 }
196
197 } // anonymous namespace
198
199 namespace folly {
200
201 /**
202  * Create a client AsyncSSLSocket
203  */
204 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
205                                EventBase* evb, bool deferSecurityNegotiation) :
206     AsyncSocket(evb),
207     ctx_(ctx),
208     handshakeTimeout_(this, evb),
209     connectionTimeout_(this, evb) {
210   init();
211   if (deferSecurityNegotiation) {
212     sslState_ = STATE_UNENCRYPTED;
213   }
214 }
215
216 /**
217  * Create a server/client AsyncSSLSocket
218  */
219 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
220                                EventBase* evb, int fd, bool server,
221                                bool deferSecurityNegotiation) :
222     AsyncSocket(evb, fd),
223     server_(server),
224     ctx_(ctx),
225     handshakeTimeout_(this, evb),
226     connectionTimeout_(this, evb) {
227   init();
228   if (server) {
229     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
230                               AsyncSSLSocket::sslInfoCallback);
231   }
232   if (deferSecurityNegotiation) {
233     sslState_ = STATE_UNENCRYPTED;
234   }
235 }
236
237 #if FOLLY_OPENSSL_HAS_SNI
238 /**
239  * Create a client AsyncSSLSocket and allow tlsext_hostname
240  * to be sent in Client Hello.
241  */
242 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
243                                  EventBase* evb,
244                                const std::string& serverName,
245                                bool deferSecurityNegotiation) :
246     AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
247   tlsextHostname_ = serverName;
248 }
249
250 /**
251  * Create a client AsyncSSLSocket from an already connected fd
252  * and allow tlsext_hostname to be sent in Client Hello.
253  */
254 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
255                                  EventBase* evb, int fd,
256                                const std::string& serverName,
257                                bool deferSecurityNegotiation) :
258     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
259   tlsextHostname_ = serverName;
260 }
261 #endif // FOLLY_OPENSSL_HAS_SNI
262
263 AsyncSSLSocket::~AsyncSSLSocket() {
264   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
265           << ", evb=" << eventBase_ << ", fd=" << fd_
266           << ", state=" << int(state_) << ", sslState="
267           << sslState_ << ", events=" << eventFlags_ << ")";
268 }
269
270 void AsyncSSLSocket::init() {
271   // Do this here to ensure we initialize this once before any use of
272   // AsyncSSLSocket instances and not as part of library load.
273   static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
274   (void)sslWriteBioMethodInitializer;
275
276   setup_SSL_CTX(ctx_->getSSLCtx());
277 }
278
279 void AsyncSSLSocket::closeNow() {
280   // Close the SSL connection.
281   if (ssl_ != nullptr && fd_ != -1) {
282     int rc = SSL_shutdown(ssl_);
283     if (rc == 0) {
284       rc = SSL_shutdown(ssl_);
285     }
286     if (rc < 0) {
287       ERR_clear_error();
288     }
289   }
290
291   if (sslSession_ != nullptr) {
292     SSL_SESSION_free(sslSession_);
293     sslSession_ = nullptr;
294   }
295
296   sslState_ = STATE_CLOSED;
297
298   if (handshakeTimeout_.isScheduled()) {
299     handshakeTimeout_.cancelTimeout();
300   }
301
302   DestructorGuard dg(this);
303
304   invokeHandshakeErr(
305       AsyncSocketException(
306         AsyncSocketException::END_OF_FILE,
307         "SSL connection closed locally"));
308
309   if (ssl_ != nullptr) {
310     SSL_free(ssl_);
311     ssl_ = nullptr;
312   }
313
314   // Close the socket.
315   AsyncSocket::closeNow();
316 }
317
318 void AsyncSSLSocket::shutdownWrite() {
319   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
320   //
321   // (Performing a full shutdown here is more desirable than doing nothing at
322   // all.  The purpose of shutdownWrite() is normally to notify the other end
323   // of the connection that no more data will be sent.  If we do nothing, the
324   // other end will never know that no more data is coming, and this may result
325   // in protocol deadlock.)
326   close();
327 }
328
329 void AsyncSSLSocket::shutdownWriteNow() {
330   closeNow();
331 }
332
333 bool AsyncSSLSocket::good() const {
334   return (AsyncSocket::good() &&
335           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
336            sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
337 }
338
339 // The TAsyncTransport definition of 'good' states that the transport is
340 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
341 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
342 // is connected but we haven't initiated the call to SSL_connect.
343 bool AsyncSSLSocket::connecting() const {
344   return (!server_ &&
345           (AsyncSocket::connecting() ||
346            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
347                                      sslState_ == STATE_CONNECTING))));
348 }
349
350 std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
351   const unsigned char* protoName = nullptr;
352   unsigned protoLength;
353   if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
354     return std::string(reinterpret_cast<const char*>(protoName), protoLength);
355   }
356   return "";
357 }
358
359 bool AsyncSSLSocket::isEorTrackingEnabled() const {
360   return trackEor_;
361 }
362
363 void AsyncSSLSocket::setEorTracking(bool track) {
364   if (trackEor_ != track) {
365     trackEor_ = track;
366     appEorByteNo_ = 0;
367     minEorRawByteNo_ = 0;
368   }
369 }
370
371 size_t AsyncSSLSocket::getRawBytesWritten() const {
372   // The bio(s) in the write path are in a chain
373   // each bio flushes to the next and finally written into the socket
374   // to get the rawBytesWritten on the socket,
375   // get the write bytes of the last bio
376   BIO *b;
377   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
378     return 0;
379   }
380   BIO* next = BIO_next(b);
381   while (next != NULL) {
382     b = next;
383     next = BIO_next(b);
384   }
385
386   return BIO_number_written(b);
387 }
388
389 size_t AsyncSSLSocket::getRawBytesReceived() const {
390   BIO *b;
391   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
392     return 0;
393   }
394
395   return BIO_number_read(b);
396 }
397
398
399 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
400   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
401              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
402              << "events=" << eventFlags_ << ", server=" << short(server_)
403              << "): " << "sslAccept/Connect() called in invalid "
404              << "state, handshake callback " << handshakeCallback_
405              << ", new callback " << callback;
406   assert(!handshakeTimeout_.isScheduled());
407   sslState_ = STATE_ERROR;
408
409   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
410                          "sslAccept() called with socket in invalid state");
411
412   handshakeEndTime_ = std::chrono::steady_clock::now();
413   if (callback) {
414     callback->handshakeErr(this, ex);
415   }
416
417   failHandshake(__func__, ex);
418 }
419
420 void AsyncSSLSocket::sslAccept(
421     HandshakeCB* callback,
422     std::chrono::milliseconds timeout,
423     const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
424   DestructorGuard dg(this);
425   assert(eventBase_->isInEventBaseThread());
426   verifyPeer_ = verifyPeer;
427
428   // Make sure we're in the uninitialized state
429   if (!server_ || (sslState_ != STATE_UNINIT &&
430                    sslState_ != STATE_UNENCRYPTED) ||
431       handshakeCallback_ != nullptr) {
432     return invalidState(callback);
433   }
434
435   // Cache local and remote socket addresses to keep them available
436   // after socket file descriptor is closed.
437   if (cacheAddrOnFailure_ && -1 != getFd()) {
438     cacheLocalPeerAddr();
439   }
440
441   handshakeStartTime_ = std::chrono::steady_clock::now();
442   // Make end time at least >= start time.
443   handshakeEndTime_ = handshakeStartTime_;
444
445   sslState_ = STATE_ACCEPTING;
446   handshakeCallback_ = callback;
447
448   if (timeout > std::chrono::milliseconds::zero()) {
449     handshakeTimeout_.scheduleTimeout(timeout);
450   }
451
452   /* register for a read operation (waiting for CLIENT HELLO) */
453   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
454 }
455
456 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
457 void AsyncSSLSocket::attachSSLContext(
458   const std::shared_ptr<SSLContext>& ctx) {
459
460   // Check to ensure we are in client mode. Changing a server's ssl
461   // context doesn't make sense since clients of that server would likely
462   // become confused when the server's context changes.
463   DCHECK(!server_);
464   DCHECK(!ctx_);
465   DCHECK(ctx);
466   DCHECK(ctx->getSSLCtx());
467   ctx_ = ctx;
468
469   // It's possible this could be attached before ssl_ is set up
470   if (!ssl_) {
471     return;
472   }
473
474   // In order to call attachSSLContext, detachSSLContext must have been
475   // previously called.
476   // We need to update the initial_ctx if necessary
477   auto sslCtx = ctx->getSSLCtx();
478   SSL_CTX_up_ref(sslCtx);
479 #ifndef OPENSSL_NO_TLSEXT
480   // note that detachSSLContext has already freed ssl_->initial_ctx
481   ssl_->initial_ctx = sslCtx;
482 #endif
483   // Detach sets the socket's context to the dummy context. Thus we must acquire
484   // this lock.
485   SpinLockGuard guard(dummyCtxLock);
486   SSL_set_SSL_CTX(ssl_, sslCtx);
487 }
488
489 void AsyncSSLSocket::detachSSLContext() {
490   DCHECK(ctx_);
491   ctx_.reset();
492   // It's possible for this to be called before ssl_ has been
493   // set up
494   if (!ssl_) {
495     return;
496   }
497 // Detach the initial_ctx as well.  Internally w/ OPENSSL_NO_TLSEXT
498 // it is used for session info.  It will be reattached in attachSSLContext
499 #ifndef OPENSSL_NO_TLSEXT
500   if (ssl_->initial_ctx) {
501     SSL_CTX_free(ssl_->initial_ctx);
502     ssl_->initial_ctx = nullptr;
503   }
504 #endif
505   SpinLockGuard guard(dummyCtxLock);
506   if (nullptr == dummyCtx) {
507     // We need to lazily initialize the dummy context so we don't
508     // accidentally override any programmatic settings to openssl
509     dummyCtx = new SSLContext;
510   }
511   // We must remove this socket's references to its context right now
512   // since this socket could get passed to any thread. If the context has
513   // had its locking disabled, just doing a set in attachSSLContext()
514   // would not be thread safe.
515   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
516 }
517 #endif
518
519 #if FOLLY_OPENSSL_HAS_SNI
520 void AsyncSSLSocket::switchServerSSLContext(
521   const std::shared_ptr<SSLContext>& handshakeCtx) {
522   CHECK(server_);
523   if (sslState_ != STATE_ACCEPTING) {
524     // We log it here and allow the switch.
525     // It should not affect our re-negotiation support (which
526     // is not supported now).
527     VLOG(6) << "fd=" << getFd()
528             << " renegotation detected when switching SSL_CTX";
529   }
530
531   setup_SSL_CTX(handshakeCtx->getSSLCtx());
532   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
533                             AsyncSSLSocket::sslInfoCallback);
534   handshakeCtx_ = handshakeCtx;
535   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
536 }
537
538 bool AsyncSSLSocket::isServerNameMatch() const {
539   CHECK(!server_);
540
541   if (!ssl_) {
542     return false;
543   }
544
545   SSL_SESSION *ss = SSL_get_session(ssl_);
546   if (!ss) {
547     return false;
548   }
549
550   if(!ss->tlsext_hostname) {
551     return false;
552   }
553   return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
554 }
555
556 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
557   tlsextHostname_ = std::move(serverName);
558 }
559
560 #endif // FOLLY_OPENSSL_HAS_SNI
561
562 void AsyncSSLSocket::timeoutExpired() noexcept {
563   if (state_ == StateEnum::ESTABLISHED &&
564       (sslState_ == STATE_CACHE_LOOKUP ||
565        sslState_ == STATE_ASYNC_PENDING)) {
566     sslState_ = STATE_ERROR;
567     // We are expecting a callback in restartSSLAccept.  The cache lookup
568     // and rsa-call necessarily have pointers to this ssl socket, so delay
569     // the cleanup until he calls us back.
570   } else if (state_ == StateEnum::CONNECTING) {
571     assert(sslState_ == STATE_CONNECTING);
572     DestructorGuard dg(this);
573     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
574                            "Fallback connect timed out during TFO");
575     failHandshake(__func__, ex);
576   } else {
577     assert(state_ == StateEnum::ESTABLISHED &&
578            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
579     DestructorGuard dg(this);
580     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
581                            (sslState_ == STATE_CONNECTING) ?
582                            "SSL connect timed out" : "SSL accept timed out");
583     failHandshake(__func__, ex);
584   }
585 }
586
587 int AsyncSSLSocket::getSSLExDataIndex() {
588   static auto index = SSL_get_ex_new_index(
589       0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
590   return index;
591 }
592
593 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
594   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
595       getSSLExDataIndex()));
596 }
597
598 void AsyncSSLSocket::failHandshake(const char* /* fn */,
599                                    const AsyncSocketException& ex) {
600   startFail();
601   if (handshakeTimeout_.isScheduled()) {
602     handshakeTimeout_.cancelTimeout();
603   }
604   invokeHandshakeErr(ex);
605   finishFail();
606 }
607
608 void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
609   handshakeEndTime_ = std::chrono::steady_clock::now();
610   if (handshakeCallback_ != nullptr) {
611     HandshakeCB* callback = handshakeCallback_;
612     handshakeCallback_ = nullptr;
613     callback->handshakeErr(this, ex);
614   }
615 }
616
617 void AsyncSSLSocket::invokeHandshakeCB() {
618   handshakeEndTime_ = std::chrono::steady_clock::now();
619   if (handshakeTimeout_.isScheduled()) {
620     handshakeTimeout_.cancelTimeout();
621   }
622   if (handshakeCallback_) {
623     HandshakeCB* callback = handshakeCallback_;
624     handshakeCallback_ = nullptr;
625     callback->handshakeSuc(this);
626   }
627 }
628
629 void AsyncSSLSocket::cacheLocalPeerAddr() {
630   SocketAddress address;
631   try {
632     getLocalAddress(&address);
633     getPeerAddress(&address);
634   } catch (const std::system_error& e) {
635     // The handle can be still valid while the connection is already closed.
636     if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
637       throw;
638     }
639   }
640 }
641
642 void AsyncSSLSocket::connect(ConnectCallback* callback,
643                               const folly::SocketAddress& address,
644                               int timeout,
645                               const OptionMap &options,
646                               const folly::SocketAddress& bindAddr)
647                               noexcept {
648   assert(!server_);
649   assert(state_ == StateEnum::UNINIT);
650   assert(sslState_ == STATE_UNINIT);
651   AsyncSSLSocketConnector *connector =
652     new AsyncSSLSocketConnector(this, callback, timeout);
653   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
654 }
655
656 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
657   // apply the settings specified in verifyPeer_
658   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
659     if(ctx_->needsPeerVerification()) {
660       SSL_set_verify(ssl, ctx_->getVerificationMode(),
661         AsyncSSLSocket::sslVerifyCallback);
662     }
663   } else {
664     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
665         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
666       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
667         AsyncSSLSocket::sslVerifyCallback);
668     }
669   }
670 }
671
672 bool AsyncSSLSocket::setupSSLBio() {
673   auto wb = BIO_new(&sslWriteBioMethod);
674
675   if (!wb) {
676     return false;
677   }
678
679   OpenSSLUtils::setBioAppData(wb, this);
680   OpenSSLUtils::setBioFd(wb, fd_, BIO_NOCLOSE);
681   SSL_set_bio(ssl_, wb, wb);
682   return true;
683 }
684
685 void AsyncSSLSocket::sslConn(
686     HandshakeCB* callback,
687     std::chrono::milliseconds timeout,
688     const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
689   DestructorGuard dg(this);
690   assert(eventBase_->isInEventBaseThread());
691
692   // Cache local and remote socket addresses to keep them available
693   // after socket file descriptor is closed.
694   if (cacheAddrOnFailure_ && -1 != getFd()) {
695     cacheLocalPeerAddr();
696   }
697
698   verifyPeer_ = verifyPeer;
699
700   // Make sure we're in the uninitialized state
701   if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
702                   STATE_UNENCRYPTED) ||
703       handshakeCallback_ != nullptr) {
704     return invalidState(callback);
705   }
706
707   sslState_ = STATE_CONNECTING;
708   handshakeCallback_ = callback;
709
710   try {
711     ssl_ = ctx_->createSSL();
712   } catch (std::exception &e) {
713     sslState_ = STATE_ERROR;
714     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
715                            "error calling SSLContext::createSSL()");
716     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
717             << fd_ << "): " << e.what();
718     return failHandshake(__func__, ex);
719   }
720
721   if (!setupSSLBio()) {
722     sslState_ = STATE_ERROR;
723     AsyncSocketException ex(
724         AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
725     return failHandshake(__func__, ex);
726   }
727
728   applyVerificationOptions(ssl_);
729
730   if (sslSession_ != nullptr) {
731     sessionResumptionAttempted_ = true;
732     SSL_set_session(ssl_, sslSession_);
733     SSL_SESSION_free(sslSession_);
734     sslSession_ = nullptr;
735   }
736 #if FOLLY_OPENSSL_HAS_SNI
737   if (tlsextHostname_.size()) {
738     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
739   }
740 #endif
741
742   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
743
744   handshakeConnectTimeout_ = timeout;
745   startSSLConnect();
746 }
747
748 // This could be called multiple times, during normal ssl connections
749 // and after TFO fallback.
750 void AsyncSSLSocket::startSSLConnect() {
751   handshakeStartTime_ = std::chrono::steady_clock::now();
752   // Make end time at least >= start time.
753   handshakeEndTime_ = handshakeStartTime_;
754   if (handshakeConnectTimeout_ > std::chrono::milliseconds::zero()) {
755     handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_);
756   }
757   handleConnect();
758 }
759
760 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
761   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
762     return SSL_get1_session(ssl_);
763   }
764
765   return sslSession_;
766 }
767
768 const SSL* AsyncSSLSocket::getSSL() const {
769   return ssl_;
770 }
771
772 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
773   sslSession_ = session;
774   if (!takeOwnership && session != nullptr) {
775     // Increment the reference count
776     // This API exists in BoringSSL and OpenSSL 1.1.0
777     SSL_SESSION_up_ref(session);
778   }
779 }
780
781 void AsyncSSLSocket::getSelectedNextProtocol(
782     const unsigned char** protoName,
783     unsigned* protoLen,
784     SSLContext::NextProtocolType* protoType) const {
785   if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
786     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
787                               "NPN not supported");
788   }
789 }
790
791 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
792     const unsigned char** protoName,
793     unsigned* protoLen,
794     SSLContext::NextProtocolType* protoType) const {
795   *protoName = nullptr;
796   *protoLen = 0;
797 #if FOLLY_OPENSSL_HAS_ALPN
798   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
799   if (*protoLen > 0) {
800     if (protoType) {
801       *protoType = SSLContext::NextProtocolType::ALPN;
802     }
803     return true;
804   }
805 #endif
806 #ifdef OPENSSL_NPN_NEGOTIATED
807   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
808   if (protoType) {
809     *protoType = SSLContext::NextProtocolType::NPN;
810   }
811   return true;
812 #else
813   (void)protoType;
814   return false;
815 #endif
816 }
817
818 bool AsyncSSLSocket::getSSLSessionReused() const {
819   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
820     return SSL_session_reused(ssl_);
821   }
822   return false;
823 }
824
825 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
826   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
827 }
828
829 /* static */
830 const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
831   if (ssl == nullptr) {
832     return nullptr;
833   }
834 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
835   return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
836 #else
837   return nullptr;
838 #endif
839 }
840
841 const char *AsyncSSLSocket::getSSLServerName() const {
842 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
843   return getSSLServerNameFromSSL(ssl_);
844 #else
845   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
846                              "SNI not supported");
847 #endif
848 }
849
850 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
851   return getSSLServerNameFromSSL(ssl_);
852 }
853
854 int AsyncSSLSocket::getSSLVersion() const {
855   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
856 }
857
858 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
859   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
860   if (cert) {
861     int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
862     return OBJ_nid2ln(nid);
863   }
864   return nullptr;
865 }
866
867 int AsyncSSLSocket::getSSLCertSize() const {
868   int certSize = 0;
869   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
870   if (cert) {
871     EVP_PKEY *key = X509_get_pubkey(cert);
872     certSize = EVP_PKEY_bits(key);
873     EVP_PKEY_free(key);
874   }
875   return certSize;
876 }
877
878 const X509* AsyncSSLSocket::getSelfCert() const {
879   return (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
880 }
881
882 bool AsyncSSLSocket::willBlock(int ret,
883                                int* sslErrorOut,
884                                unsigned long* errErrorOut) noexcept {
885   *errErrorOut = 0;
886   int error = *sslErrorOut = SSL_get_error(ssl_, ret);
887   if (error == SSL_ERROR_WANT_READ) {
888     // Register for read event if not already.
889     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
890     return true;
891   } else if (error == SSL_ERROR_WANT_WRITE) {
892     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
893             << ", state=" << int(state_) << ", sslState="
894             << sslState_ << ", events=" << eventFlags_ << "): "
895             << "SSL_ERROR_WANT_WRITE";
896     // Register for write event if not already.
897     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
898     return true;
899 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
900   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
901     // We will block but we can't register our own socket.  The callback that
902     // triggered this code will re-call handleAccept at the appropriate time.
903
904     // We can only get here if the linked libssl.so has support for this feature
905     // as well, otherwise SSL_get_error cannot return our error code.
906     sslState_ = STATE_CACHE_LOOKUP;
907
908     // Unregister for all events while blocked here
909     updateEventRegistration(EventHandler::NONE,
910                             EventHandler::READ | EventHandler::WRITE);
911
912     // The timeout (if set) keeps running here
913     return true;
914 #endif
915   } else if (0
916 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
917       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
918 #endif
919 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
920       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
921 #endif
922       ) {
923     // Our custom openssl function has kicked off an async request to do
924     // rsa/ecdsa private key operation.  When that call returns, a callback will
925     // be invoked that will re-call handleAccept.
926     sslState_ = STATE_ASYNC_PENDING;
927
928     // Unregister for all events while blocked here
929     updateEventRegistration(
930       EventHandler::NONE,
931       EventHandler::READ | EventHandler::WRITE
932     );
933
934     // The timeout (if set) keeps running here
935     return true;
936   } else {
937     unsigned long lastError = *errErrorOut = ERR_get_error();
938     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
939             << "state=" << state_ << ", "
940             << "sslState=" << sslState_ << ", "
941             << "events=" << std::hex << eventFlags_ << "): "
942             << "SSL error: " << error << ", "
943             << "errno: " << errno << ", "
944             << "ret: " << ret << ", "
945             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
946             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
947             << "func: " << ERR_func_error_string(lastError) << ", "
948             << "reason: " << ERR_reason_error_string(lastError);
949     return false;
950   }
951 }
952
953 void AsyncSSLSocket::checkForImmediateRead() noexcept {
954   // openssl may have buffered data that it read from the socket already.
955   // In this case we have to process it immediately, rather than waiting for
956   // the socket to become readable again.
957   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
958     AsyncSocket::handleRead();
959   }
960 }
961
962 void
963 AsyncSSLSocket::restartSSLAccept()
964 {
965   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this
966           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
967           << "sslState=" << sslState_ << ", events=" << eventFlags_;
968   DestructorGuard dg(this);
969   assert(
970     sslState_ == STATE_CACHE_LOOKUP ||
971     sslState_ == STATE_ASYNC_PENDING ||
972     sslState_ == STATE_ERROR ||
973     sslState_ == STATE_CLOSED);
974   if (sslState_ == STATE_CLOSED) {
975     // I sure hope whoever closed this socket didn't delete it already,
976     // but this is not strictly speaking an error
977     return;
978   }
979   if (sslState_ == STATE_ERROR) {
980     // go straight to fail if timeout expired during lookup
981     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
982                            "SSL accept timed out");
983     failHandshake(__func__, ex);
984     return;
985   }
986   sslState_ = STATE_ACCEPTING;
987   this->handleAccept();
988 }
989
990 void
991 AsyncSSLSocket::handleAccept() noexcept {
992   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
993           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
994           << "sslState=" << sslState_ << ", events=" << eventFlags_;
995   assert(server_);
996   assert(state_ == StateEnum::ESTABLISHED &&
997          sslState_ == STATE_ACCEPTING);
998   if (!ssl_) {
999     /* lazily create the SSL structure */
1000     try {
1001       ssl_ = ctx_->createSSL();
1002     } catch (std::exception &e) {
1003       sslState_ = STATE_ERROR;
1004       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1005                              "error calling SSLContext::createSSL()");
1006       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
1007                  << ", fd=" << fd_ << "): " << e.what();
1008       return failHandshake(__func__, ex);
1009     }
1010
1011     if (!setupSSLBio()) {
1012       sslState_ = STATE_ERROR;
1013       AsyncSocketException ex(
1014           AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
1015       return failHandshake(__func__, ex);
1016     }
1017
1018     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
1019
1020     applyVerificationOptions(ssl_);
1021   }
1022
1023   if (server_ && parseClientHello_) {
1024     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
1025     SSL_set_msg_callback_arg(ssl_, this);
1026   }
1027
1028   int ret = SSL_accept(ssl_);
1029   if (ret <= 0) {
1030     int sslError;
1031     unsigned long errError;
1032     int errnoCopy = errno;
1033     if (willBlock(ret, &sslError, &errError)) {
1034       return;
1035     } else {
1036       sslState_ = STATE_ERROR;
1037       SSLException ex(sslError, errError, ret, errnoCopy);
1038       return failHandshake(__func__, ex);
1039     }
1040   }
1041
1042   handshakeComplete_ = true;
1043   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1044
1045   // Move into STATE_ESTABLISHED in the normal case that we are in
1046   // STATE_ACCEPTING.
1047   sslState_ = STATE_ESTABLISHED;
1048
1049   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
1050           << " successfully accepted; state=" << int(state_)
1051           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
1052
1053   // Remember the EventBase we are attached to, before we start invoking any
1054   // callbacks (since the callbacks may call detachEventBase()).
1055   EventBase* originalEventBase = eventBase_;
1056
1057   // Call the accept callback.
1058   invokeHandshakeCB();
1059
1060   // Note that the accept callback may have changed our state.
1061   // (set or unset the read callback, called write(), closed the socket, etc.)
1062   // The following code needs to handle these situations correctly.
1063   //
1064   // If the socket has been closed, readCallback_ and writeReqHead_ will
1065   // always be nullptr, so that will prevent us from trying to read or write.
1066   //
1067   // The main thing to check for is if eventBase_ is still originalEventBase.
1068   // If not, we have been detached from this event base, so we shouldn't
1069   // perform any more operations.
1070   if (eventBase_ != originalEventBase) {
1071     return;
1072   }
1073
1074   AsyncSocket::handleInitialReadWrite();
1075 }
1076
1077 void
1078 AsyncSSLSocket::handleConnect() noexcept {
1079   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
1080           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1081           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1082   assert(!server_);
1083   if (state_ < StateEnum::ESTABLISHED) {
1084     return AsyncSocket::handleConnect();
1085   }
1086
1087   assert(
1088       (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
1089       sslState_ == STATE_CONNECTING);
1090   assert(ssl_);
1091
1092   auto originalState = state_;
1093   int ret = SSL_connect(ssl_);
1094   if (ret <= 0) {
1095     int sslError;
1096     unsigned long errError;
1097     int errnoCopy = errno;
1098     if (willBlock(ret, &sslError, &errError)) {
1099       // We fell back to connecting state due to TFO
1100       if (state_ == StateEnum::CONNECTING) {
1101         DCHECK_EQ(StateEnum::FAST_OPEN, originalState);
1102         if (handshakeTimeout_.isScheduled()) {
1103           handshakeTimeout_.cancelTimeout();
1104         }
1105       }
1106       return;
1107     } else {
1108       sslState_ = STATE_ERROR;
1109       SSLException ex(sslError, errError, ret, errnoCopy);
1110       return failHandshake(__func__, ex);
1111     }
1112   }
1113
1114   handshakeComplete_ = true;
1115   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1116
1117   // Move into STATE_ESTABLISHED in the normal case that we are in
1118   // STATE_CONNECTING.
1119   sslState_ = STATE_ESTABLISHED;
1120
1121   VLOG(3) << "AsyncSSLSocket " << this << ": "
1122           << "fd " << fd_ << " successfully connected; "
1123           << "state=" << int(state_) << ", sslState=" << sslState_
1124           << ", events=" << eventFlags_;
1125
1126   // Remember the EventBase we are attached to, before we start invoking any
1127   // callbacks (since the callbacks may call detachEventBase()).
1128   EventBase* originalEventBase = eventBase_;
1129
1130   // Call the handshake callback.
1131   invokeHandshakeCB();
1132
1133   // Note that the connect callback may have changed our state.
1134   // (set or unset the read callback, called write(), closed the socket, etc.)
1135   // The following code needs to handle these situations correctly.
1136   //
1137   // If the socket has been closed, readCallback_ and writeReqHead_ will
1138   // always be nullptr, so that will prevent us from trying to read or write.
1139   //
1140   // The main thing to check for is if eventBase_ is still originalEventBase.
1141   // If not, we have been detached from this event base, so we shouldn't
1142   // perform any more operations.
1143   if (eventBase_ != originalEventBase) {
1144     return;
1145   }
1146
1147   AsyncSocket::handleInitialReadWrite();
1148 }
1149
1150 void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
1151   connectionTimeout_.cancelTimeout();
1152   AsyncSocket::invokeConnectErr(ex);
1153   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
1154     if (handshakeTimeout_.isScheduled()) {
1155       handshakeTimeout_.cancelTimeout();
1156     }
1157     // If we fell back to connecting state during TFO and the connection
1158     // failed, it would be an SSL failure as well.
1159     invokeHandshakeErr(ex);
1160   }
1161 }
1162
1163 void AsyncSSLSocket::invokeConnectSuccess() {
1164   connectionTimeout_.cancelTimeout();
1165   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
1166     assert(tfoAttempted_);
1167     // If we failed TFO, we'd fall back to trying to connect the socket,
1168     // to setup things like timeouts.
1169     startSSLConnect();
1170   }
1171   // still invoke the base class since it re-sets the connect time.
1172   AsyncSocket::invokeConnectSuccess();
1173 }
1174
1175 void AsyncSSLSocket::scheduleConnectTimeout() {
1176   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
1177     // We fell back from TFO, and need to set the timeouts.
1178     // We will not have a connect callback in this case, thus if the timer
1179     // expires we would have no-one to notify.
1180     // Thus we should reset even the connect timers to point to the handshake
1181     // timeouts.
1182     assert(connectCallback_ == nullptr);
1183     // We use a different connect timeout here than the handshake timeout, so
1184     // that we can disambiguate the 2 timers.
1185     if (connectTimeout_.count() > 0) {
1186       if (!connectionTimeout_.scheduleTimeout(connectTimeout_)) {
1187         throw AsyncSocketException(
1188             AsyncSocketException::INTERNAL_ERROR,
1189             withAddr("failed to schedule AsyncSSLSocket connect timeout"));
1190       }
1191     }
1192     return;
1193   }
1194   AsyncSocket::scheduleConnectTimeout();
1195 }
1196
1197 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
1198 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1199   // turn on the buffer movable in openssl
1200   if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
1201       callback != nullptr && callback->isBufferMovable()) {
1202     SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
1203     isBufferMovable_ = true;
1204   }
1205 #endif
1206
1207   AsyncSocket::setReadCB(callback);
1208 }
1209
1210 void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
1211   bufferMovableEnabled_ = enabled;
1212 }
1213
1214 void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) {
1215   CHECK(readCallback_);
1216   if (isBufferMovable_) {
1217     *buf = nullptr;
1218     *buflen = 0;
1219   } else {
1220     // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
1221     readCallback_->getReadBuffer(buf, buflen);
1222   }
1223 }
1224
1225 void
1226 AsyncSSLSocket::handleRead() noexcept {
1227   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1228           << ", state=" << int(state_) << ", "
1229           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1230   if (state_ < StateEnum::ESTABLISHED) {
1231     return AsyncSocket::handleRead();
1232   }
1233
1234
1235   if (sslState_ == STATE_ACCEPTING) {
1236     assert(server_);
1237     handleAccept();
1238     return;
1239   }
1240   else if (sslState_ == STATE_CONNECTING) {
1241     assert(!server_);
1242     handleConnect();
1243     return;
1244   }
1245
1246   // Normal read
1247   AsyncSocket::handleRead();
1248 }
1249
1250 AsyncSocket::ReadResult
1251 AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
1252   VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
1253           << ", buflen=" << *buflen;
1254
1255   if (sslState_ == STATE_UNENCRYPTED) {
1256     return AsyncSocket::performRead(buf, buflen, offset);
1257   }
1258
1259   int bytes = 0;
1260   if (!isBufferMovable_) {
1261     bytes = SSL_read(ssl_, *buf, int(*buflen));
1262   }
1263 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1264   else {
1265     bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
1266   }
1267 #endif
1268
1269   if (server_ && renegotiateAttempted_) {
1270     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1271                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
1272                << "): client intitiated SSL renegotiation not permitted";
1273     return ReadResult(
1274         READ_ERROR,
1275         folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
1276   }
1277   if (bytes <= 0) {
1278     int error = SSL_get_error(ssl_, bytes);
1279     if (error == SSL_ERROR_WANT_READ) {
1280       // The caller will register for read event if not already.
1281       if (errno == EWOULDBLOCK || errno == EAGAIN) {
1282         return ReadResult(READ_BLOCKING);
1283       } else {
1284         return ReadResult(READ_ERROR);
1285       }
1286     } else if (error == SSL_ERROR_WANT_WRITE) {
1287       // TODO: Even though we are attempting to read data, SSL_read() may
1288       // need to write data if renegotiation is being performed.  We currently
1289       // don't support this and just fail the read.
1290       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1291                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
1292                  << "): unsupported SSL renegotiation during read";
1293       return ReadResult(
1294           READ_ERROR,
1295           folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1296     } else {
1297       if (zero_return(error, bytes)) {
1298         return ReadResult(bytes);
1299       }
1300       long errError = ERR_get_error();
1301       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1302               << "state=" << state_ << ", "
1303               << "sslState=" << sslState_ << ", "
1304               << "events=" << std::hex << eventFlags_ << "): "
1305               << "bytes: " << bytes << ", "
1306               << "error: " << error << ", "
1307               << "errno: " << errno << ", "
1308               << "func: " << ERR_func_error_string(errError) << ", "
1309               << "reason: " << ERR_reason_error_string(errError);
1310       return ReadResult(
1311           READ_ERROR,
1312           folly::make_unique<SSLException>(error, errError, bytes, errno));
1313     }
1314   } else {
1315     appBytesReceived_ += bytes;
1316     return ReadResult(bytes);
1317   }
1318 }
1319
1320 void AsyncSSLSocket::handleWrite() noexcept {
1321   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1322           << ", state=" << int(state_) << ", "
1323           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1324   if (state_ < StateEnum::ESTABLISHED) {
1325     return AsyncSocket::handleWrite();
1326   }
1327
1328   if (sslState_ == STATE_ACCEPTING) {
1329     assert(server_);
1330     handleAccept();
1331     return;
1332   }
1333
1334   if (sslState_ == STATE_CONNECTING) {
1335     assert(!server_);
1336     handleConnect();
1337     return;
1338   }
1339
1340   // Normal write
1341   AsyncSocket::handleWrite();
1342 }
1343
1344 AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
1345   if (error == SSL_ERROR_WANT_READ) {
1346     // Even though we are attempting to write data, SSL_write() may
1347     // need to read data if renegotiation is being performed.  We currently
1348     // don't support this and just fail the write.
1349     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1350                << ", sslState=" << sslState_ << ", events=" << eventFlags_
1351                << "): "
1352                << "unsupported SSL renegotiation during write";
1353     return WriteResult(
1354         WRITE_ERROR,
1355         folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1356   } else {
1357     if (zero_return(error, rc)) {
1358       return WriteResult(0);
1359     }
1360     auto errError = ERR_get_error();
1361     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1362             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1363             << "SSL error: " << error << ", errno: " << errno
1364             << ", func: " << ERR_func_error_string(errError)
1365             << ", reason: " << ERR_reason_error_string(errError);
1366     return WriteResult(
1367         WRITE_ERROR,
1368         folly::make_unique<SSLException>(error, errError, rc, errno));
1369   }
1370 }
1371
1372 AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
1373     const iovec* vec,
1374     uint32_t count,
1375     WriteFlags flags,
1376     uint32_t* countWritten,
1377     uint32_t* partialWritten) {
1378   if (sslState_ == STATE_UNENCRYPTED) {
1379     return AsyncSocket::performWrite(
1380       vec, count, flags, countWritten, partialWritten);
1381   }
1382   if (sslState_ != STATE_ESTABLISHED) {
1383     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1384                << ", sslState=" << sslState_
1385                << ", events=" << eventFlags_ << "): "
1386                << "TODO: AsyncSSLSocket currently does not support calling "
1387                << "write() before the handshake has fully completed";
1388     return WriteResult(
1389         WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
1390   }
1391
1392   // Declare a buffer used to hold small write requests.  It could point to a
1393   // memory block either on stack or on heap. If it is on heap, we release it
1394   // manually when scope exits
1395   char* combinedBuf{nullptr};
1396   SCOPE_EXIT {
1397     // Note, always keep this check consistent with what we do below
1398     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1399       delete[] combinedBuf;
1400     }
1401   };
1402
1403   *countWritten = 0;
1404   *partialWritten = 0;
1405   ssize_t totalWritten = 0;
1406   size_t bytesStolenFromNextBuffer = 0;
1407   for (uint32_t i = 0; i < count; i++) {
1408     const iovec* v = vec + i;
1409     size_t offset = bytesStolenFromNextBuffer;
1410     bytesStolenFromNextBuffer = 0;
1411     size_t len = v->iov_len - offset;
1412     const void* buf;
1413     if (len == 0) {
1414       (*countWritten)++;
1415       continue;
1416     }
1417     buf = ((const char*)v->iov_base) + offset;
1418
1419     ssize_t bytes;
1420     uint32_t buffersStolen = 0;
1421     auto sslWriteBuf = buf;
1422     if ((len < minWriteSize_) && ((i + 1) < count)) {
1423       // Combine this buffer with part or all of the next buffers in
1424       // order to avoid really small-grained calls to SSL_write().
1425       // Each call to SSL_write() produces a separate record in
1426       // the egress SSL stream, and we've found that some low-end
1427       // mobile clients can't handle receiving an HTTP response
1428       // header and the first part of the response body in two
1429       // separate SSL records (even if those two records are in
1430       // the same TCP packet).
1431
1432       if (combinedBuf == nullptr) {
1433         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1434           // Allocate the buffer on heap
1435           combinedBuf = new char[minWriteSize_];
1436         } else {
1437           // Allocate the buffer on stack
1438           combinedBuf = (char*)alloca(minWriteSize_);
1439         }
1440       }
1441       assert(combinedBuf != nullptr);
1442       sslWriteBuf = combinedBuf;
1443
1444       memcpy(combinedBuf, buf, len);
1445       do {
1446         // INVARIANT: i + buffersStolen == complete chunks serialized
1447         uint32_t nextIndex = i + buffersStolen + 1;
1448         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1449                                              minWriteSize_ - len);
1450         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1451                bytesStolenFromNextBuffer);
1452         len += bytesStolenFromNextBuffer;
1453         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1454           // couldn't steal the whole buffer
1455           break;
1456         } else {
1457           bytesStolenFromNextBuffer = 0;
1458           buffersStolen++;
1459         }
1460       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1461     }
1462
1463     // Advance any empty buffers immediately after.
1464     if (bytesStolenFromNextBuffer == 0) {
1465       while ((i + buffersStolen + 1) < count &&
1466              vec[i + buffersStolen + 1].iov_len == 0) {
1467         buffersStolen++;
1468       }
1469     }
1470
1471     corkCurrentWrite_ =
1472         isSet(flags, WriteFlags::CORK) || (i + buffersStolen + 1 < count);
1473     bytes = eorAwareSSLWrite(
1474         ssl_,
1475         sslWriteBuf,
1476         int(len),
1477         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1478
1479     if (bytes <= 0) {
1480       int error = SSL_get_error(ssl_, int(bytes));
1481       if (error == SSL_ERROR_WANT_WRITE) {
1482         // The caller will register for write event if not already.
1483         *partialWritten = uint32_t(offset);
1484         return WriteResult(totalWritten);
1485       }
1486       auto writeResult = interpretSSLError(int(bytes), error);
1487       if (writeResult.writeReturn < 0) {
1488         return writeResult;
1489       } // else fall through to below to correctly record totalWritten
1490     }
1491
1492     totalWritten += bytes;
1493
1494     if (bytes == (ssize_t)len) {
1495       // The full iovec is written.
1496       (*countWritten) += 1 + buffersStolen;
1497       i += buffersStolen;
1498       // continue
1499     } else {
1500       bytes += offset; // adjust bytes to account for all of v
1501       while (bytes >= (ssize_t)v->iov_len) {
1502         // We combined this buf with part or all of the next one, and
1503         // we managed to write all of this buf but not all of the bytes
1504         // from the next one that we'd hoped to write.
1505         bytes -= v->iov_len;
1506         (*countWritten)++;
1507         v = &(vec[++i]);
1508       }
1509       *partialWritten = uint32_t(bytes);
1510       return WriteResult(totalWritten);
1511     }
1512   }
1513
1514   return WriteResult(totalWritten);
1515 }
1516
1517 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1518                                       bool eor) {
1519   if (eor && trackEor_) {
1520     if (appEorByteNo_) {
1521       // cannot track for more than one app byte EOR
1522       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1523     } else {
1524       appEorByteNo_ = appBytesWritten_ + n;
1525     }
1526
1527     // 1. It is fine to keep updating minEorRawByteNo_.
1528     // 2. It is _min_ in the sense that SSL record will add some overhead.
1529     minEorRawByteNo_ = getRawBytesWritten() + n;
1530   }
1531
1532   n = sslWriteImpl(ssl, buf, n);
1533   if (n > 0) {
1534     appBytesWritten_ += n;
1535     if (appEorByteNo_) {
1536       if (getRawBytesWritten() >= minEorRawByteNo_) {
1537         minEorRawByteNo_ = 0;
1538       }
1539       if(appBytesWritten_ == appEorByteNo_) {
1540         appEorByteNo_ = 0;
1541       } else {
1542         CHECK(appBytesWritten_ < appEorByteNo_);
1543       }
1544     }
1545   }
1546   return n;
1547 }
1548
1549 void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
1550   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1551   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1552     sslSocket->renegotiateAttempted_ = true;
1553   }
1554   if (where & SSL_CB_READ_ALERT) {
1555     const char* type = SSL_alert_type_string(ret);
1556     if (type) {
1557       const char* desc = SSL_alert_desc_string(ret);
1558       sslSocket->alertsReceived_.emplace_back(
1559           *type, StringPiece(desc, std::strlen(desc)));
1560     }
1561   }
1562 }
1563
1564 int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
1565   struct msghdr msg;
1566   struct iovec iov;
1567   int flags = 0;
1568   AsyncSSLSocket* tsslSock;
1569
1570   iov.iov_base = const_cast<char*>(in);
1571   iov.iov_len = inl;
1572   memset(&msg, 0, sizeof(msg));
1573   msg.msg_iov = &iov;
1574   msg.msg_iovlen = 1;
1575
1576   auto appData = OpenSSLUtils::getBioAppData(b);
1577   CHECK(appData);
1578
1579   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
1580   CHECK(tsslSock);
1581
1582   if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
1583       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1584     flags = MSG_EOR;
1585   }
1586
1587 #ifdef MSG_NOSIGNAL
1588   flags |= MSG_NOSIGNAL;
1589 #endif
1590
1591 #ifdef MSG_MORE
1592   if (tsslSock->corkCurrentWrite_) {
1593     flags |= MSG_MORE;
1594   }
1595 #endif
1596
1597   auto result = tsslSock->sendSocketMessage(
1598       OpenSSLUtils::getBioFd(b, nullptr), &msg, flags);
1599   BIO_clear_retry_flags(b);
1600   if (!result.exception && result.writeReturn <= 0) {
1601     if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
1602       BIO_set_retry_write(b);
1603     }
1604   }
1605   return int(result.writeReturn);
1606 }
1607
1608 int AsyncSSLSocket::sslVerifyCallback(
1609     int preverifyOk,
1610     X509_STORE_CTX* x509Ctx) {
1611   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1612     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1613   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1614
1615   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1616           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1617   return (self->handshakeCallback_) ?
1618     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1619     preverifyOk;
1620 }
1621
1622 void AsyncSSLSocket::enableClientHelloParsing()  {
1623     parseClientHello_ = true;
1624     clientHelloInfo_.reset(new ssl::ClientHelloInfo());
1625 }
1626
1627 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1628   SSL_set_msg_callback(ssl, nullptr);
1629   SSL_set_msg_callback_arg(ssl, nullptr);
1630   clientHelloInfo_->clientHelloBuf_.clear();
1631 }
1632
1633 void AsyncSSLSocket::clientHelloParsingCallback(int written,
1634                                                 int /* version */,
1635                                                 int contentType,
1636                                                 const void* buf,
1637                                                 size_t len,
1638                                                 SSL* ssl,
1639                                                 void* arg) {
1640   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1641   if (written != 0) {
1642     sock->resetClientHelloParsing(ssl);
1643     return;
1644   }
1645   if (contentType != SSL3_RT_HANDSHAKE) {
1646     return;
1647   }
1648   if (len == 0) {
1649     return;
1650   }
1651
1652   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1653   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1654   try {
1655     Cursor cursor(clientHelloBuf.front());
1656     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1657       sock->resetClientHelloParsing(ssl);
1658       return;
1659     }
1660
1661     if (cursor.totalLength() < 3) {
1662       clientHelloBuf.trimEnd(len);
1663       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1664       return;
1665     }
1666
1667     uint32_t messageLength = cursor.read<uint8_t>();
1668     messageLength <<= 8;
1669     messageLength |= cursor.read<uint8_t>();
1670     messageLength <<= 8;
1671     messageLength |= cursor.read<uint8_t>();
1672     if (cursor.totalLength() < messageLength) {
1673       clientHelloBuf.trimEnd(len);
1674       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1675       return;
1676     }
1677
1678     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1679     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1680
1681     cursor.skip(4); // gmt_unix_time
1682     cursor.skip(28); // random_bytes
1683
1684     cursor.skip(cursor.read<uint8_t>()); // session_id
1685
1686     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1687     for (int i = 0; i < cipherSuitesLength; i += 2) {
1688       sock->clientHelloInfo_->
1689         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1690     }
1691
1692     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1693     for (int i = 0; i < compressionMethodsLength; ++i) {
1694       sock->clientHelloInfo_->
1695         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1696     }
1697
1698     if (cursor.totalLength() > 0) {
1699       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1700       while (extensionsLength) {
1701         ssl::TLSExtension extensionType =
1702             static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
1703         sock->clientHelloInfo_->
1704           clientHelloExtensions_.push_back(extensionType);
1705         extensionsLength -= 2;
1706         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1707         extensionsLength -= 2;
1708         extensionsLength -= extensionDataLength;
1709
1710         if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
1711           cursor.skip(2);
1712           extensionDataLength -= 2;
1713           while (extensionDataLength) {
1714             ssl::HashAlgorithm hashAlg =
1715                 static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
1716             ssl::SignatureAlgorithm sigAlg =
1717                 static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
1718             extensionDataLength -= 2;
1719             sock->clientHelloInfo_->
1720               clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
1721           }
1722         } else if (extensionType == ssl::TLSExtension::SUPPORTED_VERSIONS) {
1723           cursor.skip(1);
1724           extensionDataLength -= 1;
1725           while (extensionDataLength) {
1726             sock->clientHelloInfo_->clientHelloSupportedVersions_.push_back(
1727                 cursor.readBE<uint16_t>());
1728             extensionDataLength -= 2;
1729           }
1730         } else {
1731           cursor.skip(extensionDataLength);
1732         }
1733       }
1734     }
1735   } catch (std::out_of_range&) {
1736     // we'll use what we found and cleanup below.
1737     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1738       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1739   }
1740
1741   sock->resetClientHelloParsing(ssl);
1742 }
1743
1744 void AsyncSSLSocket::getSSLClientCiphers(
1745     std::string& clientCiphers,
1746     bool convertToString) const {
1747   std::string ciphers;
1748
1749   if (parseClientHello_ == false
1750       || clientHelloInfo_->clientHelloCipherSuites_.empty()) {
1751     clientCiphers = "";
1752     return;
1753   }
1754
1755   bool first = true;
1756   for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_)
1757   {
1758     if (first) {
1759       first = false;
1760     } else {
1761       ciphers +=  ":";
1762     }
1763
1764     bool nameFound = convertToString;
1765
1766     if (convertToString) {
1767       const auto& name = OpenSSLUtils::getCipherName(originalCipherCode);
1768       if (name.empty()) {
1769         nameFound = false;
1770       } else {
1771         ciphers += name;
1772       }
1773     }
1774
1775     if (!nameFound) {
1776       folly::hexlify(
1777           std::array<uint8_t, 2>{{
1778               static_cast<uint8_t>((originalCipherCode >> 8) & 0xffL),
1779               static_cast<uint8_t>(originalCipherCode & 0x00ffL) }},
1780           ciphers,
1781           /* append to ciphers = */ true);
1782     }
1783   }
1784
1785   clientCiphers = std::move(ciphers);
1786 }
1787
1788 std::string AsyncSSLSocket::getSSLClientComprMethods() const {
1789   if (!parseClientHello_) {
1790     return "";
1791   }
1792   return folly::join(":", clientHelloInfo_->clientHelloCompressionMethods_);
1793 }
1794
1795 std::string AsyncSSLSocket::getSSLClientExts() const {
1796   if (!parseClientHello_) {
1797     return "";
1798   }
1799   return folly::join(":", clientHelloInfo_->clientHelloExtensions_);
1800 }
1801
1802 std::string AsyncSSLSocket::getSSLClientSigAlgs() const {
1803   if (!parseClientHello_) {
1804     return "";
1805   }
1806
1807   std::string sigAlgs;
1808   sigAlgs.reserve(clientHelloInfo_->clientHelloSigAlgs_.size() * 4);
1809   for (size_t i = 0; i < clientHelloInfo_->clientHelloSigAlgs_.size(); i++) {
1810     if (i) {
1811       sigAlgs.push_back(':');
1812     }
1813     sigAlgs.append(folly::to<std::string>(
1814         clientHelloInfo_->clientHelloSigAlgs_[i].first));
1815     sigAlgs.push_back(',');
1816     sigAlgs.append(folly::to<std::string>(
1817         clientHelloInfo_->clientHelloSigAlgs_[i].second));
1818   }
1819
1820   return sigAlgs;
1821 }
1822
1823 std::string AsyncSSLSocket::getSSLClientSupportedVersions() const {
1824   if (!parseClientHello_) {
1825     return "";
1826   }
1827   return folly::join(":", clientHelloInfo_->clientHelloSupportedVersions_);
1828 }
1829
1830 std::string AsyncSSLSocket::getSSLAlertsReceived() const {
1831   std::string ret;
1832
1833   for (const auto& alert : alertsReceived_) {
1834     if (!ret.empty()) {
1835       ret.append(",");
1836     }
1837     ret.append(folly::to<std::string>(alert.first, ": ", alert.second));
1838   }
1839
1840   return ret;
1841 }
1842
1843 void AsyncSSLSocket::getSSLSharedCiphers(std::string& sharedCiphers) const {
1844   char ciphersBuffer[1024];
1845   ciphersBuffer[0] = '\0';
1846   SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1);
1847   sharedCiphers = ciphersBuffer;
1848 }
1849
1850 void AsyncSSLSocket::getSSLServerCiphers(std::string& serverCiphers) const {
1851   serverCiphers = SSL_get_cipher_list(ssl_, 0);
1852   int i = 1;
1853   const char *cipher;
1854   while ((cipher = SSL_get_cipher_list(ssl_, i)) != nullptr) {
1855     serverCiphers.append(":");
1856     serverCiphers.append(cipher);
1857     i++;
1858   }
1859 }
1860
1861 } // namespace