From: Dave Watson Date: Tue, 14 Oct 2014 18:10:48 +0000 (-0700) Subject: Move SSL socket to folly X-Git-Tag: v0.22.0~188 X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=ff9b70f3cd1f05fb8e8c4351248cd9f748c2644a;p=folly.git Move SSL socket to folly Summary: One of the last thrift -> folly moves. The only change was the exception types - there are small wrapper classes in thrift/lib/cpp/async left to convert from AsyncSocketException to TTransportException. Test Plan: run unit tests Reviewed By: dcsommer@fb.com Subscribers: jdperlow, trunkagent, doug, bmatheny, ssl-diffs@, njormrod, mshneer, folly-diffs@, fugalh, jsedgwick, andrewcox, alandau FB internal diff: D1632425 Signature: t1:1632425:1414526483:339ae107bacb073bdd8cf0942fd0f6b70990feb4 --- diff --git a/folly/Makefile.am b/folly/Makefile.am index 99e237a7..a95e558c 100644 --- a/folly/Makefile.am +++ b/folly/Makefile.am @@ -132,7 +132,9 @@ nobase_follyinclude_HEADERS = \ io/async/AsyncTimeout.h \ io/async/AsyncTransport.h \ io/async/AsyncServerSocket.h \ + io/async/AsyncSSLServerSocket.h \ io/async/AsyncSocket.h \ + io/async/AsyncSSLSocket.h \ io/async/AsyncSocketException.h \ io/async/DelayedDestruction.h \ io/async/EventBase.h \ @@ -265,7 +267,9 @@ libfolly_la_SOURCES = \ io/ShutdownSocketSet.cpp \ io/async/AsyncTimeout.cpp \ io/async/AsyncServerSocket.cpp \ + io/async/AsyncSSLServerSocket.cpp \ io/async/AsyncSocket.cpp \ + io/async/AsyncSSLSocket.cpp \ io/async/EventBase.cpp \ io/async/EventBaseManager.cpp \ io/async/EventHandler.cpp \ diff --git a/folly/io/async/AsyncSSLServerSocket.cpp b/folly/io/async/AsyncSSLServerSocket.cpp new file mode 100644 index 00000000..31ae9c8d --- /dev/null +++ b/folly/io/async/AsyncSSLServerSocket.cpp @@ -0,0 +1,101 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +using std::shared_ptr; + +namespace folly { + +AsyncSSLServerSocket::AsyncSSLServerSocket( + const shared_ptr& ctx, + EventBase* eventBase) + : eventBase_(eventBase) + , serverSocket_(new AsyncServerSocket(eventBase)) + , ctx_(ctx) + , sslCallback_(nullptr) { +} + +AsyncSSLServerSocket::~AsyncSSLServerSocket() { +} + +void AsyncSSLServerSocket::destroy() { + // Stop accepting on the underlying socket as soon as destroy is called + if (sslCallback_ != nullptr) { + serverSocket_->pauseAccepting(); + serverSocket_->removeAcceptCallback(this, nullptr); + } + serverSocket_->destroy(); + serverSocket_ = nullptr; + sslCallback_ = nullptr; + + DelayedDestruction::destroy(); +} + +void AsyncSSLServerSocket::setSSLAcceptCallback(SSLAcceptCallback* callback) { + SSLAcceptCallback *oldCallback = sslCallback_; + sslCallback_ = callback; + if (callback != nullptr && oldCallback == nullptr) { + serverSocket_->addAcceptCallback(this, nullptr); + serverSocket_->startAccepting(); + } else if (callback == nullptr && oldCallback != nullptr) { + serverSocket_->removeAcceptCallback(this, nullptr); + serverSocket_->pauseAccepting(); + } +} + +void AsyncSSLServerSocket::attachEventBase(EventBase* eventBase) { + assert(sslCallback_ == nullptr); + eventBase_ = eventBase; + serverSocket_->attachEventBase(eventBase); +} + +void AsyncSSLServerSocket::detachEventBase() { + serverSocket_->detachEventBase(); + eventBase_ = nullptr; +} + +void +AsyncSSLServerSocket::connectionAccepted( + int fd, + const folly::SocketAddress& clientAddr) noexcept { + shared_ptr sslSock; + try { + // Create a AsyncSSLSocket object with the fd. The socket should be + // added to the event base and in the state of accepting SSL connection. + sslSock = AsyncSSLSocket::newSocket(ctx_, eventBase_, fd); + } catch (const std::exception &e) { + LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket " + "object with socket " << e.what() << fd; + ::close(fd); + sslCallback_->acceptError(e); + return; + } + + // TODO: Perform the SSL handshake before invoking the callback + sslCallback_->connectionAccepted(sslSock); +} + +void AsyncSSLServerSocket::acceptError(const std::exception& ex) + noexcept { + LOG(ERROR) << "AsyncSSLServerSocket accept error: " << ex.what(); + sslCallback_->acceptError(ex); +} + +} // namespace diff --git a/folly/io/async/AsyncSSLServerSocket.h b/folly/io/async/AsyncSSLServerSocket.h new file mode 100644 index 00000000..56d59493 --- /dev/null +++ b/folly/io/async/AsyncSSLServerSocket.h @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#pragma once + +#include +#include + + +namespace folly { +class SocketAddress; +class AsyncSSLSocket; + +class AsyncSSLServerSocket : public DelayedDestruction, + private AsyncServerSocket::AcceptCallback { + public: + class SSLAcceptCallback { + public: + virtual ~SSLAcceptCallback() {} + + /** + * connectionAccepted() is called whenever a new client connection is + * received. + * + * The SSLAcceptCallback will remain installed after connectionAccepted() + * returns. + * + * @param sock The newly accepted client socket. The + * SSLAcceptCallback + * assumes ownership of this socket, and is responsible + * for closing it when done. + */ + virtual void connectionAccepted( + const std::shared_ptr &sock) + noexcept = 0; + + /** + * acceptError() is called if an error occurs while accepting. + * + * The SSLAcceptCallback will remain installed even after an accept error. + * If the callback wants to uninstall itself and stop trying to accept new + * connections, it must explicit call setAcceptCallback(nullptr). + * + * @param ex An exception representing the error. + */ + virtual void acceptError(const std::exception& ex) noexcept = 0; + }; + + /** + * Create a new TAsyncSSLServerSocket with the specified EventBase. + * + * @param eventBase The EventBase to use for driving the asynchronous I/O. + * If this parameter is nullptr, attachEventBase() must be + * called before this socket can begin accepting + * connections. All TAsyncSSLSocket objects accepted by + * this server socket will be attached to this EventBase + * when they are created. + */ + explicit AsyncSSLServerSocket( + const std::shared_ptr& ctx, + EventBase* eventBase = nullptr); + + /** + * Destroy the socket. + * + * destroy() must be called to destroy the socket. The normal destructor is + * private, and should not be invoked directly. This prevents callers from + * deleting a TAsyncSSLServerSocket while it is invoking a callback. + */ + virtual void destroy(); + + virtual void bind(const folly::SocketAddress& address) { + serverSocket_->bind(address); + } + virtual void bind(uint16_t port) { + serverSocket_->bind(port); + } + void getAddress(folly::SocketAddress* addressReturn) { + serverSocket_->getAddress(addressReturn); + } + virtual void listen(int backlog) { + serverSocket_->listen(backlog); + } + + /** + * Helper function to create a shared_ptr. + * + * This passes in the correct destructor object, since TAsyncSSLServerSocket's + * destructor is protected and cannot be invoked directly. + */ + static std::shared_ptr newSocket( + const std::shared_ptr& ctx, + EventBase* evb) { + return std::shared_ptr( + new AsyncSSLServerSocket(ctx, evb), + Destructor()); + } + + /** + * Set the accept callback. + * + * This method may only be invoked from the EventBase's loop thread. + * + * @param callback The callback to invoke when a new socket + * connection is accepted and a new TAsyncSSLSocket is + * created. + * + * Throws TTransportException on error. + */ + void setSSLAcceptCallback(SSLAcceptCallback* callback); + + SSLAcceptCallback *getSSLAcceptCallback() const { + return sslCallback_; + } + + void attachEventBase(EventBase* eventBase); + void detachEventBase(); + + /** + * Returns the EventBase that the handler is currently attached to. + */ + EventBase* getEventBase() const { + return eventBase_; + } + + protected: + /** + * Protected destructor. + * + * Invoke destroy() instead to destroy the TAsyncSSLServerSocket. + */ + virtual ~AsyncSSLServerSocket(); + + protected: + virtual void connectionAccepted(int fd, + const folly::SocketAddress& clientAddr) + noexcept; + virtual void acceptError(const std::exception& ex) noexcept; + + EventBase* eventBase_; + AsyncServerSocket* serverSocket_; + // SSL context + std::shared_ptr ctx_; + // The accept callback + SSLAcceptCallback* sslCallback_; +}; + +} // namespace diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp new file mode 100644 index 00000000..69efb113 --- /dev/null +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -0,0 +1,1478 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +using folly::SocketAddress; +using folly::SSLContext; +using std::string; +using std::shared_ptr; + +using folly::Endian; +using folly::IOBuf; +using folly::io::Cursor; +using folly::io::PortableSpinLock; +using folly::io::PortableSpinLockGuard; +using std::unique_ptr; +using std::bind; + +namespace { +using folly::AsyncSocket; +using folly::AsyncSocketException; +using folly::AsyncSSLSocket; +using folly::Optional; + +/** Try to avoid calling SSL_write() for buffers smaller than this: */ +size_t MIN_WRITE_SIZE = 1500; + +// We have one single dummy SSL context so that we can implement attach +// and detach methods in a thread safe fashion without modifying opnessl. +static SSLContext *dummyCtx = nullptr; +static PortableSpinLock dummyCtxLock; + +// Numbers chosen as to not collide with functions in ssl.h +const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90; +const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91; + +// This converts "illegal" shutdowns into ZERO_RETURN +inline bool zero_return(int error, int rc) { + return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0)); +} + +class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback, + public AsyncSSLSocket::HandshakeCB { + + private: + AsyncSSLSocket *sslSocket_; + AsyncSSLSocket::ConnectCallback *callback_; + int timeout_; + int64_t startTime_; + + protected: + virtual ~AsyncSSLSocketConnector() { + } + + public: + AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket, + AsyncSocket::ConnectCallback *callback, + int timeout) : + sslSocket_(sslSocket), + callback_(callback), + timeout_(timeout), + startTime_(std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()).count()) { + } + + virtual void connectSuccess() noexcept { + VLOG(7) << "client socket connected"; + + int64_t timeoutLeft = 0; + if (timeout_ > 0) { + auto curTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()).count(); + + timeoutLeft = timeout_ - (curTime - startTime_); + if (timeoutLeft <= 0) { + AsyncSocketException ex(AsyncSocketException::TIMED_OUT, + "SSL connect timed out"); + fail(ex); + delete this; + return; + } + } + sslSocket_->sslConn(this, timeoutLeft); + } + + virtual void connectErr(const AsyncSocketException& ex) noexcept { + LOG(ERROR) << "TCP connect failed: " << ex.what(); + fail(ex); + delete this; + } + + virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept { + VLOG(7) << "client handshake success"; + if (callback_) { + callback_->connectSuccess(); + } + delete this; + } + + virtual void handshakeErr(AsyncSSLSocket *socket, + const AsyncSocketException& ex) noexcept { + LOG(ERROR) << "client handshakeErr: " << ex.what(); + fail(ex); + delete this; + } + + void fail(const AsyncSocketException &ex) { + // fail is a noop if called twice + if (callback_) { + AsyncSSLSocket::ConnectCallback *cb = callback_; + callback_ = nullptr; + + cb->connectErr(ex); + sslSocket_->closeNow(); + // closeNow can call handshakeErr if it hasn't been called already. + // So this may have been deleted, no member variable access beyond this + // point + // Note that closeNow may invoke writeError callbacks if the socket had + // write data pending connection completion. + } + } +}; + +// XXX: implement an equivalent to corking for platforms with TCP_NOPUSH? +#ifdef TCP_CORK // Linux-only +/** + * Utility class that corks a TCP socket upon construction or uncorks + * the socket upon destruction + */ +class CorkGuard : private boost::noncopyable { + public: + CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked): + fd_(fd), haveMore_(haveMore), corked_(corked) { + if (*corked_) { + // socket is already corked; nothing to do + return; + } + if (multipleWrites || haveMore) { + // We are performing multiple writes in this performWrite() call, + // and/or there are more calls to performWrite() that will be invoked + // later, so enable corking + int flag = 1; + setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag)); + *corked_ = true; + } + } + + ~CorkGuard() { + if (haveMore_) { + // more data to come; don't uncork yet + return; + } + if (!*corked_) { + // socket isn't corked; nothing to do + return; + } + + int flag = 0; + setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag)); + *corked_ = false; + } + + private: + int fd_; + bool haveMore_; + bool* corked_; +}; +#else +class CorkGuard : private boost::noncopyable { + public: + CorkGuard(int, bool, bool, bool*) {} +}; +#endif + +void setup_SSL_CTX(SSL_CTX *ctx) { +#ifdef SSL_MODE_RELEASE_BUFFERS + SSL_CTX_set_mode(ctx, + SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | + SSL_MODE_ENABLE_PARTIAL_WRITE + | SSL_MODE_RELEASE_BUFFERS + ); +#else + SSL_CTX_set_mode(ctx, + SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | + SSL_MODE_ENABLE_PARTIAL_WRITE + ); +#endif +} + +BIO_METHOD eorAwareBioMethod; + +__attribute__((__constructor__)) +void initEorBioMethod(void) { + memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod)); + // override the bwrite method for MSG_EOR support + eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite; + + // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not + // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and + // then have specific handlings. The eorAwareBioWrite should be compatible + // with the one in openssl. +} + +} // anonymous namespace + +namespace folly { + +SSLException::SSLException(int sslError, int errno_copy): + AsyncSocketException( + AsyncSocketException::SSL_ERROR, + ERR_error_string(sslError, msg_), + sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {} + +/** + * Create a client AsyncSSLSocket + */ +AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, + EventBase* evb) : + AsyncSocket(evb), + ctx_(ctx), + handshakeTimeout_(this, evb) { + setup_SSL_CTX(ctx_->getSSLCtx()); +} + +/** + * Create a server/client AsyncSSLSocket + */ +AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, + EventBase* evb, int fd, bool server) : + AsyncSocket(evb, fd), + server_(server), + ctx_(ctx), + handshakeTimeout_(this, evb) { + setup_SSL_CTX(ctx_->getSSLCtx()); + if (server) { + SSL_CTX_set_info_callback(ctx_->getSSLCtx(), + AsyncSSLSocket::sslInfoCallback); + } +} + +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +/** + * Create a client AsyncSSLSocket and allow tlsext_hostname + * to be sent in Client Hello. + */ +AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, + EventBase* evb, + const std::string& serverName) : + AsyncSocket(evb), + ctx_(ctx), + handshakeTimeout_(this, evb), + tlsextHostname_(serverName) { + setup_SSL_CTX(ctx_->getSSLCtx()); +} + +/** + * Create a client AsyncSSLSocket from an already connected fd + * and allow tlsext_hostname to be sent in Client Hello. + */ +AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, + EventBase* evb, int fd, + const std::string& serverName) : + AsyncSocket(evb, fd), + ctx_(ctx), + handshakeTimeout_(this, evb), + tlsextHostname_(serverName) { + setup_SSL_CTX(ctx_->getSSLCtx()); +} +#endif + +AsyncSSLSocket::~AsyncSSLSocket() { + VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this + << ", evb=" << eventBase_ << ", fd=" << fd_ + << ", state=" << int(state_) << ", sslState=" + << sslState_ << ", events=" << eventFlags_ << ")"; +} + +void AsyncSSLSocket::closeNow() { + // Close the SSL connection. + if (ssl_ != nullptr && fd_ != -1) { + int rc = SSL_shutdown(ssl_); + if (rc == 0) { + rc = SSL_shutdown(ssl_); + } + if (rc < 0) { + ERR_clear_error(); + } + } + + if (sslSession_ != nullptr) { + SSL_SESSION_free(sslSession_); + sslSession_ = nullptr; + } + + sslState_ = STATE_CLOSED; + + if (handshakeTimeout_.isScheduled()) { + handshakeTimeout_.cancelTimeout(); + } + + DestructorGuard dg(this); + + if (handshakeCallback_) { + AsyncSocketException ex(AsyncSocketException::END_OF_FILE, + "SSL connection closed locally"); + HandshakeCB* callback = handshakeCallback_; + handshakeCallback_ = nullptr; + callback->handshakeErr(this, ex); + } + + if (ssl_ != nullptr) { + SSL_free(ssl_); + ssl_ = nullptr; + } + + // Close the socket. + AsyncSocket::closeNow(); +} + +void AsyncSSLSocket::shutdownWrite() { + // SSL sockets do not support half-shutdown, so just perform a full shutdown. + // + // (Performing a full shutdown here is more desirable than doing nothing at + // all. The purpose of shutdownWrite() is normally to notify the other end + // of the connection that no more data will be sent. If we do nothing, the + // other end will never know that no more data is coming, and this may result + // in protocol deadlock.) + close(); +} + +void AsyncSSLSocket::shutdownWriteNow() { + closeNow(); +} + +bool AsyncSSLSocket::good() const { + return (AsyncSocket::good() && + (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING || + sslState_ == STATE_ESTABLISHED)); +} + +// The TAsyncTransport definition of 'good' states that the transport is +// ready to perform reads and writes, so sslState_ == UNINIT must report !good. +// connecting can be true when the sslState_ == UNINIT because the AsyncSocket +// is connected but we haven't initiated the call to SSL_connect. +bool AsyncSSLSocket::connecting() const { + return (!server_ && + (AsyncSocket::connecting() || + (AsyncSocket::good() && (sslState_ == STATE_UNINIT || + sslState_ == STATE_CONNECTING)))); +} + +bool AsyncSSLSocket::isEorTrackingEnabled() const { + const BIO *wb = SSL_get_wbio(ssl_); + return wb && wb->method == &eorAwareBioMethod; +} + +void AsyncSSLSocket::setEorTracking(bool track) { + BIO *wb = SSL_get_wbio(ssl_); + if (!wb) { + throw AsyncSocketException(AsyncSocketException::INVALID_STATE, + "setting EOR tracking without an initialized " + "BIO"); + } + + if (track) { + if (wb->method != &eorAwareBioMethod) { + // only do this if we didn't + wb->method = &eorAwareBioMethod; + BIO_set_app_data(wb, this); + appEorByteNo_ = 0; + minEorRawByteNo_ = 0; + } + } else if (wb->method == &eorAwareBioMethod) { + wb->method = BIO_s_socket(); + BIO_set_app_data(wb, nullptr); + appEorByteNo_ = 0; + minEorRawByteNo_ = 0; + } else { + CHECK(wb->method == BIO_s_socket()); + } +} + +size_t AsyncSSLSocket::getRawBytesWritten() const { + BIO *b; + if (!ssl_ || !(b = SSL_get_wbio(ssl_))) { + return 0; + } + + return BIO_number_written(b); +} + +size_t AsyncSSLSocket::getRawBytesReceived() const { + BIO *b; + if (!ssl_ || !(b = SSL_get_rbio(ssl_))) { + return 0; + } + + return BIO_number_read(b); +} + + +void AsyncSSLSocket::invalidState(HandshakeCB* callback) { + LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_ + << ", state=" << int(state_) << ", sslState=" << sslState_ << ", " + << "events=" << eventFlags_ << ", server=" << short(server_) << "): " + << "sslAccept/Connect() called in invalid " + << "state, handshake callback " << handshakeCallback_ << ", new callback " + << callback; + assert(!handshakeTimeout_.isScheduled()); + sslState_ = STATE_ERROR; + + AsyncSocketException ex(AsyncSocketException::INVALID_STATE, + "sslAccept() called with socket in invalid state"); + + if (callback) { + callback->handshakeErr(this, ex); + } + + // Check the socket state not the ssl state here. + if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) { + failHandshake(__func__, ex); + } +} + +void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout, + const SSLContext::SSLVerifyPeerEnum& verifyPeer) { + DestructorGuard dg(this); + assert(eventBase_->isInEventBaseThread()); + verifyPeer_ = verifyPeer; + + // Make sure we're in the uninitialized state + if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) { + return invalidState(callback); + } + + sslState_ = STATE_ACCEPTING; + handshakeCallback_ = callback; + + if (timeout > 0) { + handshakeTimeout_.scheduleTimeout(timeout); + } + + /* register for a read operation (waiting for CLIENT HELLO) */ + updateEventRegistration(EventHandler::READ, EventHandler::WRITE); +} + +#if OPENSSL_VERSION_NUMBER >= 0x009080bfL +void AsyncSSLSocket::attachSSLContext( + const std::shared_ptr& ctx) { + + // Check to ensure we are in client mode. Changing a server's ssl + // context doesn't make sense since clients of that server would likely + // become confused when the server's context changes. + DCHECK(!server_); + DCHECK(!ctx_); + DCHECK(ctx); + DCHECK(ctx->getSSLCtx()); + ctx_ = ctx; + + // In order to call attachSSLContext, detachSSLContext must have been + // previously called which sets the socket's context to the dummy + // context. Thus we must acquire this lock. + PortableSpinLockGuard guard(dummyCtxLock); + SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx()); +} + +void AsyncSSLSocket::detachSSLContext() { + DCHECK(ctx_); + ctx_.reset(); + // We aren't using the initial_ctx for now, and it can introduce race + // conditions in the destructor of the SSL object. +#ifndef OPENSSL_NO_TLSEXT + if (ssl_->initial_ctx) { + SSL_CTX_free(ssl_->initial_ctx); + ssl_->initial_ctx = nullptr; + } +#endif + PortableSpinLockGuard guard(dummyCtxLock); + if (nullptr == dummyCtx) { + // We need to lazily initialize the dummy context so we don't + // accidentally override any programmatic settings to openssl + dummyCtx = new SSLContext; + } + // We must remove this socket's references to its context right now + // since this socket could get passed to any thread. If the context has + // had its locking disabled, just doing a set in attachSSLContext() + // would not be thread safe. + SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx()); +} +#endif + +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +void AsyncSSLSocket::switchServerSSLContext( + const std::shared_ptr& handshakeCtx) { + CHECK(server_); + if (sslState_ != STATE_ACCEPTING) { + // We log it here and allow the switch. + // It should not affect our re-negotiation support (which + // is not supported now). + VLOG(6) << "fd=" << getFd() + << " renegotation detected when switching SSL_CTX"; + } + + setup_SSL_CTX(handshakeCtx->getSSLCtx()); + SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(), + AsyncSSLSocket::sslInfoCallback); + handshakeCtx_ = handshakeCtx; + SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx()); +} + +bool AsyncSSLSocket::isServerNameMatch() const { + CHECK(!server_); + + if (!ssl_) { + return false; + } + + SSL_SESSION *ss = SSL_get_session(ssl_); + if (!ss) { + return false; + } + + return (ss->tlsext_hostname ? true : false); +} + +void AsyncSSLSocket::setServerName(std::string serverName) noexcept { + tlsextHostname_ = std::move(serverName); +} + +#endif + +void AsyncSSLSocket::timeoutExpired() noexcept { + if (state_ == StateEnum::ESTABLISHED && + (sslState_ == STATE_CACHE_LOOKUP || + sslState_ == STATE_RSA_ASYNC_PENDING)) { + sslState_ = STATE_ERROR; + // We are expecting a callback in restartSSLAccept. The cache lookup + // and rsa-call necessarily have pointers to this ssl socket, so delay + // the cleanup until he calls us back. + } else { + assert(state_ == StateEnum::ESTABLISHED && + (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING)); + DestructorGuard dg(this); + AsyncSocketException ex(AsyncSocketException::TIMED_OUT, + (sslState_ == STATE_CONNECTING) ? + "SSL connect timed out" : "SSL accept timed out"); + failHandshake(__func__, ex); + } +} + +int AsyncSSLSocket::sslExDataIndex_ = -1; +std::mutex AsyncSSLSocket::mutex_; + +int AsyncSSLSocket::getSSLExDataIndex() { + if (sslExDataIndex_ < 0) { + std::lock_guard g(mutex_); + if (sslExDataIndex_ < 0) { + sslExDataIndex_ = SSL_get_ex_new_index(0, + (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr); + } + } + return sslExDataIndex_; +} + +AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) { + return static_cast(SSL_get_ex_data(ssl, + getSSLExDataIndex())); +} + +void AsyncSSLSocket::failHandshake(const char* fn, + const AsyncSocketException& ex) { + startFail(); + + if (handshakeTimeout_.isScheduled()) { + handshakeTimeout_.cancelTimeout(); + } + if (handshakeCallback_ != nullptr) { + HandshakeCB* callback = handshakeCallback_; + handshakeCallback_ = nullptr; + callback->handshakeErr(this, ex); + } + + finishFail(); +} + +void AsyncSSLSocket::invokeHandshakeCB() { + if (handshakeTimeout_.isScheduled()) { + handshakeTimeout_.cancelTimeout(); + } + if (handshakeCallback_) { + HandshakeCB* callback = handshakeCallback_; + handshakeCallback_ = nullptr; + callback->handshakeSuc(this); + } +} + +void AsyncSSLSocket::connect(ConnectCallback* callback, + const folly::SocketAddress& address, + int timeout, + const OptionMap &options, + const folly::SocketAddress& bindAddr) + noexcept { + assert(!server_); + assert(state_ == StateEnum::UNINIT); + assert(sslState_ == STATE_UNINIT); + AsyncSSLSocketConnector *connector = + new AsyncSSLSocketConnector(this, callback, timeout); + AsyncSocket::connect(connector, address, timeout, options, bindAddr); +} + +void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) { + // apply the settings specified in verifyPeer_ + if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) { + if(ctx_->needsPeerVerification()) { + SSL_set_verify(ssl, ctx_->getVerificationMode(), + AsyncSSLSocket::sslVerifyCallback); + } + } else { + if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY || + verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) { + SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_), + AsyncSSLSocket::sslVerifyCallback); + } + } +} + +void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, + const SSLContext::SSLVerifyPeerEnum& verifyPeer) { + DestructorGuard dg(this); + assert(eventBase_->isInEventBaseThread()); + + verifyPeer_ = verifyPeer; + + // Make sure we're in the uninitialized state + if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) { + return invalidState(callback); + } + + sslState_ = STATE_CONNECTING; + handshakeCallback_ = callback; + + try { + ssl_ = ctx_->createSSL(); + } catch (std::exception &e) { + sslState_ = STATE_ERROR; + AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR, + "error calling SSLContext::createSSL()"); + LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd=" + << fd_ << "): " << e.what(); + return failHandshake(__func__, ex); + } + + applyVerificationOptions(ssl_); + + SSL_set_fd(ssl_, fd_); + if (sslSession_ != nullptr) { + SSL_set_session(ssl_, sslSession_); + SSL_SESSION_free(sslSession_); + sslSession_ = nullptr; + } +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + if (tlsextHostname_.size()) { + SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str()); + } +#endif + + SSL_set_ex_data(ssl_, getSSLExDataIndex(), this); + + if (timeout > 0) { + handshakeTimeout_.scheduleTimeout(timeout); + } + + handleConnect(); +} + +SSL_SESSION *AsyncSSLSocket::getSSLSession() { + if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) { + return SSL_get1_session(ssl_); + } + + return sslSession_; +} + +void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) { + sslSession_ = session; + if (!takeOwnership && session != nullptr) { + // Increment the reference count + CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION); + } +} + +void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName, + unsigned* protoLen) const { + if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) { + throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED, + "NPN not supported"); + } +} + +bool AsyncSSLSocket::getSelectedNextProtocolNoThrow( + const unsigned char** protoName, + unsigned* protoLen) const { + *protoName = nullptr; + *protoLen = 0; +#ifdef OPENSSL_NPN_NEGOTIATED + SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen); + return true; +#else + return false; +#endif +} + +bool AsyncSSLSocket::getSSLSessionReused() const { + if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) { + return SSL_session_reused(ssl_); + } + return false; +} + +const char *AsyncSSLSocket::getNegotiatedCipherName() const { + return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr; +} + +const char *AsyncSSLSocket::getSSLServerName() const { +#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB + return (ssl_ != nullptr) ? SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name) + : nullptr; +#else + throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED, + "SNI not supported"); +#endif +} + +const char *AsyncSSLSocket::getSSLServerNameNoThrow() const { + try { + return getSSLServerName(); + } catch (AsyncSocketException& ex) { + return nullptr; + } +} + +int AsyncSSLSocket::getSSLVersion() const { + return (ssl_ != nullptr) ? SSL_version(ssl_) : 0; +} + +int AsyncSSLSocket::getSSLCertSize() const { + int certSize = 0; + X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr; + if (cert) { + EVP_PKEY *key = X509_get_pubkey(cert); + certSize = EVP_PKEY_bits(key); + EVP_PKEY_free(key); + } + return certSize; +} + +bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept { + int error = *errorOut = SSL_get_error(ssl_, ret); + if (error == SSL_ERROR_WANT_READ) { + // Register for read event if not already. + updateEventRegistration(EventHandler::READ, EventHandler::WRITE); + return true; + } else if (error == SSL_ERROR_WANT_WRITE) { + VLOG(3) << "AsyncSSLSocket(fd=" << fd_ + << ", state=" << int(state_) << ", sslState=" + << sslState_ << ", events=" << eventFlags_ << "): " + << "SSL_ERROR_WANT_WRITE"; + // Register for write event if not already. + updateEventRegistration(EventHandler::WRITE, EventHandler::READ); + return true; +#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP + } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) { + // We will block but we can't register our own socket. The callback that + // triggered this code will re-call handleAccept at the appropriate time. + + // We can only get here if the linked libssl.so has support for this feature + // as well, otherwise SSL_get_error cannot return our error code. + sslState_ = STATE_CACHE_LOOKUP; + + // Unregister for all events while blocked here + updateEventRegistration(EventHandler::NONE, + EventHandler::READ | EventHandler::WRITE); + + // The timeout (if set) keeps running here + return true; +#endif +#ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING + } else if (error == SSL_ERROR_WANT_RSA_ASYNC_PENDING) { + // Our custom openssl function has kicked off an async request to do + // modular exponentiation. When that call returns, a callback will + // be invoked that will re-call handleAccept. + sslState_ = STATE_RSA_ASYNC_PENDING; + + // Unregister for all events while blocked here + updateEventRegistration( + EventHandler::NONE, + EventHandler::READ | EventHandler::WRITE + ); + + // The timeout (if set) keeps running here + return true; +#endif + } else { + // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail + // in the log + long lastError = ERR_get_error(); + VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", " + << "state=" << state_ << ", " + << "sslState=" << sslState_ << ", " + << "events=" << std::hex << eventFlags_ << "): " + << "SSL error: " << error << ", " + << "errno: " << errno << ", " + << "ret: " << ret << ", " + << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", " + << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", " + << "func: " << ERR_func_error_string(lastError) << ", " + << "reason: " << ERR_reason_error_string(lastError); + if (error != SSL_ERROR_SYSCALL) { + if (error == SSL_ERROR_SSL) { + *errorOut = lastError; + } + if ((unsigned long)lastError < 0x8000) { + errno = ENOSYS; + } else { + errno = lastError; + } + } + ERR_clear_error(); + return false; + } +} + +void AsyncSSLSocket::checkForImmediateRead() noexcept { + // openssl may have buffered data that it read from the socket already. + // In this case we have to process it immediately, rather than waiting for + // the socket to become readable again. + if (ssl_ != nullptr && SSL_pending(ssl_) > 0) { + AsyncSocket::handleRead(); + } +} + +void +AsyncSSLSocket::restartSSLAccept() +{ + VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this << ", fd=" << fd_ + << ", state=" << int(state_) << ", " + << "sslState=" << sslState_ << ", events=" << eventFlags_; + DestructorGuard dg(this); + assert( + sslState_ == STATE_CACHE_LOOKUP || + sslState_ == STATE_RSA_ASYNC_PENDING || + sslState_ == STATE_ERROR || + sslState_ == STATE_CLOSED + ); + if (sslState_ == STATE_CLOSED) { + // I sure hope whoever closed this socket didn't delete it already, + // but this is not strictly speaking an error + return; + } + if (sslState_ == STATE_ERROR) { + // go straight to fail if timeout expired during lookup + AsyncSocketException ex(AsyncSocketException::TIMED_OUT, + "SSL accept timed out"); + failHandshake(__func__, ex); + return; + } + sslState_ = STATE_ACCEPTING; + this->handleAccept(); +} + +void +AsyncSSLSocket::handleAccept() noexcept { + VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this + << ", fd=" << fd_ << ", state=" << int(state_) << ", " + << "sslState=" << sslState_ << ", events=" << eventFlags_; + assert(server_); + assert(state_ == StateEnum::ESTABLISHED && + sslState_ == STATE_ACCEPTING); + if (!ssl_) { + /* lazily create the SSL structure */ + try { + ssl_ = ctx_->createSSL(); + } catch (std::exception &e) { + sslState_ = STATE_ERROR; + AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR, + "error calling SSLContext::createSSL()"); + LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this + << ", fd=" << fd_ << "): " << e.what(); + return failHandshake(__func__, ex); + } + SSL_set_fd(ssl_, fd_); + SSL_set_ex_data(ssl_, getSSLExDataIndex(), this); + + applyVerificationOptions(ssl_); + } + + if (server_ && parseClientHello_) { + SSL_set_msg_callback_arg(ssl_, this); + SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback); + } + + errno = 0; + int ret = SSL_accept(ssl_); + if (ret <= 0) { + int error; + if (willBlock(ret, &error)) { + return; + } else { + sslState_ = STATE_ERROR; + SSLException ex(error, errno); + return failHandshake(__func__, ex); + } + } + + handshakeComplete_ = true; + updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE); + + // Move into STATE_ESTABLISHED in the normal case that we are in + // STATE_ACCEPTING. + sslState_ = STATE_ESTABLISHED; + + VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_ + << " successfully accepted; state=" << int(state_) + << ", sslState=" << sslState_ << ", events=" << eventFlags_; + + // Remember the EventBase we are attached to, before we start invoking any + // callbacks (since the callbacks may call detachEventBase()). + EventBase* originalEventBase = eventBase_; + + // Call the accept callback. + invokeHandshakeCB(); + + // Note that the accept callback may have changed our state. + // (set or unset the read callback, called write(), closed the socket, etc.) + // The following code needs to handle these situations correctly. + // + // If the socket has been closed, readCallback_ and writeReqHead_ will + // always be nullptr, so that will prevent us from trying to read or write. + // + // The main thing to check for is if eventBase_ is still originalEventBase. + // If not, we have been detached from this event base, so we shouldn't + // perform any more operations. + if (eventBase_ != originalEventBase) { + return; + } + + AsyncSocket::handleInitialReadWrite(); +} + +void +AsyncSSLSocket::handleConnect() noexcept { + VLOG(3) << "AsyncSSLSocket::handleConnect() this=" << this + << ", fd=" << fd_ << ", state=" << int(state_) << ", " + << "sslState=" << sslState_ << ", events=" << eventFlags_; + assert(!server_); + if (state_ < StateEnum::ESTABLISHED) { + return AsyncSocket::handleConnect(); + } + + assert(state_ == StateEnum::ESTABLISHED && + sslState_ == STATE_CONNECTING); + assert(ssl_); + + errno = 0; + int ret = SSL_connect(ssl_); + if (ret <= 0) { + int error; + if (willBlock(ret, &error)) { + return; + } else { + sslState_ = STATE_ERROR; + SSLException ex(error, errno); + return failHandshake(__func__, ex); + } + } + + handshakeComplete_ = true; + updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE); + + // Move into STATE_ESTABLISHED in the normal case that we are in + // STATE_CONNECTING. + sslState_ = STATE_ESTABLISHED; + + VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; " + << "state=" << int(state_) << ", sslState=" << sslState_ + << ", events=" << eventFlags_; + + // Remember the EventBase we are attached to, before we start invoking any + // callbacks (since the callbacks may call detachEventBase()). + EventBase* originalEventBase = eventBase_; + + // Call the handshake callback. + invokeHandshakeCB(); + + // Note that the connect callback may have changed our state. + // (set or unset the read callback, called write(), closed the socket, etc.) + // The following code needs to handle these situations correctly. + // + // If the socket has been closed, readCallback_ and writeReqHead_ will + // always be nullptr, so that will prevent us from trying to read or write. + // + // The main thing to check for is if eventBase_ is still originalEventBase. + // If not, we have been detached from this event base, so we shouldn't + // perform any more operations. + if (eventBase_ != originalEventBase) { + return; + } + + AsyncSocket::handleInitialReadWrite(); +} + +void +AsyncSSLSocket::handleRead() noexcept { + VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_ + << ", state=" << int(state_) << ", " + << "sslState=" << sslState_ << ", events=" << eventFlags_; + if (state_ < StateEnum::ESTABLISHED) { + return AsyncSocket::handleRead(); + } + + + if (sslState_ == STATE_ACCEPTING) { + assert(server_); + handleAccept(); + return; + } + else if (sslState_ == STATE_CONNECTING) { + assert(!server_); + handleConnect(); + return; + } + + // Normal read + AsyncSocket::handleRead(); +} + +ssize_t +AsyncSSLSocket::performRead(void* buf, size_t buflen) { + errno = 0; + ssize_t bytes = SSL_read(ssl_, buf, buflen); + if (server_ && renegotiateAttempted_) { + LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) + << ", sslstate=" << sslState_ << ", events=" << eventFlags_ << "): " + << "client intitiated SSL renegotiation not permitted"; + // We pack our own SSLerr here with a dummy function + errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ, + SSL_CLIENT_RENEGOTIATION_ATTEMPT); + ERR_clear_error(); + return READ_ERROR; + } + if (bytes <= 0) { + int error = SSL_get_error(ssl_, bytes); + if (error == SSL_ERROR_WANT_READ) { + // The caller will register for read event if not already. + return READ_BLOCKING; + } else if (error == SSL_ERROR_WANT_WRITE) { + // TODO: Even though we are attempting to read data, SSL_read() may + // need to write data if renegotiation is being performed. We currently + // don't support this and just fail the read. + LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) + << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): " + << "unsupported SSL renegotiation during read", + errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ, + SSL_INVALID_RENEGOTIATION); + ERR_clear_error(); + return READ_ERROR; + } else { + // TODO: Fix this code so that it can return a proper error message + // to the callback, rather than relying on AsyncSocket code which + // can't handle SSL errors. + long lastError = ERR_get_error(); + + VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", " + << "state=" << state_ << ", " + << "sslState=" << sslState_ << ", " + << "events=" << std::hex << eventFlags_ << "): " + << "bytes: " << bytes << ", " + << "error: " << error << ", " + << "errno: " << errno << ", " + << "func: " << ERR_func_error_string(lastError) << ", " + << "reason: " << ERR_reason_error_string(lastError); + ERR_clear_error(); + if (zero_return(error, bytes)) { + return bytes; + } + if (error != SSL_ERROR_SYSCALL) { + if ((unsigned long)lastError < 0x8000) { + errno = ENOSYS; + } else { + errno = lastError; + } + } + return READ_ERROR; + } + } else { + appBytesReceived_ += bytes; + return bytes; + } +} + +void AsyncSSLSocket::handleWrite() noexcept { + VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_ + << ", state=" << int(state_) << ", " + << "sslState=" << sslState_ << ", events=" << eventFlags_; + if (state_ < StateEnum::ESTABLISHED) { + return AsyncSocket::handleWrite(); + } + + if (sslState_ == STATE_ACCEPTING) { + assert(server_); + handleAccept(); + return; + } + + if (sslState_ == STATE_CONNECTING) { + assert(!server_); + handleConnect(); + return; + } + + // Normal write + AsyncSocket::handleWrite(); +} + +ssize_t AsyncSSLSocket::performWrite(const iovec* vec, + uint32_t count, + WriteFlags flags, + uint32_t* countWritten, + uint32_t* partialWritten) { + if (sslState_ != STATE_ESTABLISHED) { + LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) + << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): " + << "TODO: AsyncSSLSocket currently does not support calling " + << "write() before the handshake has fully completed"; + errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE, + SSL_EARLY_WRITE); + return -1; + } + + bool cork = isSet(flags, WriteFlags::CORK); + CorkGuard guard(fd_, count > 1, cork, &corked_); + + *countWritten = 0; + *partialWritten = 0; + ssize_t totalWritten = 0; + size_t bytesStolenFromNextBuffer = 0; + for (uint32_t i = 0; i < count; i++) { + const iovec* v = vec + i; + size_t offset = bytesStolenFromNextBuffer; + bytesStolenFromNextBuffer = 0; + size_t len = v->iov_len - offset; + const void* buf; + if (len == 0) { + (*countWritten)++; + continue; + } + buf = ((const char*)v->iov_base) + offset; + + ssize_t bytes; + errno = 0; + uint32_t buffersStolen = 0; + if ((len < MIN_WRITE_SIZE) && ((i + 1) < count)) { + // Combine this buffer with part or all of the next buffers in + // order to avoid really small-grained calls to SSL_write(). + // Each call to SSL_write() produces a separate record in + // the egress SSL stream, and we've found that some low-end + // mobile clients can't handle receiving an HTTP response + // header and the first part of the response body in two + // separate SSL records (even if those two records are in + // the same TCP packet). + char combinedBuf[MIN_WRITE_SIZE]; + memcpy(combinedBuf, buf, len); + do { + // INVARIANT: i + buffersStolen == complete chunks serialized + uint32_t nextIndex = i + buffersStolen + 1; + bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len, + MIN_WRITE_SIZE - len); + memcpy(combinedBuf + len, vec[nextIndex].iov_base, + bytesStolenFromNextBuffer); + len += bytesStolenFromNextBuffer; + if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) { + // couldn't steal the whole buffer + break; + } else { + bytesStolenFromNextBuffer = 0; + buffersStolen++; + } + } while ((i + buffersStolen + 1) < count && (len < MIN_WRITE_SIZE)); + bytes = eorAwareSSLWrite( + ssl_, combinedBuf, len, + (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count)); + + } else { + bytes = eorAwareSSLWrite(ssl_, buf, len, + (isSet(flags, WriteFlags::EOR) && i + 1 == count)); + } + + if (bytes <= 0) { + int error = SSL_get_error(ssl_, bytes); + if (error == SSL_ERROR_WANT_WRITE) { + // The caller will register for write event if not already. + *partialWritten = offset; + return totalWritten; + } else if (error == SSL_ERROR_WANT_READ) { + // TODO: Even though we are attempting to write data, SSL_write() may + // need to read data if renegotiation is being performed. We currently + // don't support this and just fail the write. + LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) + << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): " + << "unsupported SSL renegotiation during write", + errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE, + SSL_INVALID_RENEGOTIATION); + ERR_clear_error(); + return -1; + } else { + // TODO: Fix this code so that it can return a proper error message + // to the callback, rather than relying on AsyncSocket code which + // can't handle SSL errors. + long lastError = ERR_get_error(); + VLOG(3) << + "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) + << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): " + << "SSL error: " << error << ", errno: " << errno + << ", func: " << ERR_func_error_string(lastError) + << ", reason: " << ERR_reason_error_string(lastError); + if (error != SSL_ERROR_SYSCALL) { + if ((unsigned long)lastError < 0x8000) { + errno = ENOSYS; + } else { + errno = lastError; + } + } + ERR_clear_error(); + if (!zero_return(error, bytes)) { + return -1; + } // else fall through to below to correctly record totalWritten + } + } + + totalWritten += bytes; + + if (bytes == (ssize_t)len) { + // The full iovec is written. + (*countWritten) += 1 + buffersStolen; + i += buffersStolen; + // continue + } else { + bytes += offset; // adjust bytes to account for all of v + while (bytes >= (ssize_t)v->iov_len) { + // We combined this buf with part or all of the next one, and + // we managed to write all of this buf but not all of the bytes + // from the next one that we'd hoped to write. + bytes -= v->iov_len; + (*countWritten)++; + v = &(vec[++i]); + } + *partialWritten = bytes; + return totalWritten; + } + } + + return totalWritten; +} + +int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n, + bool eor) { + if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) { + if (appEorByteNo_) { + // cannot track for more than one app byte EOR + CHECK(appEorByteNo_ == appBytesWritten_ + n); + } else { + appEorByteNo_ = appBytesWritten_ + n; + } + + // 1. It is fine to keep updating minEorRawByteNo_. + // 2. It is _min_ in the sense that SSL record will add some overhead. + minEorRawByteNo_ = getRawBytesWritten() + n; + } + + n = sslWriteImpl(ssl, buf, n); + if (n > 0) { + appBytesWritten_ += n; + if (appEorByteNo_) { + if (getRawBytesWritten() >= minEorRawByteNo_) { + minEorRawByteNo_ = 0; + } + if(appBytesWritten_ == appEorByteNo_) { + appEorByteNo_ = 0; + } else { + CHECK(appBytesWritten_ < appEorByteNo_); + } + } + } + return n; +} + +void +AsyncSSLSocket::sslInfoCallback(const SSL *ssl, int where, int ret) { + AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl); + if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) { + sslSocket->renegotiateAttempted_ = true; + } +} + +int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) { + int ret; + struct msghdr msg; + struct iovec iov; + int flags = 0; + AsyncSSLSocket *tsslSock; + + iov.iov_base = const_cast(in); + iov.iov_len = inl; + memset(&msg, 0, sizeof(msg)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + tsslSock = + reinterpret_cast(BIO_get_app_data(b)); + if (tsslSock && + tsslSock->minEorRawByteNo_ && + tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) { + flags = MSG_EOR; + } + + errno = 0; + ret = sendmsg(b->num, &msg, flags); + BIO_clear_retry_flags(b); + if (ret <= 0) { + if (BIO_sock_should_retry(ret)) + BIO_set_retry_write(b); + } + return(ret); +} + +int AsyncSSLSocket::sslVerifyCallback(int preverifyOk, + X509_STORE_CTX* x509Ctx) { + SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data( + x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); + AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl); + + VLOG(3) << "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", " + << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk; + return (self->handshakeCallback_) ? + self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) : + preverifyOk; +} + +void AsyncSSLSocket::enableClientHelloParsing() { + parseClientHello_ = true; + clientHelloInfo_.reset(new ClientHelloInfo()); +} + +void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl) { + SSL_set_msg_callback(ssl, nullptr); + SSL_set_msg_callback_arg(ssl, nullptr); + clientHelloInfo_->clientHelloBuf_.clear(); +} + +void +AsyncSSLSocket::clientHelloParsingCallback(int written, int version, + int contentType, const void *buf, size_t len, SSL *ssl, void *arg) +{ + AsyncSSLSocket *sock = static_cast(arg); + if (written != 0) { + sock->resetClientHelloParsing(ssl); + return; + } + if (contentType != SSL3_RT_HANDSHAKE) { + sock->resetClientHelloParsing(ssl); + return; + } + if (len == 0) { + return; + } + + auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_; + clientHelloBuf.append(IOBuf::wrapBuffer(buf, len)); + try { + Cursor cursor(clientHelloBuf.front()); + if (cursor.read() != SSL3_MT_CLIENT_HELLO) { + sock->resetClientHelloParsing(ssl); + return; + } + + if (cursor.totalLength() < 3) { + clientHelloBuf.trimEnd(len); + clientHelloBuf.append(IOBuf::copyBuffer(buf, len)); + return; + } + + uint32_t messageLength = cursor.read(); + messageLength <<= 8; + messageLength |= cursor.read(); + messageLength <<= 8; + messageLength |= cursor.read(); + if (cursor.totalLength() < messageLength) { + clientHelloBuf.trimEnd(len); + clientHelloBuf.append(IOBuf::copyBuffer(buf, len)); + return; + } + + sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read(); + sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read(); + + cursor.skip(4); // gmt_unix_time + cursor.skip(28); // random_bytes + + cursor.skip(cursor.read()); // session_id + + uint16_t cipherSuitesLength = cursor.readBE(); + for (int i = 0; i < cipherSuitesLength; i += 2) { + sock->clientHelloInfo_-> + clientHelloCipherSuites_.push_back(cursor.readBE()); + } + + uint8_t compressionMethodsLength = cursor.read(); + for (int i = 0; i < compressionMethodsLength; ++i) { + sock->clientHelloInfo_-> + clientHelloCompressionMethods_.push_back(cursor.readBE()); + } + + if (cursor.totalLength() > 0) { + uint16_t extensionsLength = cursor.readBE(); + while (extensionsLength) { + sock->clientHelloInfo_-> + clientHelloExtensions_.push_back(cursor.readBE()); + extensionsLength -= 2; + uint16_t extensionDataLength = cursor.readBE(); + extensionsLength -= 2; + cursor.skip(extensionDataLength); + extensionsLength -= extensionDataLength; + } + } + } catch (std::out_of_range& e) { + // we'll use what we found and cleanup below. + VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): " + << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock; + } + + sock->resetClientHelloParsing(ssl); +} + +} // namespace diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h new file mode 100644 index 00000000..8f3c8bd0 --- /dev/null +++ b/folly/io/async/AsyncSSLSocket.h @@ -0,0 +1,749 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using folly::io::Cursor; +using std::unique_ptr; + +namespace folly { + +class SSLException: public folly::AsyncSocketException { + public: + SSLException(int sslError, int errno_copy); + + int getSSLError() const { return error_; } + + protected: + int error_; + char msg_[256]; +}; + +/** + * A class for performing asynchronous I/O on an SSL connection. + * + * AsyncSSLSocket allows users to asynchronously wait for data on an + * SSL connection, and to asynchronously send data. + * + * The APIs for reading and writing are intentionally asymmetric. + * Waiting for data to read is a persistent API: a callback is + * installed, and is notified whenever new data is available. It + * continues to be notified of new events until it is uninstalled. + * + * AsyncSSLSocket does not provide read timeout functionality, + * because it typically cannot determine when the timeout should be + * active. Generally, a timeout should only be enabled when + * processing is blocked waiting on data from the remote endpoint. + * For server connections, the timeout should not be active if the + * server is currently processing one or more outstanding requests for + * this connection. For client connections, the timeout should not be + * active if there are no requests pending on the connection. + * Additionally, if a client has multiple pending requests, it will + * ususally want a separate timeout for each request, rather than a + * single read timeout. + * + * The write API is fairly intuitive: a user can request to send a + * block of data, and a callback will be informed once the entire + * block has been transferred to the kernel, or on error. + * AsyncSSLSocket does provide a send timeout, since most callers + * want to give up if the remote end stops responding and no further + * progress can be made sending the data. + */ +class AsyncSSLSocket : public virtual AsyncSocket { + public: + typedef std::unique_ptr UniquePtr; + + class HandshakeCB { + public: + virtual ~HandshakeCB() {} + + /** + * handshakeVer() is invoked during handshaking to give the + * application chance to validate it's peer's certificate. + * + * Note that OpenSSL performs only rudimentary internal + * consistency verification checks by itself. Any other validation + * like whether or not the certificate was issued by a trusted CA. + * The default implementation of this callback mimics what what + * OpenSSL does internally if SSL_VERIFY_PEER is set with no + * verification callback. + * + * See the passages on verify_callback in SSL_CTX_set_verify(3) + * for more details. + */ + virtual bool handshakeVer(AsyncSSLSocket* sock, + bool preverifyOk, + X509_STORE_CTX* ctx) noexcept { + return preverifyOk; + } + + /** + * handshakeSuc() is called when a new SSL connection is + * established, i.e., after SSL_accept/connect() returns successfully. + * + * The HandshakeCB will be uninstalled before handshakeSuc() + * is called. + * + * @param sock SSL socket on which the handshake was initiated + */ + virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept = 0; + + /** + * handshakeErr() is called if an error occurs while + * establishing the SSL connection. + * + * The HandshakeCB will be uninstalled before handshakeErr() + * is called. + * + * @param sock SSL socket on which the handshake was initiated + * @param ex An exception representing the error. + */ + virtual void handshakeErr( + AsyncSSLSocket *sock, + const AsyncSocketException& ex) + noexcept = 0; + }; + + class HandshakeTimeout : public AsyncTimeout { + public: + HandshakeTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase) + : AsyncTimeout(eventBase) + , sslSocket_(sslSocket) {} + + virtual void timeoutExpired() noexcept { + sslSocket_->timeoutExpired(); + } + + private: + AsyncSSLSocket* sslSocket_; + }; + + + /** + * These are passed to the application via errno, packed in an SSL err which + * are outside the valid errno range. The values are chosen to be unique + * against values in ssl.h + */ + enum SSLError { + SSL_CLIENT_RENEGOTIATION_ATTEMPT = 900, + SSL_INVALID_RENEGOTIATION = 901, + SSL_EARLY_WRITE = 902 + }; + + /** + * Create a client AsyncSSLSocket + */ + AsyncSSLSocket(const std::shared_ptr &ctx, + EventBase* evb); + + /** + * Create a server/client AsyncSSLSocket from an already connected + * socket file descriptor. + * + * Note that while AsyncSSLSocket enables TCP_NODELAY for sockets it creates + * when connecting, it does not change the socket options when given an + * existing file descriptor. If callers want TCP_NODELAY enabled when using + * this version of the constructor, they need to explicitly call + * setNoDelay(true) after the constructor returns. + * + * @param ctx SSL context for this connection. + * @param evb EventBase that will manage this socket. + * @param fd File descriptor to take over (should be a connected socket). + * @param server Is socket in server mode? + */ + AsyncSSLSocket(const std::shared_ptr& ctx, + EventBase* evb, int fd, bool server = true); + + + /** + * Helper function to create a server/client shared_ptr. + */ + static std::shared_ptr newSocket( + const std::shared_ptr& ctx, + EventBase* evb, int fd, bool server=true) { + return std::shared_ptr( + new AsyncSSLSocket(ctx, evb, fd, server), + Destructor()); + } + + /** + * Helper function to create a client shared_ptr. + */ + static std::shared_ptr newSocket( + const std::shared_ptr& ctx, + EventBase* evb) { + return std::shared_ptr( + new AsyncSSLSocket(ctx, evb), + Destructor()); + } + + +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + /** + * Create a client AsyncSSLSocket with tlsext_servername in + * the Client Hello message. + */ + AsyncSSLSocket(const std::shared_ptr &ctx, + EventBase* evb, + const std::string& serverName); + + /** + * Create a client AsyncSSLSocket from an already connected + * socket file descriptor. + * + * Note that while AsyncSSLSocket enables TCP_NODELAY for sockets it creates + * when connecting, it does not change the socket options when given an + * existing file descriptor. If callers want TCP_NODELAY enabled when using + * this version of the constructor, they need to explicitly call + * setNoDelay(true) after the constructor returns. + * + * @param ctx SSL context for this connection. + * @param evb EventBase that will manage this socket. + * @param fd File descriptor to take over (should be a connected socket). + * @param serverName tlsext_hostname that will be sent in ClientHello. + */ + AsyncSSLSocket(const std::shared_ptr& ctx, + EventBase* evb, + int fd, + const std::string& serverName); + + static std::shared_ptr newSocket( + const std::shared_ptr& ctx, + EventBase* evb, + const std::string& serverName) { + return std::shared_ptr( + new AsyncSSLSocket(ctx, evb, serverName), + Destructor()); + } +#endif + + /** + * TODO: implement support for SSL renegotiation. + * + * This involves proper handling of the SSL_ERROR_WANT_READ/WRITE + * code as a result of SSL_write/read(), instead of returning an + * error. In that case, the READ/WRITE event should be registered, + * and a flag (e.g., writeBlockedOnRead) should be set to indiciate + * the condition. In the next invocation of read/write callback, if + * the flag is on, performWrite()/performRead() should be called in + * addition to the normal call to performRead()/performWrite(), and + * the flag should be reset. + */ + + // Inherit TAsyncTransport methods from AsyncSocket except the + // following. + // See the documentation in TAsyncTransport.h + // TODO: implement graceful shutdown in close() + // TODO: implement detachSSL() that returns the SSL connection + virtual void closeNow(); + virtual void shutdownWrite(); + virtual void shutdownWriteNow(); + virtual bool good() const; + virtual bool connecting() const; + + bool isEorTrackingEnabled() const override; + virtual void setEorTracking(bool track); + virtual size_t getRawBytesWritten() const; + virtual size_t getRawBytesReceived() const; + void enableClientHelloParsing(); + + /** + * Accept an SSL connection on the socket. + * + * The callback will be invoked and uninstalled when an SSL + * connection has been established on the underlying socket. + * The value of verifyPeer determines the client verification method. + * By default, its set to use the value in the underlying context + * + * @param callback callback object to invoke on success/failure + * @param timeout timeout for this function in milliseconds, or 0 for no + * timeout + * @param verifyPeer SSLVerifyPeerEnum uses the options specified in the + * context by default, can be set explcitly to override the + * method in the context + */ + virtual void sslAccept(HandshakeCB* callback, uint32_t timeout = 0, + const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = + folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); + + /** + * Invoke SSL accept following an asynchronous session cache lookup + */ + void restartSSLAccept(); + + /** + * Connect to the given address, invoking callback when complete or on error + * + * Note timeout applies to TCP + SSL connection time + */ + void connect(ConnectCallback* callback, + const folly::SocketAddress& address, + int timeout = 0, + const OptionMap &options = emptyOptionMap, + const folly::SocketAddress& bindAddr = anyAddress) + noexcept; + + using AsyncSocket::connect; + + /** + * Initiate an SSL connection on the socket + * THe callback will be invoked and uninstalled when an SSL connection + * has been establshed on the underlying socket. + * The verification option verifyPeer is applied if its passed explicitly. + * If its not, the options in SSLContext set on the underying SSLContext + * are applied. + * + * @param callback callback object to invoke on success/failure + * @param timeout timeout for this function in milliseconds, or 0 for no + * timeout + * @param verifyPeer SSLVerifyPeerEnum uses the options specified in the + * context by default, can be set explcitly to override the + * method in the context. If verification is turned on sets + * SSL_VERIFY_PEER and invokes + * HandshakeCB::handshakeVer(). + */ + virtual void sslConn(HandshakeCB *callback, uint64_t timeout = 0, + const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = + folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); + + enum SSLStateEnum { + STATE_UNINIT, + STATE_ACCEPTING, + STATE_CACHE_LOOKUP, + STATE_RSA_ASYNC_PENDING, + STATE_CONNECTING, + STATE_ESTABLISHED, + STATE_REMOTE_CLOSED, /// remote end closed; we can still write + STATE_CLOSING, ///< close() called, but waiting on writes to complete + /// close() called with pending writes, before connect() has completed + STATE_CONNECTING_CLOSING, + STATE_CLOSED, + STATE_ERROR + }; + + SSLStateEnum getSSLState() const { return sslState_;} + + /** + * Get a handle to the negotiated SSL session. This increments the session + * refcount and must be deallocated by the caller. + */ + SSL_SESSION *getSSLSession(); + + /** + * Set the SSL session to be used during sslConn. AsyncSSLSocket will + * hold a reference to the session until it is destroyed or released by the + * underlying SSL structure. + * + * @param takeOwnership if true, AsyncSSLSocket will assume the caller's + * reference count to session. + */ + void setSSLSession(SSL_SESSION *session, bool takeOwnership = false); + + /** + * Get the name of the protocol selected by the client during + * Next Protocol Negotiation (NPN) + * + * Throw an exception if openssl does not support NPN + * + * @param protoName Name of the protocol (not guaranteed to be + * null terminated); will be set to nullptr if + * the client did not negotiate a protocol. + * Note: the AsyncSSLSocket retains ownership + * of this string. + * @param protoNameLen Length of the name. + */ + virtual void getSelectedNextProtocol(const unsigned char** protoName, + unsigned* protoLen) const; + + /** + * Get the name of the protocol selected by the client during + * Next Protocol Negotiation (NPN) + * + * @param protoName Name of the protocol (not guaranteed to be + * null terminated); will be set to nullptr if + * the client did not negotiate a protocol. + * Note: the AsyncSSLSocket retains ownership + * of this string. + * @param protoNameLen Length of the name. + * @return false if openssl does not support NPN + */ + virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName, + unsigned* protoLen) const; + + /** + * Determine if the session specified during setSSLSession was reused + * or if the server rejected it and issued a new session. + */ + bool getSSLSessionReused() const; + + /** + * true if the session was resumed using session ID + */ + bool sessionIDResumed() const { return sessionIDResumed_; } + + void setSessionIDResumed(bool resumed) { + sessionIDResumed_ = resumed; + } + + /** + * Get the negociated cipher name for this SSL connection. + * Returns the cipher used or the constant value "NONE" when no SSL session + * has been established. + */ + const char *getNegotiatedCipherName() const; + + /** + * Get the server name for this SSL connection. + * Returns the server name used or the constant value "NONE" when no SSL + * session has been established. + * If openssl has no SNI support, throw TTransportException. + */ + const char *getSSLServerName() const; + + /** + * Get the server name for this SSL connection. + * Returns the server name used or the constant value "NONE" when no SSL + * session has been established. + * If openssl has no SNI support, return "NONE" + */ + const char *getSSLServerNameNoThrow() const; + + /** + * Get the SSL version for this connection. + * Possible return values are SSL2_VERSION, SSL3_VERSION, TLS1_VERSION, + * with hexa representations 0x200, 0x300, 0x301, + * or 0 if no SSL session has been established. + */ + int getSSLVersion() const; + + /** + * Get the certificate size used for this SSL connection. + */ + int getSSLCertSize() const; + + /* Get the number of bytes read from the wire (including protocol + * overhead). Returns 0 once the connection has been closed. + */ + unsigned long getBytesRead() const { + if (ssl_ != nullptr) { + return BIO_number_read(SSL_get_rbio(ssl_)); + } + return 0; + } + + /* Get the number of bytes written to the wire (including protocol + * overhead). Returns 0 once the connection has been closed. + */ + unsigned long getBytesWritten() const { + if (ssl_ != nullptr) { + return BIO_number_written(SSL_get_wbio(ssl_)); + } + return 0; + } + + virtual void attachEventBase(EventBase* eventBase) { + AsyncSocket::attachEventBase(eventBase); + handshakeTimeout_.attachEventBase(eventBase); + } + + virtual void detachEventBase() { + AsyncSocket::detachEventBase(); + handshakeTimeout_.detachEventBase(); + } + + virtual void attachTimeoutManager(TimeoutManager* manager) { + handshakeTimeout_.attachTimeoutManager(manager); + } + + virtual void detachTimeoutManager() { + handshakeTimeout_.detachTimeoutManager(); + } + +#if OPENSSL_VERSION_NUMBER >= 0x009080bfL + /** + * This function will set the SSL context for this socket to the + * argument. This should only be used on client SSL Sockets that have + * already called detachSSLContext(); + */ + void attachSSLContext(const std::shared_ptr& ctx); + + /** + * Detaches the SSL context for this socket. + */ + void detachSSLContext(); +#endif + +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + /** + * Switch the SSLContext to continue the SSL handshake. + * It can only be used in server mode. + */ + void switchServerSSLContext( + const std::shared_ptr& handshakeCtx); + + /** + * Did server recognize/support the tlsext_hostname in Client Hello? + * It can only be used in client mode. + * + * @return true - tlsext_hostname is matched by the server + * false - tlsext_hostname is not matched or + * is not supported by server + */ + bool isServerNameMatch() const; + + /** + * Set the SNI hostname that we'll advertise to the server in the + * ClientHello message. + */ + void setServerName(std::string serverName) noexcept; +#endif + + void timeoutExpired() noexcept; + + /** + * Get the list of supported ciphers sent by the client in the client's + * preference order. + */ + void getSSLClientCiphers(std::string& clientCiphers) { + std::stringstream ciphersStream; + std::string cipherName; + + if (parseClientHello_ == false + || clientHelloInfo_->clientHelloCipherSuites_.empty()) { + clientCiphers = ""; + return; + } + + for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_) + { + // OpenSSL expects code as a big endian char array + auto cipherCode = htons(originalCipherCode); + +#if defined(SSL_OP_NO_TLSv1_2) + const SSL_CIPHER* cipher = + TLSv1_2_method()->get_cipher_by_char((unsigned char*)&cipherCode); +#elif defined(SSL_OP_NO_TLSv1_1) + const SSL_CIPHER* cipher = + TLSv1_1_method()->get_cipher_by_char((unsigned char*)&cipherCode); +#elif defined(SSL_OP_NO_TLSv1) + const SSL_CIPHER* cipher = + TLSv1_method()->get_cipher_by_char((unsigned char*)&cipherCode); +#else + const SSL_CIPHER* cipher = + SSLv3_method()->get_cipher_by_char((unsigned char*)&cipherCode); +#endif + + if (cipher == nullptr) { + ciphersStream << std::setfill('0') << std::setw(4) << std::hex + << originalCipherCode << ":"; + } else { + ciphersStream << SSL_CIPHER_get_name(cipher) << ":"; + } + } + + clientCiphers = ciphersStream.str(); + clientCiphers.erase(clientCiphers.end() - 1); + } + + /** + * Get the list of compression methods sent by the client in TLS Hello. + */ + std::string getSSLClientComprMethods() { + if (!parseClientHello_) { + return ""; + } + return folly::join(":", clientHelloInfo_->clientHelloCompressionMethods_); + } + + /** + * Get the list of TLS extensions sent by the client in the TLS Hello. + */ + std::string getSSLClientExts() { + if (!parseClientHello_) { + return ""; + } + return folly::join(":", clientHelloInfo_->clientHelloExtensions_); + } + + /** + * Get the list of shared ciphers between the server and the client. + * Works well for only SSLv2, not so good for SSLv3 or TLSv1. + */ + void getSSLSharedCiphers(std::string& sharedCiphers) { + char ciphersBuffer[1024]; + ciphersBuffer[0] = '\0'; + SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1); + sharedCiphers = ciphersBuffer; + } + + /** + * Get the list of ciphers supported by the server in the server's + * preference order. + */ + void getSSLServerCiphers(std::string& serverCiphers) { + serverCiphers = SSL_get_cipher_list(ssl_, 0); + int i = 1; + const char *cipher; + while ((cipher = SSL_get_cipher_list(ssl_, i)) != nullptr) { + serverCiphers.append(":"); + serverCiphers.append(cipher); + i++; + } + } + + static int getSSLExDataIndex(); + static AsyncSSLSocket* getFromSSL(const SSL *ssl); + static int eorAwareBioWrite(BIO *b, const char *in, int inl); + void resetClientHelloParsing(SSL *ssl); + static void clientHelloParsingCallback(int write_p, int version, + int content_type, const void *buf, size_t len, SSL *ssl, void *arg); + + struct ClientHelloInfo { + folly::IOBufQueue clientHelloBuf_; + uint8_t clientHelloMajorVersion_; + uint8_t clientHelloMinorVersion_; + std::vector clientHelloCipherSuites_; + std::vector clientHelloCompressionMethods_; + std::vector clientHelloExtensions_; + }; + + // For unit-tests + ClientHelloInfo* getClientHelloInfo() { + return clientHelloInfo_.get(); + } + + protected: + + /** + * Protected destructor. + * + * Users of AsyncSSLSocket must never delete it directly. Instead, invoke + * destroy() instead. (See the documentation in TDelayedDestruction.h for + * more details.) + */ + ~AsyncSSLSocket(); + + // Inherit event notification methods from AsyncSocket except + // the following. + + void handleRead() noexcept; + void handleWrite() noexcept; + void handleAccept() noexcept; + void handleConnect() noexcept; + + void invalidState(HandshakeCB* callback); + bool willBlock(int ret, int *errorOut) noexcept; + + virtual void checkForImmediateRead() noexcept; + // AsyncSocket calls this at the wrong time for SSL + void handleInitialReadWrite() noexcept {} + + ssize_t performRead(void* buf, size_t buflen); + ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags, + uint32_t* countWritten, uint32_t* partialWritten); + + // This virtual wrapper around SSL_write exists solely for testing/mockability + virtual int sslWriteImpl(SSL *ssl, const void *buf, int n) { + return SSL_write(ssl, buf, n); + } + + /** + * Apply verification options passed to sslConn/sslAccept or those set + * in the underlying SSLContext object. + * + * @param ssl pointer to the SSL object on which verification options will be + * applied. If verifyPeer_ was explicitly set either via sslConn/sslAccept, + * those options override the settings in the underlying SSLContext. + */ + void applyVerificationOptions(SSL * ssl); + + /** + * A SSL_write wrapper that understand EOR + * + * @param ssl: SSL* object + * @param buf: Buffer to be written + * @param n: Number of bytes to be written + * @param eor: Does the last byte (buf[n-1]) have the app-last-byte? + * @return: The number of app bytes successfully written to the socket + */ + int eorAwareSSLWrite(SSL *ssl, const void *buf, int n, bool eor); + + // Inherit error handling methods from AsyncSocket, plus the following. + void failHandshake(const char* fn, const AsyncSocketException& ex); + + void invokeHandshakeCB(); + + static void sslInfoCallback(const SSL *ssl, int type, int val); + + static std::mutex mutex_; + static int sslExDataIndex_; + // Whether we've applied the TCP_CORK option to the socket + bool corked_{false}; + // SSL related members. + bool server_{false}; + // Used to prevent client-initiated renegotiation. Note that AsyncSSLSocket + // doesn't fully support renegotiation, so we could just fail all attempts + // to enforce this. Once it is supported, we should make it an option + // to disable client-initiated renegotiation. + bool handshakeComplete_{false}; + bool renegotiateAttempted_{false}; + SSLStateEnum sslState_{STATE_UNINIT}; + std::shared_ptr ctx_; + // Callback for SSL_accept() or SSL_connect() + HandshakeCB* handshakeCallback_{nullptr}; + SSL* ssl_{nullptr}; + SSL_SESSION *sslSession_{nullptr}; + HandshakeTimeout handshakeTimeout_; + // whether the SSL session was resumed using session ID or not + bool sessionIDResumed_{false}; + + // The app byte num that we are tracking for the MSG_EOR + // Only one app EOR byte can be tracked. + size_t appEorByteNo_{0}; + + // When openssl is about to sendmsg() across the minEorRawBytesNo_, + // it will pass MSG_EOR to sendmsg(). + size_t minEorRawByteNo_{0}; +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + std::shared_ptr handshakeCtx_; + std::string tlsextHostname_; +#endif + folly::SSLContext::SSLVerifyPeerEnum + verifyPeer_{folly::SSLContext::SSLVerifyPeerEnum::USE_CTX}; + + // Callback for SSL_CTX_set_verify() + static int sslVerifyCallback(int preverifyOk, X509_STORE_CTX* ctx); + + bool parseClientHello_{false}; + unique_ptr clientHelloInfo_; +}; + +} // namespace diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index ceba5211..d18e3b6c 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -175,6 +175,14 @@ class AsyncSocket::WriteRequest { struct iovec writeOps_[]; ///< write operation(s) list }; +AsyncSocket::AsyncSocket() + : eventBase_(nullptr) + , writeTimeout_(this, nullptr) + , ioHandler_(this, nullptr) { + VLOG(5) << "new AsyncSocket(" << ")"; + init(); +} + AsyncSocket::AsyncSocket(EventBase* evb) : eventBase_(evb) , writeTimeout_(this, evb) diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index 77bd2b0c..33924b6d 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -184,6 +184,7 @@ class AsyncSocket : virtual public AsyncTransport { noexcept = 0; }; + explicit AsyncSocket(); /** * Create a new unconnected AsyncSocket. * @@ -549,6 +550,14 @@ class AsyncSocket : virtual public AsyncTransport { return setsockopt(fd_, level, optname, optval, sizeof(T)); } + enum class StateEnum : uint8_t { + UNINIT, + CONNECTING, + ESTABLISHED, + CLOSED, + ERROR + }; + protected: enum ReadResultEnum { READ_EOF = 0, @@ -565,14 +574,6 @@ class AsyncSocket : virtual public AsyncTransport { */ ~AsyncSocket(); - enum class StateEnum : uint8_t { - UNINIT, - CONNECTING, - ESTABLISHED, - CLOSED, - ERROR - }; - friend std::ostream& operator << (std::ostream& os, const StateEnum& state); enum ShutdownFlags { diff --git a/folly/io/async/AsyncTimeout.cpp b/folly/io/async/AsyncTimeout.cpp index 1bd5a831..2f94f760 100644 --- a/folly/io/async/AsyncTimeout.cpp +++ b/folly/io/async/AsyncTimeout.cpp @@ -44,9 +44,11 @@ AsyncTimeout::AsyncTimeout(EventBase* eventBase) event_set(&event_, -1, EV_TIMEOUT, &AsyncTimeout::libeventCallback, this); event_.ev_base = nullptr; - timeoutManager_->attachTimeoutManager( + if (eventBase) { + timeoutManager_->attachTimeoutManager( this, TimeoutManager::InternalEnum::NORMAL); + } RequestContext::getStaticContext(); }