From 3a033f27c7a96c9b1e6a51adaab03cee660aa027 Mon Sep 17 00:00:00 2001 From: Kyle Nekritz Date: Mon, 9 Jan 2017 11:51:34 -0800 Subject: [PATCH] Add pre received data API to AsyncSSLSocket. Summary: This allows something else (ie fizz) to read data from a socket, and then later decide to to accept an SSL connection with OpenSSL by inserting the data it read in front of future reads on the socket. Reviewed By: anirudhvr Differential Revision: D4325634 fbshipit-source-id: 05076d2d911fda681b9c4e5d9d3375559293ea35 --- folly/io/async/AsyncSSLSocket.cpp | 37 +++++++++++++++++++--- folly/io/async/AsyncSSLSocket.h | 4 +++ folly/io/async/test/AsyncSSLSocketTest.cpp | 35 ++++++++++++++++++++ 3 files changed, 72 insertions(+), 4 deletions(-) diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 70a4640e..7910bd02 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -452,6 +452,10 @@ void AsyncSSLSocket::sslAccept( /* register for a read operation (waiting for CLIENT HELLO) */ updateEventRegistration(EventHandler::READ, EventHandler::WRITE); + + if (preReceivedData_) { + handleRead(); + } } #if OPENSSL_VERSION_NUMBER >= 0x009080bfL @@ -1610,12 +1614,31 @@ int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) { if (!out) { return 0; } - auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0); BIO_clear_retry_flags(b); - if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) { - BIO_set_retry_read(b); + + auto appData = OpenSSLUtils::getBioAppData(b); + CHECK(appData); + auto sslSock = reinterpret_cast(appData); + + if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) { + VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock + << ", reading pre-received data"; + + Cursor cursor(sslSock->preReceivedData_.get()); + auto len = cursor.pullAtMost(out, outl); + + IOBufQueue queue; + queue.append(std::move(sslSock->preReceivedData_)); + queue.trimStart(len); + sslSock->preReceivedData_ = queue.move(); + return len; + } else { + auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0); + if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) { + BIO_set_retry_read(b); + } + return result; } - return result; } int AsyncSSLSocket::sslVerifyCallback( @@ -1632,6 +1655,12 @@ int AsyncSSLSocket::sslVerifyCallback( preverifyOk; } +void AsyncSSLSocket::setPreReceivedData(std::unique_ptr data) { + CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED); + CHECK(!preReceivedData_); + preReceivedData_ = std::move(data); +} + void AsyncSSLSocket::enableClientHelloParsing() { parseClientHello_ = true; clientHelloInfo_.reset(new ssl::ClientHelloInfo()); diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index 0d106214..01ff7bd2 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -278,6 +278,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { virtual size_t getRawBytesReceived() const override; void enableClientHelloParsing(); + void setPreReceivedData(std::unique_ptr data); + /** * Accept an SSL connection on the socket. * @@ -818,6 +820,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { std::chrono::steady_clock::time_point handshakeEndTime_; std::chrono::milliseconds handshakeConnectTimeout_{0}; bool sessionResumptionAttempted_{false}; + + std::unique_ptr preReceivedData_; }; } // namespace diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index 1508d0e9..87c1453c 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -1961,6 +1961,41 @@ TEST(AsyncSSLSocketTest, HandshakeTFORefused) { EXPECT_THAT(ccb.error, testing::HasSubstr("refused")); } +TEST(AsyncSSLSocketTest, TestPreReceivedData) { + EventBase clientEventBase; + EventBase serverEventBase; + auto clientCtx = std::make_shared(); + auto dfServerCtx = std::make_shared(); + std::array fds; + getfds(fds.data()); + getctx(clientCtx, dfServerCtx); + + AsyncSSLSocket::UniquePtr clientSockPtr( + new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSockPtr( + new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true)); + auto clientSock = clientSockPtr.get(); + auto serverSock = serverSockPtr.get(); + SSLHandshakeClient client(std::move(clientSockPtr), true, true); + + // Steal some data from the server. + clientEventBase.loopOnce(); + std::array buf; + recv(fds[1], buf.data(), buf.size(), 0); + + serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf))); + SSLHandshakeServer server(std::move(serverSockPtr), true, true); + while (!client.handshakeSuccess_ && !client.handshakeError_) { + serverEventBase.loopOnce(); + clientEventBase.loopOnce(); + } + + EXPECT_TRUE(client.handshakeSuccess_); + EXPECT_TRUE(server.handshakeSuccess_); + EXPECT_EQ( + serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten()); +} + #endif } // namespace -- 2.34.1