Always override write bio method
authorSubodh Iyengar <subodh@fb.com>
Fri, 3 Jun 2016 19:40:37 +0000 (12:40 -0700)
committerFacebook Github Bot 7 <facebook-github-bot-7-bot@fb.com>
Fri, 3 Jun 2016 19:53:22 +0000 (12:53 -0700)
Summary:
Always overriding write bio method
allows us to more cleanly implement
features like eor tracking, support
multiple ssl libraries, and also TFO

Reviewed By: anirudhvr

Differential Revision: D3350482

fbshipit-source-id: ddd2333431f9d636d69c8325b2c18d7cc043b848

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSSLSocket.h

index 1656589cd34f1e8f6b2ae7bdc112d3edab2a22d4..4fa6b6fef26b57f6cbb62cb1c083a07f40fa1e3b 100644 (file)
@@ -223,16 +223,16 @@ void setup_SSL_CTX(SSL_CTX *ctx) {
 
 }
 
-BIO_METHOD eorAwareBioMethod;
+BIO_METHOD sslWriteBioMethod;
 
-void* initEorBioMethod(void) {
-  memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
+void* initsslWriteBioMethod(void) {
+  memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
   // override the bwrite method for MSG_EOR support
-  eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
+  sslWriteBioMethod.bwrite = AsyncSSLSocket::bioWrite;
 
-  // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
+  // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
-  // then have specific handlings. The eorAwareBioWrite should be compatible
+  // then have specific handlings. The sslWriteBioWrite should be compatible
   // with the one in openssl.
 
   // Return something here to enable AsyncSSLSocket to call this method using
@@ -314,8 +314,8 @@ AsyncSSLSocket::~AsyncSSLSocket() {
 void AsyncSSLSocket::init() {
   // Do this here to ensure we initialize this once before any use of
   // AsyncSSLSocket instances and not as part of library load.
-  static const auto eorAwareBioMethodInitializer = initEorBioMethod();
-  (void)eorAwareBioMethodInitializer;
+  static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
+  (void)sslWriteBioMethodInitializer;
 
   setup_SSL_CTX(ctx_->getSSLCtx());
 }
@@ -401,36 +401,14 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
 }
 
 bool AsyncSSLSocket::isEorTrackingEnabled() const {
-  if (ssl_ == nullptr) {
-    return false;
-  }
-  const BIO *wb = SSL_get_wbio(ssl_);
-  return wb && wb->method == &eorAwareBioMethod;
+  return trackEor_;
 }
 
 void AsyncSSLSocket::setEorTracking(bool track) {
-  BIO *wb = SSL_get_wbio(ssl_);
-  if (!wb) {
-    throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
-                              "setting EOR tracking without an initialized "
-                              "BIO");
-  }
-
-  if (track) {
-    if (wb->method != &eorAwareBioMethod) {
-      // only do this if we didn't
-      wb->method = &eorAwareBioMethod;
-      BIO_set_app_data(wb, this);
-      appEorByteNo_ = 0;
-      minEorRawByteNo_ = 0;
-    }
-  } else if (wb->method == &eorAwareBioMethod) {
-    wb->method = BIO_s_socket();
-    BIO_set_app_data(wb, nullptr);
+  if (trackEor_ != track) {
+    trackEor_ = track;
     appEorByteNo_ = 0;
     minEorRawByteNo_ = 0;
-  } else {
-    CHECK(wb->method == BIO_s_socket());
   }
 }
 
@@ -703,6 +681,19 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
   }
 }
 
+bool AsyncSSLSocket::setupSSLBio() {
+  auto wb = BIO_new(&sslWriteBioMethod);
+
+  if (!wb) {
+    return false;
+  }
+
+  BIO_set_app_data(wb, this);
+  BIO_set_fd(wb, fd_, BIO_NOCLOSE);
+  SSL_set_bio(ssl_, wb, wb);
+  return true;
+}
+
 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
@@ -741,9 +732,15 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
     return failHandshake(__func__, ex);
   }
 
+  if (!setupSSLBio()) {
+    sslState_ = STATE_ERROR;
+    AsyncSocketException ex(
+        AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
+    return failHandshake(__func__, ex);
+  }
+
   applyVerificationOptions(ssl_);
 
-  SSL_set_fd(ssl_, fd_);
   if (sslSession_ != nullptr) {
     SSL_set_session(ssl_, sslSession_);
     SSL_SESSION_free(sslSession_);
@@ -1010,7 +1007,14 @@ AsyncSSLSocket::handleAccept() noexcept {
                  << ", fd=" << fd_ << "): " << e.what();
       return failHandshake(__func__, ex);
     }
-    SSL_set_fd(ssl_, fd_);
+
+    if (!setupSSLBio()) {
+      sslState_ = STATE_ERROR;
+      AsyncSocketException ex(
+          AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
+      return failHandshake(__func__, ex);
+    }
+
     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
 
     applyVerificationOptions(ssl_);
@@ -1448,7 +1452,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
 
 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
                                       bool eor) {
-  if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
+  if (eor && trackEor_) {
     if (appEorByteNo_) {
       // cannot track for more than one app byte EOR
       CHECK(appEorByteNo_ == appBytesWritten_ + n);
@@ -1493,34 +1497,37 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
   }
 }
 
-int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
+int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   int ret;
   struct msghdr msg;
   struct iovec iov;
   int flags = 0;
-  AsyncSSLSocket *tsslSock;
+  AsyncSSLSockettsslSock;
 
-  iov.iov_base = const_cast<char *>(in);
+  iov.iov_base = const_cast<char*>(in);
   iov.iov_len = inl;
   memset(&msg, 0, sizeof(msg));
   msg.msg_iov = &iov;
   msg.msg_iovlen = 1;
 
-  tsslSock =
-    reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
-  if (tsslSock &&
-      tsslSock->minEorRawByteNo_ &&
+  auto appData = BIO_get_app_data(b);
+  CHECK(appData);
+
+  tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+  CHECK(tsslSock);
+
+  if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
     flags = MSG_EOR;
   }
 
-  ret = sendmsg(b->num, &msg, flags);
+  ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags);
   BIO_clear_retry_flags(b);
   if (ret <= 0) {
     if (BIO_sock_should_retry(ret))
       BIO_set_retry_write(b);
   }
-  return(ret);
+  return ret;
 }
 
 int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
index b4e4ca47b134a4099ff54e11c42d1cbe476b459e..40ceb87aa51a2935f604aa5e755b23a049a1f5d9 100644 (file)
@@ -652,7 +652,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   static int getSSLExDataIndex();
   static AsyncSSLSocket* getFromSSL(const SSL *ssl);
-  static int eorAwareBioWrite(BIO *b, const char *in, int inl);
+  static int bioWrite(BIO* b, const char* in, int inl);
   void resetClientHelloParsing(SSL *ssl);
   static void clientHelloParsingCallback(int write_p, int version,
       int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
@@ -774,6 +774,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   void applyVerificationOptions(SSL * ssl);
 
+  /**
+   * Sets up SSL with a custom write bio which intercepts all writes.
+   *
+   * @return true, if succeeds and false if there is an error creating the bio.
+   */
+  bool setupSSLBio();
+
   /**
    * A SSL_write wrapper that understand EOR
    *
@@ -815,6 +822,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // whether the SSL session was resumed using session ID or not
   bool sessionIDResumed_{false};
 
+  // Whether to track EOR or not.
+  bool trackEor_{false};
+
   // The app byte num that we are tracking for the MSG_EOR
   // Only one app EOR byte can be tracked.
   size_t appEorByteNo_{0};