cleaning up RequestContext
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20
21 #include <boost/noncopyable.hpp>
22 #include <errno.h>
23 #include <fcntl.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/asn1.h>
28 #include <openssl/ssl.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <unistd.h>
32 #include <chrono>
33
34 #include <folly/Bits.h>
35 #include <folly/SocketAddress.h>
36 #include <folly/SpinLock.h>
37 #include <folly/io/IOBuf.h>
38 #include <folly/io/Cursor.h>
39
40 using folly::SocketAddress;
41 using folly::SSLContext;
42 using std::string;
43 using std::shared_ptr;
44
45 using folly::Endian;
46 using folly::IOBuf;
47 using folly::SpinLock;
48 using folly::SpinLockGuard;
49 using folly::io::Cursor;
50 using std::unique_ptr;
51 using std::bind;
52
53 namespace {
54 using folly::AsyncSocket;
55 using folly::AsyncSocketException;
56 using folly::AsyncSSLSocket;
57 using folly::Optional;
58
59 // We have one single dummy SSL context so that we can implement attach
60 // and detach methods in a thread safe fashion without modifying opnessl.
61 static SSLContext *dummyCtx = nullptr;
62 static SpinLock dummyCtxLock;
63
64 // Numbers chosen as to not collide with functions in ssl.h
65 const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90;
66 const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91;
67
68 // If given min write size is less than this, buffer will be allocated on
69 // stack, otherwise it is allocated on heap
70 const size_t MAX_STACK_BUF_SIZE = 2048;
71
72 // This converts "illegal" shutdowns into ZERO_RETURN
73 inline bool zero_return(int error, int rc) {
74   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
75 }
76
77 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
78                                 public AsyncSSLSocket::HandshakeCB {
79
80  private:
81   AsyncSSLSocket *sslSocket_;
82   AsyncSSLSocket::ConnectCallback *callback_;
83   int timeout_;
84   int64_t startTime_;
85
86  protected:
87   virtual ~AsyncSSLSocketConnector() {
88   }
89
90  public:
91   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
92                            AsyncSocket::ConnectCallback *callback,
93                            int timeout) :
94       sslSocket_(sslSocket),
95       callback_(callback),
96       timeout_(timeout),
97       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
98                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
99   }
100
101   virtual void connectSuccess() noexcept {
102     VLOG(7) << "client socket connected";
103
104     int64_t timeoutLeft = 0;
105     if (timeout_ > 0) {
106       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
107         std::chrono::steady_clock::now().time_since_epoch()).count();
108
109       timeoutLeft = timeout_ - (curTime - startTime_);
110       if (timeoutLeft <= 0) {
111         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
112                                 "SSL connect timed out");
113         fail(ex);
114         delete this;
115         return;
116       }
117     }
118     sslSocket_->sslConn(this, timeoutLeft);
119   }
120
121   virtual void connectErr(const AsyncSocketException& ex) noexcept {
122     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
123     fail(ex);
124     delete this;
125   }
126
127   virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept {
128     VLOG(7) << "client handshake success";
129     if (callback_) {
130       callback_->connectSuccess();
131     }
132     delete this;
133   }
134
135   virtual void handshakeErr(AsyncSSLSocket *socket,
136                               const AsyncSocketException& ex) noexcept {
137     LOG(ERROR) << "client handshakeErr: " << ex.what();
138     fail(ex);
139     delete this;
140   }
141
142   void fail(const AsyncSocketException &ex) {
143     // fail is a noop if called twice
144     if (callback_) {
145       AsyncSSLSocket::ConnectCallback *cb = callback_;
146       callback_ = nullptr;
147
148       cb->connectErr(ex);
149       sslSocket_->closeNow();
150       // closeNow can call handshakeErr if it hasn't been called already.
151       // So this may have been deleted, no member variable access beyond this
152       // point
153       // Note that closeNow may invoke writeError callbacks if the socket had
154       // write data pending connection completion.
155     }
156   }
157 };
158
159 // XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
160 #ifdef TCP_CORK // Linux-only
161 /**
162  * Utility class that corks a TCP socket upon construction or uncorks
163  * the socket upon destruction
164  */
165 class CorkGuard : private boost::noncopyable {
166  public:
167   CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
168     fd_(fd), haveMore_(haveMore), corked_(corked) {
169     if (*corked_) {
170       // socket is already corked; nothing to do
171       return;
172     }
173     if (multipleWrites || haveMore) {
174       // We are performing multiple writes in this performWrite() call,
175       // and/or there are more calls to performWrite() that will be invoked
176       // later, so enable corking
177       int flag = 1;
178       setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
179       *corked_ = true;
180     }
181   }
182
183   ~CorkGuard() {
184     if (haveMore_) {
185       // more data to come; don't uncork yet
186       return;
187     }
188     if (!*corked_) {
189       // socket isn't corked; nothing to do
190       return;
191     }
192
193     int flag = 0;
194     setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
195     *corked_ = false;
196   }
197
198  private:
199   int fd_;
200   bool haveMore_;
201   bool* corked_;
202 };
203 #else
204 class CorkGuard : private boost::noncopyable {
205  public:
206   CorkGuard(int, bool, bool, bool*) {}
207 };
208 #endif
209
210 void setup_SSL_CTX(SSL_CTX *ctx) {
211 #ifdef SSL_MODE_RELEASE_BUFFERS
212   SSL_CTX_set_mode(ctx,
213                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
214                    SSL_MODE_ENABLE_PARTIAL_WRITE
215                    | SSL_MODE_RELEASE_BUFFERS
216                    );
217 #else
218   SSL_CTX_set_mode(ctx,
219                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
220                    SSL_MODE_ENABLE_PARTIAL_WRITE
221                    );
222 #endif
223 }
224
225 BIO_METHOD eorAwareBioMethod;
226
227 void* initEorBioMethod(void) {
228   memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
229   // override the bwrite method for MSG_EOR support
230   eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
231
232   // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
233   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
234   // then have specific handlings. The eorAwareBioWrite should be compatible
235   // with the one in openssl.
236
237   // Return something here to enable AsyncSSLSocket to call this method using
238   // a function-scoped static.
239   return nullptr;
240 }
241
242 } // anonymous namespace
243
244 namespace folly {
245
246 SSLException::SSLException(int sslError, int errno_copy):
247     AsyncSocketException(
248       AsyncSocketException::SSL_ERROR,
249       ERR_error_string(sslError, msg_),
250       sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {}
251
252 /**
253  * Create a client AsyncSSLSocket
254  */
255 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
256                                  EventBase* evb) :
257     AsyncSocket(evb),
258     ctx_(ctx),
259     handshakeTimeout_(this, evb) {
260   init();
261 }
262
263 /**
264  * Create a server/client AsyncSSLSocket
265  */
266 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
267                                  EventBase* evb, int fd, bool server) :
268     AsyncSocket(evb, fd),
269     server_(server),
270     ctx_(ctx),
271     handshakeTimeout_(this, evb) {
272   init();
273   if (server) {
274     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
275                               AsyncSSLSocket::sslInfoCallback);
276   }
277 }
278
279 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
280 /**
281  * Create a client AsyncSSLSocket and allow tlsext_hostname
282  * to be sent in Client Hello.
283  */
284 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
285                                  EventBase* evb,
286                                  const std::string& serverName) :
287     AsyncSSLSocket(ctx, evb) {
288   tlsextHostname_ = serverName;
289 }
290
291 /**
292  * Create a client AsyncSSLSocket from an already connected fd
293  * and allow tlsext_hostname to be sent in Client Hello.
294  */
295 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
296                                  EventBase* evb, int fd,
297                                  const std::string& serverName) :
298     AsyncSSLSocket(ctx, evb, fd, false) {
299   tlsextHostname_ = serverName;
300 }
301 #endif
302
303 AsyncSSLSocket::~AsyncSSLSocket() {
304   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
305           << ", evb=" << eventBase_ << ", fd=" << fd_
306           << ", state=" << int(state_) << ", sslState="
307           << sslState_ << ", events=" << eventFlags_ << ")";
308 }
309
310 void AsyncSSLSocket::init() {
311   // Do this here to ensure we initialize this once before any use of
312   // AsyncSSLSocket instances and not as part of library load.
313   static const auto eorAwareBioMethodInitializer = initEorBioMethod();
314   setup_SSL_CTX(ctx_->getSSLCtx());
315 }
316
317 void AsyncSSLSocket::closeNow() {
318   // Close the SSL connection.
319   if (ssl_ != nullptr && fd_ != -1) {
320     int rc = SSL_shutdown(ssl_);
321     if (rc == 0) {
322       rc = SSL_shutdown(ssl_);
323     }
324     if (rc < 0) {
325       ERR_clear_error();
326     }
327   }
328
329   if (sslSession_ != nullptr) {
330     SSL_SESSION_free(sslSession_);
331     sslSession_ = nullptr;
332   }
333
334   sslState_ = STATE_CLOSED;
335
336   if (handshakeTimeout_.isScheduled()) {
337     handshakeTimeout_.cancelTimeout();
338   }
339
340   DestructorGuard dg(this);
341
342   if (handshakeCallback_) {
343     AsyncSocketException ex(AsyncSocketException::END_OF_FILE,
344                            "SSL connection closed locally");
345     HandshakeCB* callback = handshakeCallback_;
346     handshakeCallback_ = nullptr;
347     callback->handshakeErr(this, ex);
348   }
349
350   if (ssl_ != nullptr) {
351     SSL_free(ssl_);
352     ssl_ = nullptr;
353   }
354
355   // Close the socket.
356   AsyncSocket::closeNow();
357 }
358
359 void AsyncSSLSocket::shutdownWrite() {
360   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
361   //
362   // (Performing a full shutdown here is more desirable than doing nothing at
363   // all.  The purpose of shutdownWrite() is normally to notify the other end
364   // of the connection that no more data will be sent.  If we do nothing, the
365   // other end will never know that no more data is coming, and this may result
366   // in protocol deadlock.)
367   close();
368 }
369
370 void AsyncSSLSocket::shutdownWriteNow() {
371   closeNow();
372 }
373
374 bool AsyncSSLSocket::good() const {
375   return (AsyncSocket::good() &&
376           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
377            sslState_ == STATE_ESTABLISHED));
378 }
379
380 // The TAsyncTransport definition of 'good' states that the transport is
381 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
382 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
383 // is connected but we haven't initiated the call to SSL_connect.
384 bool AsyncSSLSocket::connecting() const {
385   return (!server_ &&
386           (AsyncSocket::connecting() ||
387            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
388                                      sslState_ == STATE_CONNECTING))));
389 }
390
391 bool AsyncSSLSocket::isEorTrackingEnabled() const {
392   const BIO *wb = SSL_get_wbio(ssl_);
393   return wb && wb->method == &eorAwareBioMethod;
394 }
395
396 void AsyncSSLSocket::setEorTracking(bool track) {
397   BIO *wb = SSL_get_wbio(ssl_);
398   if (!wb) {
399     throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
400                               "setting EOR tracking without an initialized "
401                               "BIO");
402   }
403
404   if (track) {
405     if (wb->method != &eorAwareBioMethod) {
406       // only do this if we didn't
407       wb->method = &eorAwareBioMethod;
408       BIO_set_app_data(wb, this);
409       appEorByteNo_ = 0;
410       minEorRawByteNo_ = 0;
411     }
412   } else if (wb->method == &eorAwareBioMethod) {
413     wb->method = BIO_s_socket();
414     BIO_set_app_data(wb, nullptr);
415     appEorByteNo_ = 0;
416     minEorRawByteNo_ = 0;
417   } else {
418     CHECK(wb->method == BIO_s_socket());
419   }
420 }
421
422 size_t AsyncSSLSocket::getRawBytesWritten() const {
423   BIO *b;
424   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
425     return 0;
426   }
427
428   return BIO_number_written(b);
429 }
430
431 size_t AsyncSSLSocket::getRawBytesReceived() const {
432   BIO *b;
433   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
434     return 0;
435   }
436
437   return BIO_number_read(b);
438 }
439
440
441 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
442   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
443              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
444              << "events=" << eventFlags_ << ", server=" << short(server_) << "): "
445              << "sslAccept/Connect() called in invalid "
446              << "state, handshake callback " << handshakeCallback_ << ", new callback "
447              << callback;
448   assert(!handshakeTimeout_.isScheduled());
449   sslState_ = STATE_ERROR;
450
451   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
452                          "sslAccept() called with socket in invalid state");
453
454   if (callback) {
455     callback->handshakeErr(this, ex);
456   }
457
458   // Check the socket state not the ssl state here.
459   if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
460     failHandshake(__func__, ex);
461   }
462 }
463
464 void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
465       const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
466   DestructorGuard dg(this);
467   assert(eventBase_->isInEventBaseThread());
468   verifyPeer_ = verifyPeer;
469
470   // Make sure we're in the uninitialized state
471   if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
472     return invalidState(callback);
473   }
474
475   sslState_ = STATE_ACCEPTING;
476   handshakeCallback_ = callback;
477
478   if (timeout > 0) {
479     handshakeTimeout_.scheduleTimeout(timeout);
480   }
481
482   /* register for a read operation (waiting for CLIENT HELLO) */
483   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
484 }
485
486 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
487 void AsyncSSLSocket::attachSSLContext(
488   const std::shared_ptr<SSLContext>& ctx) {
489
490   // Check to ensure we are in client mode. Changing a server's ssl
491   // context doesn't make sense since clients of that server would likely
492   // become confused when the server's context changes.
493   DCHECK(!server_);
494   DCHECK(!ctx_);
495   DCHECK(ctx);
496   DCHECK(ctx->getSSLCtx());
497   ctx_ = ctx;
498
499   // In order to call attachSSLContext, detachSSLContext must have been
500   // previously called which sets the socket's context to the dummy
501   // context. Thus we must acquire this lock.
502   SpinLockGuard guard(dummyCtxLock);
503   SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
504 }
505
506 void AsyncSSLSocket::detachSSLContext() {
507   DCHECK(ctx_);
508   ctx_.reset();
509   // We aren't using the initial_ctx for now, and it can introduce race
510   // conditions in the destructor of the SSL object.
511 #ifndef OPENSSL_NO_TLSEXT
512   if (ssl_->initial_ctx) {
513     SSL_CTX_free(ssl_->initial_ctx);
514     ssl_->initial_ctx = nullptr;
515   }
516 #endif
517   SpinLockGuard guard(dummyCtxLock);
518   if (nullptr == dummyCtx) {
519     // We need to lazily initialize the dummy context so we don't
520     // accidentally override any programmatic settings to openssl
521     dummyCtx = new SSLContext;
522   }
523   // We must remove this socket's references to its context right now
524   // since this socket could get passed to any thread. If the context has
525   // had its locking disabled, just doing a set in attachSSLContext()
526   // would not be thread safe.
527   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
528 }
529 #endif
530
531 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
532 void AsyncSSLSocket::switchServerSSLContext(
533   const std::shared_ptr<SSLContext>& handshakeCtx) {
534   CHECK(server_);
535   if (sslState_ != STATE_ACCEPTING) {
536     // We log it here and allow the switch.
537     // It should not affect our re-negotiation support (which
538     // is not supported now).
539     VLOG(6) << "fd=" << getFd()
540             << " renegotation detected when switching SSL_CTX";
541   }
542
543   setup_SSL_CTX(handshakeCtx->getSSLCtx());
544   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
545                             AsyncSSLSocket::sslInfoCallback);
546   handshakeCtx_ = handshakeCtx;
547   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
548 }
549
550 bool AsyncSSLSocket::isServerNameMatch() const {
551   CHECK(!server_);
552
553   if (!ssl_) {
554     return false;
555   }
556
557   SSL_SESSION *ss = SSL_get_session(ssl_);
558   if (!ss) {
559     return false;
560   }
561
562   return (ss->tlsext_hostname ? true : false);
563 }
564
565 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
566   tlsextHostname_ = std::move(serverName);
567 }
568
569 #endif
570
571 void AsyncSSLSocket::timeoutExpired() noexcept {
572   if (state_ == StateEnum::ESTABLISHED &&
573       (sslState_ == STATE_CACHE_LOOKUP ||
574        sslState_ == STATE_RSA_ASYNC_PENDING)) {
575     sslState_ = STATE_ERROR;
576     // We are expecting a callback in restartSSLAccept.  The cache lookup
577     // and rsa-call necessarily have pointers to this ssl socket, so delay
578     // the cleanup until he calls us back.
579   } else {
580     assert(state_ == StateEnum::ESTABLISHED &&
581            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
582     DestructorGuard dg(this);
583     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
584                            (sslState_ == STATE_CONNECTING) ?
585                            "SSL connect timed out" : "SSL accept timed out");
586     failHandshake(__func__, ex);
587   }
588 }
589
590 int AsyncSSLSocket::sslExDataIndex_ = -1;
591 std::mutex AsyncSSLSocket::mutex_;
592
593 int AsyncSSLSocket::getSSLExDataIndex() {
594   if (sslExDataIndex_ < 0) {
595     std::lock_guard<std::mutex> g(mutex_);
596     if (sslExDataIndex_ < 0) {
597       sslExDataIndex_ = SSL_get_ex_new_index(0,
598           (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
599     }
600   }
601   return sslExDataIndex_;
602 }
603
604 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
605   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
606       getSSLExDataIndex()));
607 }
608
609 void AsyncSSLSocket::failHandshake(const char* fn,
610                                     const AsyncSocketException& ex) {
611   startFail();
612
613   if (handshakeTimeout_.isScheduled()) {
614     handshakeTimeout_.cancelTimeout();
615   }
616   if (handshakeCallback_ != nullptr) {
617     HandshakeCB* callback = handshakeCallback_;
618     handshakeCallback_ = nullptr;
619     callback->handshakeErr(this, ex);
620   }
621
622   finishFail();
623 }
624
625 void AsyncSSLSocket::invokeHandshakeCB() {
626   if (handshakeTimeout_.isScheduled()) {
627     handshakeTimeout_.cancelTimeout();
628   }
629   if (handshakeCallback_) {
630     HandshakeCB* callback = handshakeCallback_;
631     handshakeCallback_ = nullptr;
632     callback->handshakeSuc(this);
633   }
634 }
635
636 void AsyncSSLSocket::connect(ConnectCallback* callback,
637                               const folly::SocketAddress& address,
638                               int timeout,
639                               const OptionMap &options,
640                               const folly::SocketAddress& bindAddr)
641                               noexcept {
642   assert(!server_);
643   assert(state_ == StateEnum::UNINIT);
644   assert(sslState_ == STATE_UNINIT);
645   AsyncSSLSocketConnector *connector =
646     new AsyncSSLSocketConnector(this, callback, timeout);
647   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
648 }
649
650 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
651   // apply the settings specified in verifyPeer_
652   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
653     if(ctx_->needsPeerVerification()) {
654       SSL_set_verify(ssl, ctx_->getVerificationMode(),
655         AsyncSSLSocket::sslVerifyCallback);
656     }
657   } else {
658     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
659         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
660       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
661         AsyncSSLSocket::sslVerifyCallback);
662     }
663   }
664 }
665
666 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
667         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
668   DestructorGuard dg(this);
669   assert(eventBase_->isInEventBaseThread());
670
671   verifyPeer_ = verifyPeer;
672
673   // Make sure we're in the uninitialized state
674   if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
675     return invalidState(callback);
676   }
677
678   sslState_ = STATE_CONNECTING;
679   handshakeCallback_ = callback;
680
681   try {
682     ssl_ = ctx_->createSSL();
683   } catch (std::exception &e) {
684     sslState_ = STATE_ERROR;
685     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
686                            "error calling SSLContext::createSSL()");
687     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
688             << fd_ << "): " << e.what();
689     return failHandshake(__func__, ex);
690   }
691
692   applyVerificationOptions(ssl_);
693
694   SSL_set_fd(ssl_, fd_);
695   if (sslSession_ != nullptr) {
696     SSL_set_session(ssl_, sslSession_);
697     SSL_SESSION_free(sslSession_);
698     sslSession_ = nullptr;
699   }
700 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
701   if (tlsextHostname_.size()) {
702     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
703   }
704 #endif
705
706   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
707
708   if (timeout > 0) {
709     handshakeTimeout_.scheduleTimeout(timeout);
710   }
711
712   handleConnect();
713 }
714
715 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
716   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
717     return SSL_get1_session(ssl_);
718   }
719
720   return sslSession_;
721 }
722
723 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
724   sslSession_ = session;
725   if (!takeOwnership && session != nullptr) {
726     // Increment the reference count
727     CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
728   }
729 }
730
731 void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName,
732     unsigned* protoLen) const {
733   if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) {
734     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
735                               "NPN not supported");
736   }
737 }
738
739 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
740   const unsigned char** protoName,
741   unsigned* protoLen) const {
742   *protoName = nullptr;
743   *protoLen = 0;
744 #ifdef OPENSSL_NPN_NEGOTIATED
745   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
746   return true;
747 #else
748   return false;
749 #endif
750 }
751
752 bool AsyncSSLSocket::getSSLSessionReused() const {
753   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
754     return SSL_session_reused(ssl_);
755   }
756   return false;
757 }
758
759 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
760   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
761 }
762
763 const char *AsyncSSLSocket::getSSLServerName() const {
764 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
765   return (ssl_ != nullptr) ? SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name)
766         : nullptr;
767 #else
768   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
769                             "SNI not supported");
770 #endif
771 }
772
773 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
774   try {
775     return getSSLServerName();
776   } catch (AsyncSocketException& ex) {
777     return nullptr;
778   }
779 }
780
781 int AsyncSSLSocket::getSSLVersion() const {
782   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
783 }
784
785 int AsyncSSLSocket::getSSLCertSize() const {
786   int certSize = 0;
787   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
788   if (cert) {
789     EVP_PKEY *key = X509_get_pubkey(cert);
790     certSize = EVP_PKEY_bits(key);
791     EVP_PKEY_free(key);
792   }
793   return certSize;
794 }
795
796 bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
797   int error = *errorOut = SSL_get_error(ssl_, ret);
798   if (error == SSL_ERROR_WANT_READ) {
799     // Register for read event if not already.
800     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
801     return true;
802   } else if (error == SSL_ERROR_WANT_WRITE) {
803     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
804             << ", state=" << int(state_) << ", sslState="
805             << sslState_ << ", events=" << eventFlags_ << "): "
806             << "SSL_ERROR_WANT_WRITE";
807     // Register for write event if not already.
808     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
809     return true;
810 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
811   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
812     // We will block but we can't register our own socket.  The callback that
813     // triggered this code will re-call handleAccept at the appropriate time.
814
815     // We can only get here if the linked libssl.so has support for this feature
816     // as well, otherwise SSL_get_error cannot return our error code.
817     sslState_ = STATE_CACHE_LOOKUP;
818
819     // Unregister for all events while blocked here
820     updateEventRegistration(EventHandler::NONE,
821                             EventHandler::READ | EventHandler::WRITE);
822
823     // The timeout (if set) keeps running here
824     return true;
825 #endif
826 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
827   } else if (error == SSL_ERROR_WANT_RSA_ASYNC_PENDING) {
828     // Our custom openssl function has kicked off an async request to do
829     // modular exponentiation.  When that call returns, a callback will
830     // be invoked that will re-call handleAccept.
831     sslState_ = STATE_RSA_ASYNC_PENDING;
832
833     // Unregister for all events while blocked here
834     updateEventRegistration(
835       EventHandler::NONE,
836       EventHandler::READ | EventHandler::WRITE
837     );
838
839     // The timeout (if set) keeps running here
840     return true;
841 #endif
842   } else {
843     // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
844     // in the log
845     long lastError = ERR_get_error();
846     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
847             << "state=" << state_ << ", "
848             << "sslState=" << sslState_ << ", "
849             << "events=" << std::hex << eventFlags_ << "): "
850             << "SSL error: " << error << ", "
851             << "errno: " << errno << ", "
852             << "ret: " << ret << ", "
853             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
854             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
855             << "func: " << ERR_func_error_string(lastError) << ", "
856             << "reason: " << ERR_reason_error_string(lastError);
857     if (error != SSL_ERROR_SYSCALL) {
858       if (error == SSL_ERROR_SSL) {
859         *errorOut = lastError;
860       }
861       if ((unsigned long)lastError < 0x8000) {
862         errno = ENOSYS;
863       } else {
864         errno = lastError;
865       }
866     }
867     ERR_clear_error();
868     return false;
869   }
870 }
871
872 void AsyncSSLSocket::checkForImmediateRead() noexcept {
873   // openssl may have buffered data that it read from the socket already.
874   // In this case we have to process it immediately, rather than waiting for
875   // the socket to become readable again.
876   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
877     AsyncSocket::handleRead();
878   }
879 }
880
881 void
882 AsyncSSLSocket::restartSSLAccept()
883 {
884   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this << ", fd=" << fd_
885           << ", state=" << int(state_) << ", "
886           << "sslState=" << sslState_ << ", events=" << eventFlags_;
887   DestructorGuard dg(this);
888   assert(
889     sslState_ == STATE_CACHE_LOOKUP ||
890     sslState_ == STATE_RSA_ASYNC_PENDING ||
891     sslState_ == STATE_ERROR ||
892     sslState_ == STATE_CLOSED
893   );
894   if (sslState_ == STATE_CLOSED) {
895     // I sure hope whoever closed this socket didn't delete it already,
896     // but this is not strictly speaking an error
897     return;
898   }
899   if (sslState_ == STATE_ERROR) {
900     // go straight to fail if timeout expired during lookup
901     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
902                            "SSL accept timed out");
903     failHandshake(__func__, ex);
904     return;
905   }
906   sslState_ = STATE_ACCEPTING;
907   this->handleAccept();
908 }
909
910 void
911 AsyncSSLSocket::handleAccept() noexcept {
912   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
913           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
914           << "sslState=" << sslState_ << ", events=" << eventFlags_;
915   assert(server_);
916   assert(state_ == StateEnum::ESTABLISHED &&
917          sslState_ == STATE_ACCEPTING);
918   if (!ssl_) {
919     /* lazily create the SSL structure */
920     try {
921       ssl_ = ctx_->createSSL();
922     } catch (std::exception &e) {
923       sslState_ = STATE_ERROR;
924       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
925                              "error calling SSLContext::createSSL()");
926       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
927                  << ", fd=" << fd_ << "): " << e.what();
928       return failHandshake(__func__, ex);
929     }
930     SSL_set_fd(ssl_, fd_);
931     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
932
933     applyVerificationOptions(ssl_);
934   }
935
936   if (server_ && parseClientHello_) {
937     SSL_set_msg_callback_arg(ssl_, this);
938     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
939   }
940
941   errno = 0;
942   int ret = SSL_accept(ssl_);
943   if (ret <= 0) {
944     int error;
945     if (willBlock(ret, &error)) {
946       return;
947     } else {
948       sslState_ = STATE_ERROR;
949       SSLException ex(error, errno);
950       return failHandshake(__func__, ex);
951     }
952   }
953
954   handshakeComplete_ = true;
955   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
956
957   // Move into STATE_ESTABLISHED in the normal case that we are in
958   // STATE_ACCEPTING.
959   sslState_ = STATE_ESTABLISHED;
960
961   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
962           << " successfully accepted; state=" << int(state_)
963           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
964
965   // Remember the EventBase we are attached to, before we start invoking any
966   // callbacks (since the callbacks may call detachEventBase()).
967   EventBase* originalEventBase = eventBase_;
968
969   // Call the accept callback.
970   invokeHandshakeCB();
971
972   // Note that the accept callback may have changed our state.
973   // (set or unset the read callback, called write(), closed the socket, etc.)
974   // The following code needs to handle these situations correctly.
975   //
976   // If the socket has been closed, readCallback_ and writeReqHead_ will
977   // always be nullptr, so that will prevent us from trying to read or write.
978   //
979   // The main thing to check for is if eventBase_ is still originalEventBase.
980   // If not, we have been detached from this event base, so we shouldn't
981   // perform any more operations.
982   if (eventBase_ != originalEventBase) {
983     return;
984   }
985
986   AsyncSocket::handleInitialReadWrite();
987 }
988
989 void
990 AsyncSSLSocket::handleConnect() noexcept {
991   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
992           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
993           << "sslState=" << sslState_ << ", events=" << eventFlags_;
994   assert(!server_);
995   if (state_ < StateEnum::ESTABLISHED) {
996     return AsyncSocket::handleConnect();
997   }
998
999   assert(state_ == StateEnum::ESTABLISHED &&
1000          sslState_ == STATE_CONNECTING);
1001   assert(ssl_);
1002
1003   errno = 0;
1004   int ret = SSL_connect(ssl_);
1005   if (ret <= 0) {
1006     int error;
1007     if (willBlock(ret, &error)) {
1008       return;
1009     } else {
1010       sslState_ = STATE_ERROR;
1011       SSLException ex(error, errno);
1012       return failHandshake(__func__, ex);
1013     }
1014   }
1015
1016   handshakeComplete_ = true;
1017   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1018
1019   // Move into STATE_ESTABLISHED in the normal case that we are in
1020   // STATE_CONNECTING.
1021   sslState_ = STATE_ESTABLISHED;
1022
1023   VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; "
1024           << "state=" << int(state_) << ", sslState=" << sslState_
1025           << ", events=" << eventFlags_;
1026
1027   // Remember the EventBase we are attached to, before we start invoking any
1028   // callbacks (since the callbacks may call detachEventBase()).
1029   EventBase* originalEventBase = eventBase_;
1030
1031   // Call the handshake callback.
1032   invokeHandshakeCB();
1033
1034   // Note that the connect callback may have changed our state.
1035   // (set or unset the read callback, called write(), closed the socket, etc.)
1036   // The following code needs to handle these situations correctly.
1037   //
1038   // If the socket has been closed, readCallback_ and writeReqHead_ will
1039   // always be nullptr, so that will prevent us from trying to read or write.
1040   //
1041   // The main thing to check for is if eventBase_ is still originalEventBase.
1042   // If not, we have been detached from this event base, so we shouldn't
1043   // perform any more operations.
1044   if (eventBase_ != originalEventBase) {
1045     return;
1046   }
1047
1048   AsyncSocket::handleInitialReadWrite();
1049 }
1050
1051 void
1052 AsyncSSLSocket::handleRead() noexcept {
1053   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1054           << ", state=" << int(state_) << ", "
1055           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1056   if (state_ < StateEnum::ESTABLISHED) {
1057     return AsyncSocket::handleRead();
1058   }
1059
1060
1061   if (sslState_ == STATE_ACCEPTING) {
1062     assert(server_);
1063     handleAccept();
1064     return;
1065   }
1066   else if (sslState_ == STATE_CONNECTING) {
1067     assert(!server_);
1068     handleConnect();
1069     return;
1070   }
1071
1072   // Normal read
1073   AsyncSocket::handleRead();
1074 }
1075
1076 ssize_t
1077 AsyncSSLSocket::performRead(void* buf, size_t buflen) {
1078   errno = 0;
1079   ssize_t bytes = SSL_read(ssl_, buf, buflen);
1080   if (server_ && renegotiateAttempted_) {
1081     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1082                << ", sslstate=" << sslState_ << ", events=" << eventFlags_ << "): "
1083                << "client intitiated SSL renegotiation not permitted";
1084     // We pack our own SSLerr here with a dummy function
1085     errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
1086                      SSL_CLIENT_RENEGOTIATION_ATTEMPT);
1087     ERR_clear_error();
1088     return READ_ERROR;
1089   }
1090   if (bytes <= 0) {
1091     int error = SSL_get_error(ssl_, bytes);
1092     if (error == SSL_ERROR_WANT_READ) {
1093       // The caller will register for read event if not already.
1094       return READ_BLOCKING;
1095     } else if (error == SSL_ERROR_WANT_WRITE) {
1096       // TODO: Even though we are attempting to read data, SSL_read() may
1097       // need to write data if renegotiation is being performed.  We currently
1098       // don't support this and just fail the read.
1099       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1100                  << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1101                  << "unsupported SSL renegotiation during read",
1102       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
1103                        SSL_INVALID_RENEGOTIATION);
1104       ERR_clear_error();
1105       return READ_ERROR;
1106     } else {
1107       // TODO: Fix this code so that it can return a proper error message
1108       // to the callback, rather than relying on AsyncSocket code which
1109       // can't handle SSL errors.
1110       long lastError = ERR_get_error();
1111
1112       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1113               << "state=" << state_ << ", "
1114               << "sslState=" << sslState_ << ", "
1115               << "events=" << std::hex << eventFlags_ << "): "
1116               << "bytes: " << bytes << ", "
1117               << "error: " << error << ", "
1118               << "errno: " << errno << ", "
1119               << "func: " << ERR_func_error_string(lastError) << ", "
1120               << "reason: " << ERR_reason_error_string(lastError);
1121       ERR_clear_error();
1122       if (zero_return(error, bytes)) {
1123         return bytes;
1124       }
1125       if (error != SSL_ERROR_SYSCALL) {
1126         if ((unsigned long)lastError < 0x8000) {
1127           errno = ENOSYS;
1128         } else {
1129           errno = lastError;
1130         }
1131       }
1132       return READ_ERROR;
1133     }
1134   } else {
1135     appBytesReceived_ += bytes;
1136     return bytes;
1137   }
1138 }
1139
1140 void AsyncSSLSocket::handleWrite() noexcept {
1141   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1142           << ", state=" << int(state_) << ", "
1143           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1144   if (state_ < StateEnum::ESTABLISHED) {
1145     return AsyncSocket::handleWrite();
1146   }
1147
1148   if (sslState_ == STATE_ACCEPTING) {
1149     assert(server_);
1150     handleAccept();
1151     return;
1152   }
1153
1154   if (sslState_ == STATE_CONNECTING) {
1155     assert(!server_);
1156     handleConnect();
1157     return;
1158   }
1159
1160   // Normal write
1161   AsyncSocket::handleWrite();
1162 }
1163
1164 ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
1165                                       uint32_t count,
1166                                       WriteFlags flags,
1167                                       uint32_t* countWritten,
1168                                       uint32_t* partialWritten) {
1169   if (sslState_ != STATE_ESTABLISHED) {
1170     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1171                << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1172                << "TODO: AsyncSSLSocket currently does not support calling "
1173                << "write() before the handshake has fully completed";
1174       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
1175                        SSL_EARLY_WRITE);
1176       return -1;
1177   }
1178
1179   bool cork = isSet(flags, WriteFlags::CORK);
1180   CorkGuard guard(fd_, count > 1, cork, &corked_);
1181
1182   // Declare a buffer used to hold small write requests.  It could point to a
1183   // memory block either on stack or on heap. If it is on heap, we release it
1184   // manually when scope exits
1185   char* combinedBuf{nullptr};
1186   SCOPE_EXIT {
1187     // Note, always keep this check consistent with what we do below
1188     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1189       delete[] combinedBuf;
1190     }
1191   };
1192
1193   *countWritten = 0;
1194   *partialWritten = 0;
1195   ssize_t totalWritten = 0;
1196   size_t bytesStolenFromNextBuffer = 0;
1197   for (uint32_t i = 0; i < count; i++) {
1198     const iovec* v = vec + i;
1199     size_t offset = bytesStolenFromNextBuffer;
1200     bytesStolenFromNextBuffer = 0;
1201     size_t len = v->iov_len - offset;
1202     const void* buf;
1203     if (len == 0) {
1204       (*countWritten)++;
1205       continue;
1206     }
1207     buf = ((const char*)v->iov_base) + offset;
1208
1209     ssize_t bytes;
1210     errno = 0;
1211     uint32_t buffersStolen = 0;
1212     if ((len < minWriteSize_) && ((i + 1) < count)) {
1213       // Combine this buffer with part or all of the next buffers in
1214       // order to avoid really small-grained calls to SSL_write().
1215       // Each call to SSL_write() produces a separate record in
1216       // the egress SSL stream, and we've found that some low-end
1217       // mobile clients can't handle receiving an HTTP response
1218       // header and the first part of the response body in two
1219       // separate SSL records (even if those two records are in
1220       // the same TCP packet).
1221
1222       if (combinedBuf == nullptr) {
1223         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1224           // Allocate the buffer on heap
1225           combinedBuf = new char[minWriteSize_];
1226         } else {
1227           // Allocate the buffer on stack
1228           combinedBuf = (char*)alloca(minWriteSize_);
1229         }
1230       }
1231       assert(combinedBuf != nullptr);
1232
1233       memcpy(combinedBuf, buf, len);
1234       do {
1235         // INVARIANT: i + buffersStolen == complete chunks serialized
1236         uint32_t nextIndex = i + buffersStolen + 1;
1237         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1238                                              minWriteSize_ - len);
1239         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1240                bytesStolenFromNextBuffer);
1241         len += bytesStolenFromNextBuffer;
1242         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1243           // couldn't steal the whole buffer
1244           break;
1245         } else {
1246           bytesStolenFromNextBuffer = 0;
1247           buffersStolen++;
1248         }
1249       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1250       bytes = eorAwareSSLWrite(
1251         ssl_, combinedBuf, len,
1252         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1253
1254     } else {
1255       bytes = eorAwareSSLWrite(ssl_, buf, len,
1256                            (isSet(flags, WriteFlags::EOR) && i + 1 == count));
1257     }
1258
1259     if (bytes <= 0) {
1260       int error = SSL_get_error(ssl_, bytes);
1261       if (error == SSL_ERROR_WANT_WRITE) {
1262         // The caller will register for write event if not already.
1263         *partialWritten = offset;
1264         return totalWritten;
1265       } else if (error == SSL_ERROR_WANT_READ) {
1266         // TODO: Even though we are attempting to write data, SSL_write() may
1267         // need to read data if renegotiation is being performed.  We currently
1268         // don't support this and just fail the write.
1269         LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1270                    << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1271                    << "unsupported SSL renegotiation during write",
1272         errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
1273                          SSL_INVALID_RENEGOTIATION);
1274         ERR_clear_error();
1275         return -1;
1276       } else {
1277         // TODO: Fix this code so that it can return a proper error message
1278         // to the callback, rather than relying on AsyncSocket code which
1279         // can't handle SSL errors.
1280         long lastError = ERR_get_error();
1281         VLOG(3) <<
1282           "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1283                 << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1284                 << "SSL error: " << error << ", errno: " << errno
1285                 << ", func: " << ERR_func_error_string(lastError)
1286                 << ", reason: " << ERR_reason_error_string(lastError);
1287         if (error != SSL_ERROR_SYSCALL) {
1288           if ((unsigned long)lastError < 0x8000) {
1289             errno = ENOSYS;
1290           } else {
1291             errno = lastError;
1292           }
1293         }
1294         ERR_clear_error();
1295         if (!zero_return(error, bytes)) {
1296           return -1;
1297         } // else fall through to below to correctly record totalWritten
1298       }
1299     }
1300
1301     totalWritten += bytes;
1302
1303     if (bytes == (ssize_t)len) {
1304       // The full iovec is written.
1305       (*countWritten) += 1 + buffersStolen;
1306       i += buffersStolen;
1307       // continue
1308     } else {
1309       bytes += offset; // adjust bytes to account for all of v
1310       while (bytes >= (ssize_t)v->iov_len) {
1311         // We combined this buf with part or all of the next one, and
1312         // we managed to write all of this buf but not all of the bytes
1313         // from the next one that we'd hoped to write.
1314         bytes -= v->iov_len;
1315         (*countWritten)++;
1316         v = &(vec[++i]);
1317       }
1318       *partialWritten = bytes;
1319       return totalWritten;
1320     }
1321   }
1322
1323   return totalWritten;
1324 }
1325
1326 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1327                                       bool eor) {
1328   if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
1329     if (appEorByteNo_) {
1330       // cannot track for more than one app byte EOR
1331       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1332     } else {
1333       appEorByteNo_ = appBytesWritten_ + n;
1334     }
1335
1336     // 1. It is fine to keep updating minEorRawByteNo_.
1337     // 2. It is _min_ in the sense that SSL record will add some overhead.
1338     minEorRawByteNo_ = getRawBytesWritten() + n;
1339   }
1340
1341   n = sslWriteImpl(ssl, buf, n);
1342   if (n > 0) {
1343     appBytesWritten_ += n;
1344     if (appEorByteNo_) {
1345       if (getRawBytesWritten() >= minEorRawByteNo_) {
1346         minEorRawByteNo_ = 0;
1347       }
1348       if(appBytesWritten_ == appEorByteNo_) {
1349         appEorByteNo_ = 0;
1350       } else {
1351         CHECK(appBytesWritten_ < appEorByteNo_);
1352       }
1353     }
1354   }
1355   return n;
1356 }
1357
1358 void
1359 AsyncSSLSocket::sslInfoCallback(const SSL *ssl, int where, int ret) {
1360   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1361   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1362     sslSocket->renegotiateAttempted_ = true;
1363   }
1364 }
1365
1366 int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
1367   int ret;
1368   struct msghdr msg;
1369   struct iovec iov;
1370   int flags = 0;
1371   AsyncSSLSocket *tsslSock;
1372
1373   iov.iov_base = const_cast<char *>(in);
1374   iov.iov_len = inl;
1375   memset(&msg, 0, sizeof(msg));
1376   msg.msg_iov = &iov;
1377   msg.msg_iovlen = 1;
1378
1379   tsslSock =
1380     reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
1381   if (tsslSock &&
1382       tsslSock->minEorRawByteNo_ &&
1383       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1384     flags = MSG_EOR;
1385   }
1386
1387   errno = 0;
1388   ret = sendmsg(b->num, &msg, flags);
1389   BIO_clear_retry_flags(b);
1390   if (ret <= 0) {
1391     if (BIO_sock_should_retry(ret))
1392       BIO_set_retry_write(b);
1393   }
1394   return(ret);
1395 }
1396
1397 int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
1398                                        X509_STORE_CTX* x509Ctx) {
1399   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1400     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1401   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1402
1403   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1404           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1405   return (self->handshakeCallback_) ?
1406     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1407     preverifyOk;
1408 }
1409
1410 void AsyncSSLSocket::enableClientHelloParsing()  {
1411     parseClientHello_ = true;
1412     clientHelloInfo_.reset(new ClientHelloInfo());
1413 }
1414
1415 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1416   SSL_set_msg_callback(ssl, nullptr);
1417   SSL_set_msg_callback_arg(ssl, nullptr);
1418   clientHelloInfo_->clientHelloBuf_.clear();
1419 }
1420
1421 void
1422 AsyncSSLSocket::clientHelloParsingCallback(int written, int version,
1423     int contentType, const void *buf, size_t len, SSL *ssl, void *arg)
1424 {
1425   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1426   if (written != 0) {
1427     sock->resetClientHelloParsing(ssl);
1428     return;
1429   }
1430   if (contentType != SSL3_RT_HANDSHAKE) {
1431     sock->resetClientHelloParsing(ssl);
1432     return;
1433   }
1434   if (len == 0) {
1435     return;
1436   }
1437
1438   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1439   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1440   try {
1441     Cursor cursor(clientHelloBuf.front());
1442     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1443       sock->resetClientHelloParsing(ssl);
1444       return;
1445     }
1446
1447     if (cursor.totalLength() < 3) {
1448       clientHelloBuf.trimEnd(len);
1449       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1450       return;
1451     }
1452
1453     uint32_t messageLength = cursor.read<uint8_t>();
1454     messageLength <<= 8;
1455     messageLength |= cursor.read<uint8_t>();
1456     messageLength <<= 8;
1457     messageLength |= cursor.read<uint8_t>();
1458     if (cursor.totalLength() < messageLength) {
1459       clientHelloBuf.trimEnd(len);
1460       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1461       return;
1462     }
1463
1464     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1465     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1466
1467     cursor.skip(4); // gmt_unix_time
1468     cursor.skip(28); // random_bytes
1469
1470     cursor.skip(cursor.read<uint8_t>()); // session_id
1471
1472     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1473     for (int i = 0; i < cipherSuitesLength; i += 2) {
1474       sock->clientHelloInfo_->
1475         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1476     }
1477
1478     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1479     for (int i = 0; i < compressionMethodsLength; ++i) {
1480       sock->clientHelloInfo_->
1481         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1482     }
1483
1484     if (cursor.totalLength() > 0) {
1485       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1486       while (extensionsLength) {
1487         sock->clientHelloInfo_->
1488           clientHelloExtensions_.push_back(cursor.readBE<uint16_t>());
1489         extensionsLength -= 2;
1490         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1491         extensionsLength -= 2;
1492         cursor.skip(extensionDataLength);
1493         extensionsLength -= extensionDataLength;
1494       }
1495     }
1496   } catch (std::out_of_range& e) {
1497     // we'll use what we found and cleanup below.
1498     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1499       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1500   }
1501
1502   sock->resetClientHelloParsing(ssl);
1503 }
1504
1505 } // namespace