From 727f779bf54deac4add02b20bd60ae50532c9d93 Mon Sep 17 00:00:00 2001 From: Dave Watson Date: Tue, 9 Jun 2015 11:34:10 -0700 Subject: [PATCH] AsyncSSLSocket StartTLS Summary: Adds a StartTLS mode to AsyncSSLSocket. Previously I could only find anyone doing something like this by using AsyncSocket, calling detachFd, then creating a new AsyncSSLSocket, and calling sslConn/sslAccept. That had a couple downsides: 1) All pointers to the previous AsyncSocket become invalid and similarly 2) have to be super careful reads/writes happen on the correct socket, are flushed before changing socket types, etc. This makes it super easy to just use the same AsyncSSLSocket for everything: a) Create AsyncSSLSocket in StartTLS mode b) send/recv anything c) Call sslAccept/sslConn. Existing writes are still flushed in the correct order, any additional writes are buffered until handshake completes d) Start receiving encrypted data. I made it a new mode (vs. the default), since it seems bad to unintentionally send unencrypted data. Use case is easy secure thrift upgrade (similar to how current kerberos does it) Test Plan: New unittest Reviewed By: afrind@fb.com Subscribers: doug, ssl-diffs@, folly-diffs@, yfeldblum, chalfant, haijunz, andrewcox, alandau, alikhtarov, jsedgwick, simpkins FB internal diff: D2120114 Signature: t1:2120114:1433798448:caeddc8feb6cc10fb34200ba97ea323bcaf09f7a --- folly/io/async/AsyncSSLSocket.cpp | 39 +++++++--- folly/io/async/AsyncSSLSocket.h | 28 +++++--- folly/io/async/test/AsyncSSLSocketTest.cpp | 82 ++++++++++++++++++++++ 3 files changed, 130 insertions(+), 19 deletions(-) diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 3eeb932f..0489b6a4 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -253,18 +253,22 @@ SSLException::SSLException(int sslError, int errno_copy): * Create a client AsyncSSLSocket */ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, - EventBase* evb) : + EventBase* evb, bool deferSecurityNegotiation) : AsyncSocket(evb), ctx_(ctx), handshakeTimeout_(this, evb) { init(); + if (deferSecurityNegotiation) { + sslState_ = STATE_UNENCRYPTED; + } } /** * Create a server/client AsyncSSLSocket */ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, - EventBase* evb, int fd, bool server) : + EventBase* evb, int fd, bool server, + bool deferSecurityNegotiation) : AsyncSocket(evb, fd), server_(server), ctx_(ctx), @@ -274,6 +278,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, SSL_CTX_set_info_callback(ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback); } + if (deferSecurityNegotiation) { + sslState_ = STATE_UNENCRYPTED; + } } #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) @@ -283,8 +290,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, */ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, EventBase* evb, - const std::string& serverName) : - AsyncSSLSocket(ctx, evb) { + const std::string& serverName, + bool deferSecurityNegotiation) : + AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) { tlsextHostname_ = serverName; } @@ -294,8 +302,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, */ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, EventBase* evb, int fd, - const std::string& serverName) : - AsyncSSLSocket(ctx, evb, fd, false) { + const std::string& serverName, + bool deferSecurityNegotiation) : + AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { tlsextHostname_ = serverName; } #endif @@ -374,7 +383,7 @@ void AsyncSSLSocket::shutdownWriteNow() { bool AsyncSSLSocket::good() const { return (AsyncSocket::good() && (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING || - sslState_ == STATE_ESTABLISHED)); + sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED)); } // The TAsyncTransport definition of 'good' states that the transport is @@ -468,7 +477,9 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout, verifyPeer_ = verifyPeer; // Make sure we're in the uninitialized state - if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) { + if (!server_ || (sslState_ != STATE_UNINIT && + sslState_ != STATE_UNENCRYPTED) || + handshakeCallback_ != nullptr) { return invalidState(callback); } @@ -674,7 +685,9 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, verifyPeer_ = verifyPeer; // Make sure we're in the uninitialized state - if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) { + if (server_ || (sslState_ != STATE_UNINIT && sslState_ != + STATE_UNENCRYPTED) || + handshakeCallback_ != nullptr) { return invalidState(callback); } @@ -1078,6 +1091,10 @@ AsyncSSLSocket::handleRead() noexcept { ssize_t AsyncSSLSocket::performRead(void* buf, size_t buflen) { + if (sslState_ == STATE_UNENCRYPTED) { + return AsyncSocket::performRead(buf, buflen); + } + errno = 0; ssize_t bytes = SSL_read(ssl_, buf, buflen); if (server_ && renegotiateAttempted_) { @@ -1169,6 +1186,10 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, WriteFlags flags, uint32_t* countWritten, uint32_t* partialWritten) { + if (sslState_ == STATE_UNENCRYPTED) { + return AsyncSocket::performWrite( + vec, count, flags, countWritten, partialWritten); + } if (sslState_ != STATE_ESTABLISHED) { LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): " diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index cce4c18b..393b8765 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -162,7 +162,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { * Create a client AsyncSSLSocket */ AsyncSSLSocket(const std::shared_ptr &ctx, - EventBase* evb); + EventBase* evb, bool deferSecurityNegotiation = false); /** * Create a server/client AsyncSSLSocket from an already connected @@ -178,9 +178,12 @@ class AsyncSSLSocket : public virtual AsyncSocket { * @param evb EventBase that will manage this socket. * @param fd File descriptor to take over (should be a connected socket). * @param server Is socket in server mode? + * @param deferSecurityNegotiation + * unencrypted data can be sent before sslConn/Accept */ AsyncSSLSocket(const std::shared_ptr& ctx, - EventBase* evb, int fd, bool server = true); + EventBase* evb, int fd, + bool server = true, bool deferSecurityNegotiation = false); /** @@ -188,9 +191,10 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ static std::shared_ptr newSocket( const std::shared_ptr& ctx, - EventBase* evb, int fd, bool server=true) { + EventBase* evb, int fd, bool server=true, + bool deferSecurityNegotiation = false) { return std::shared_ptr( - new AsyncSSLSocket(ctx, evb, fd, server), + new AsyncSSLSocket(ctx, evb, fd, server, deferSecurityNegotiation), Destructor()); } @@ -199,9 +203,9 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ static std::shared_ptr newSocket( const std::shared_ptr& ctx, - EventBase* evb) { + EventBase* evb, bool deferSecurityNegotiation = false) { return std::shared_ptr( - new AsyncSSLSocket(ctx, evb), + new AsyncSSLSocket(ctx, evb, deferSecurityNegotiation), Destructor()); } @@ -213,7 +217,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ AsyncSSLSocket(const std::shared_ptr &ctx, EventBase* evb, - const std::string& serverName); + const std::string& serverName, + bool deferSecurityNegotiation = false); /** * Create a client AsyncSSLSocket from an already connected @@ -233,14 +238,16 @@ class AsyncSSLSocket : public virtual AsyncSocket { AsyncSSLSocket(const std::shared_ptr& ctx, EventBase* evb, int fd, - const std::string& serverName); + const std::string& serverName, + bool deferSecurityNegotiation = false); static std::shared_ptr newSocket( const std::shared_ptr& ctx, EventBase* evb, - const std::string& serverName) { + const std::string& serverName, + bool deferSecurityNegotiation = false) { return std::shared_ptr( - new AsyncSSLSocket(ctx, evb, serverName), + new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation), Destructor()); } #endif @@ -336,6 +343,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { enum SSLStateEnum { STATE_UNINIT, + STATE_UNENCRYPTED, STATE_ACCEPTING, STATE_CACHE_LOOKUP, STATE_RSA_ASYNC_PENDING, diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index 20f782a1..b3759cfe 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -1262,8 +1262,90 @@ TEST(AsyncSSLSocketTest, MinWriteSizeTest) { socket->setMinWriteSize(50000); EXPECT_EQ(50000, socket->getMinWriteSize()); } + +class ReadCallbackTerminator : public ReadCallback { + public: + ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb) + : ReadCallback(wcb) + , base_(base) {} + + // Do not write data back, terminate the loop. + void readDataAvailable(size_t len) noexcept override { + std::cerr << "readDataAvailable, len " << len << std::endl; + + currentBuffer.length = len; + + buffers.push_back(currentBuffer); + currentBuffer.reset(); + state = STATE_SUCCEEDED; + + socket_->setReadCB(nullptr); + base_->terminateLoopSoon(); + } + private: + EventBase* base_; +}; + + +/** + * Test a full unencrypted codepath + */ +TEST(AsyncSSLSocketTest, UnencryptedTest) { + EventBase base; + + auto clientCtx = std::make_shared(); + auto serverCtx = std::make_shared(); + int fds[2]; + getfds(fds); + getctx(clientCtx, serverCtx); + auto client = AsyncSSLSocket::newSocket( + clientCtx, &base, fds[0], false, true); + auto server = AsyncSSLSocket::newSocket( + serverCtx, &base, fds[1], true, true); + + ReadCallbackTerminator readCallback(&base, nullptr); + server->setReadCB(&readCallback); + readCallback.setSocket(server); + + uint8_t buf[128]; + memset(buf, 'a', sizeof(buf)); + client->write(nullptr, buf, sizeof(buf)); + + // Check that bytes are unencrypted + char c; + EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK)); + EXPECT_EQ('a', c); + + EventBaseAborter eba(&base, 3000); + base.loop(); + + EXPECT_EQ(1, readCallback.buffers.size()); + EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState()); + + server->setReadCB(&readCallback); + + // Unencrypted + server->sslAccept(nullptr); + client->sslConn(nullptr); + + // Do NOT wait for handshake, writing should be queued and happen after + + client->write(nullptr, buf, sizeof(buf)); + + // Check that bytes are *not* unencrypted + char c2; + EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK)); + EXPECT_NE('a', c2); + + + base.loop(); + + EXPECT_EQ(2, readCallback.buffers.size()); + EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState()); } +} // namespace + /////////////////////////////////////////////////////////////////////////// // init_unit_test_suite /////////////////////////////////////////////////////////////////////////// -- 2.34.1