From 8aac0e33471dad37f6af287b0b85de2b35fc33a5 Mon Sep 17 00:00:00 2001 From: Subodh Iyengar Date: Fri, 3 Jun 2016 12:40:37 -0700 Subject: [PATCH] Always override write bio method Summary: Always overriding write bio method allows us to more cleanly implement features like eor tracking, support multiple ssl libraries, and also TFO Reviewed By: anirudhvr Differential Revision: D3350482 fbshipit-source-id: ddd2333431f9d636d69c8325b2c18d7cc043b848 --- folly/io/async/AsyncSSLSocket.cpp | 97 +++++++++++++++++-------------- folly/io/async/AsyncSSLSocket.h | 12 +++- 2 files changed, 63 insertions(+), 46 deletions(-) diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 1656589c..4fa6b6fe 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -223,16 +223,16 @@ void setup_SSL_CTX(SSL_CTX *ctx) { } -BIO_METHOD eorAwareBioMethod; +BIO_METHOD sslWriteBioMethod; -void* initEorBioMethod(void) { - memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod)); +void* initsslWriteBioMethod(void) { + memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod)); // override the bwrite method for MSG_EOR support - eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite; + sslWriteBioMethod.bwrite = AsyncSSLSocket::bioWrite; - // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not + // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and - // then have specific handlings. The eorAwareBioWrite should be compatible + // then have specific handlings. The sslWriteBioWrite should be compatible // with the one in openssl. // Return something here to enable AsyncSSLSocket to call this method using @@ -314,8 +314,8 @@ AsyncSSLSocket::~AsyncSSLSocket() { void AsyncSSLSocket::init() { // Do this here to ensure we initialize this once before any use of // AsyncSSLSocket instances and not as part of library load. - static const auto eorAwareBioMethodInitializer = initEorBioMethod(); - (void)eorAwareBioMethodInitializer; + static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod(); + (void)sslWriteBioMethodInitializer; setup_SSL_CTX(ctx_->getSSLCtx()); } @@ -401,36 +401,14 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept { } bool AsyncSSLSocket::isEorTrackingEnabled() const { - if (ssl_ == nullptr) { - return false; - } - const BIO *wb = SSL_get_wbio(ssl_); - return wb && wb->method == &eorAwareBioMethod; + return trackEor_; } void AsyncSSLSocket::setEorTracking(bool track) { - BIO *wb = SSL_get_wbio(ssl_); - if (!wb) { - throw AsyncSocketException(AsyncSocketException::INVALID_STATE, - "setting EOR tracking without an initialized " - "BIO"); - } - - if (track) { - if (wb->method != &eorAwareBioMethod) { - // only do this if we didn't - wb->method = &eorAwareBioMethod; - BIO_set_app_data(wb, this); - appEorByteNo_ = 0; - minEorRawByteNo_ = 0; - } - } else if (wb->method == &eorAwareBioMethod) { - wb->method = BIO_s_socket(); - BIO_set_app_data(wb, nullptr); + if (trackEor_ != track) { + trackEor_ = track; appEorByteNo_ = 0; minEorRawByteNo_ = 0; - } else { - CHECK(wb->method == BIO_s_socket()); } } @@ -703,6 +681,19 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) { } } +bool AsyncSSLSocket::setupSSLBio() { + auto wb = BIO_new(&sslWriteBioMethod); + + if (!wb) { + return false; + } + + BIO_set_app_data(wb, this); + BIO_set_fd(wb, fd_, BIO_NOCLOSE); + SSL_set_bio(ssl_, wb, wb); + return true; +} + void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, const SSLContext::SSLVerifyPeerEnum& verifyPeer) { DestructorGuard dg(this); @@ -741,9 +732,15 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, return failHandshake(__func__, ex); } + if (!setupSSLBio()) { + sslState_ = STATE_ERROR; + AsyncSocketException ex( + AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio"); + return failHandshake(__func__, ex); + } + applyVerificationOptions(ssl_); - SSL_set_fd(ssl_, fd_); if (sslSession_ != nullptr) { SSL_set_session(ssl_, sslSession_); SSL_SESSION_free(sslSession_); @@ -1010,7 +1007,14 @@ AsyncSSLSocket::handleAccept() noexcept { << ", fd=" << fd_ << "): " << e.what(); return failHandshake(__func__, ex); } - SSL_set_fd(ssl_, fd_); + + if (!setupSSLBio()) { + sslState_ = STATE_ERROR; + AsyncSocketException ex( + AsyncSocketException::INTERNAL_ERROR, "error creating write bio"); + return failHandshake(__func__, ex); + } + SSL_set_ex_data(ssl_, getSSLExDataIndex(), this); applyVerificationOptions(ssl_); @@ -1448,7 +1452,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite( int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n, bool eor) { - if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) { + if (eor && trackEor_) { if (appEorByteNo_) { // cannot track for more than one app byte EOR CHECK(appEorByteNo_ == appBytesWritten_ + n); @@ -1493,34 +1497,37 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) { } } -int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) { +int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { int ret; struct msghdr msg; struct iovec iov; int flags = 0; - AsyncSSLSocket *tsslSock; + AsyncSSLSocket* tsslSock; - iov.iov_base = const_cast(in); + iov.iov_base = const_cast(in); iov.iov_len = inl; memset(&msg, 0, sizeof(msg)); msg.msg_iov = &iov; msg.msg_iovlen = 1; - tsslSock = - reinterpret_cast(BIO_get_app_data(b)); - if (tsslSock && - tsslSock->minEorRawByteNo_ && + auto appData = BIO_get_app_data(b); + CHECK(appData); + + tsslSock = reinterpret_cast(appData); + CHECK(tsslSock); + + if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ && tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) { flags = MSG_EOR; } - ret = sendmsg(b->num, &msg, flags); + ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags); BIO_clear_retry_flags(b); if (ret <= 0) { if (BIO_sock_should_retry(ret)) BIO_set_retry_write(b); } - return(ret); + return ret; } int AsyncSSLSocket::sslVerifyCallback(int preverifyOk, diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index b4e4ca47..40ceb87a 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -652,7 +652,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { static int getSSLExDataIndex(); static AsyncSSLSocket* getFromSSL(const SSL *ssl); - static int eorAwareBioWrite(BIO *b, const char *in, int inl); + static int bioWrite(BIO* b, const char* in, int inl); void resetClientHelloParsing(SSL *ssl); static void clientHelloParsingCallback(int write_p, int version, int content_type, const void *buf, size_t len, SSL *ssl, void *arg); @@ -774,6 +774,13 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ void applyVerificationOptions(SSL * ssl); + /** + * Sets up SSL with a custom write bio which intercepts all writes. + * + * @return true, if succeeds and false if there is an error creating the bio. + */ + bool setupSSLBio(); + /** * A SSL_write wrapper that understand EOR * @@ -815,6 +822,9 @@ class AsyncSSLSocket : public virtual AsyncSocket { // whether the SSL session was resumed using session ID or not bool sessionIDResumed_{false}; + // Whether to track EOR or not. + bool trackEor_{false}; + // The app byte num that we are tracking for the MSG_EOR // Only one app EOR byte can be tracked. size_t appEorByteNo_{0}; -- 2.34.1