From 12ace86198d5050388099abb0cdb58faa7d4ac74 Mon Sep 17 00:00:00 2001 From: Subodh Iyengar Date: Tue, 16 Aug 2016 21:52:13 -0700 Subject: [PATCH] Invoking correct callback during TFO fallback Summary: If we fallback from SSL to TFO and the connection times out, invokeConnectSuccess tries to deliver the connectError, however we've already delivered the connect callback to the user. This is bad because we have no way of reporting an error back. This changes it so that when using SSL and we're scheduling a timeout when we're falling back, we will schedule a timeout of our own which will invoke AsyncSSLSocket's timeoutExpired. This will return a handshakeError instead to the client. Reviewed By: yfeldblum Differential Revision: D3708699 fbshipit-source-id: 41fe668f00972c0875bb0318c6a6de863d3ab8f9 --- folly/io/async/AsyncSSLSocket.cpp | 42 ++++++++++++++++++++-- folly/io/async/AsyncSSLSocket.h | 17 +++++++++ folly/io/async/AsyncSocket.cpp | 10 ++++-- folly/io/async/AsyncSocket.h | 5 +-- folly/io/async/test/AsyncSSLSocketTest.cpp | 25 +++++++++++-- 5 files changed, 90 insertions(+), 9 deletions(-) diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 04ca1516..c16e6fb6 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -253,7 +253,8 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, EventBase* evb, bool deferSecurityNegotiation) : AsyncSocket(evb), ctx_(ctx), - handshakeTimeout_(this, evb) { + handshakeTimeout_(this, evb), + connectionTimeout_(this, evb) { init(); if (deferSecurityNegotiation) { sslState_ = STATE_UNENCRYPTED; @@ -269,7 +270,8 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, AsyncSocket(evb, fd), server_(server), ctx_(ctx), - handshakeTimeout_(this, evb) { + handshakeTimeout_(this, evb), + connectionTimeout_(this, evb) { init(); if (server) { SSL_CTX_set_info_callback(ctx_->getSSLCtx(), @@ -587,6 +589,12 @@ void AsyncSSLSocket::timeoutExpired() noexcept { // We are expecting a callback in restartSSLAccept. The cache lookup // and rsa-call necessarily have pointers to this ssl socket, so delay // the cleanup until he calls us back. + } else if (state_ == StateEnum::CONNECTING) { + assert(sslState_ == STATE_CONNECTING); + DestructorGuard dg(this); + AsyncSocketException ex(AsyncSocketException::TIMED_OUT, + "Fallback connect timed out during TFO"); + failHandshake(__func__, ex); } else { assert(state_ == StateEnum::ESTABLISHED && (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING)); @@ -1157,15 +1165,45 @@ AsyncSSLSocket::handleConnect() noexcept { AsyncSocket::handleInitialReadWrite(); } +void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) { + connectionTimeout_.cancelTimeout(); + AsyncSocket::invokeConnectErr(ex); +} + void AsyncSSLSocket::invokeConnectSuccess() { + connectionTimeout_.cancelTimeout(); if (sslState_ == SSLStateEnum::STATE_CONNECTING) { // If we failed TFO, we'd fall back to trying to connect the socket, // to setup things like timeouts. startSSLConnect(); } + // still invoke the base class since it re-sets the connect time. AsyncSocket::invokeConnectSuccess(); } +void AsyncSSLSocket::scheduleConnectTimeout() { + if (sslState_ == SSLStateEnum::STATE_CONNECTING) { + // We fell back from TFO, and need to set the timeouts. + // We will not have a connect callback in this case, thus if the timer + // expires we would have no-one to notify. + // Thus we should reset even the connect timers to point to the handshake + // timeouts. + assert(connectCallback_ == nullptr); + // We use a different connect timeout here than the handshake timeout, so + // that we can disambiguate the 2 timers. + int timeout = connectTimeout_.count(); + if (timeout > 0) { + if (!connectionTimeout_.scheduleTimeout(timeout)) { + throw AsyncSocketException( + AsyncSocketException::INTERNAL_ERROR, + withAddr("failed to schedule AsyncSSLSocket connect timeout")); + } + } + return; + } + AsyncSocket::scheduleConnectTimeout(); +} + void AsyncSSLSocket::setReadCB(ReadCallback *callback) { #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP // turn on the buffer movable in openssl diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index 296641db..47ad97b0 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -136,6 +136,20 @@ class AsyncSSLSocket : public virtual AsyncSocket { AsyncSSLSocket* sslSocket_; }; + // Timer for if we fallback from SSL connects to TCP connects + class ConnectionTimeout : public AsyncTimeout { + public: + ConnectionTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase) + : AsyncTimeout(eventBase), sslSocket_(sslSocket) {} + + virtual void timeoutExpired() noexcept override { + sslSocket_->timeoutExpired(); + } + + private: + AsyncSSLSocket* sslSocket_; + }; + /** * Create a client AsyncSSLSocket */ @@ -811,7 +825,9 @@ class AsyncSSLSocket : public virtual AsyncSocket { void invokeHandshakeErr(const AsyncSocketException& ex); void invokeHandshakeCB(); + void invokeConnectErr(const AsyncSocketException& ex) override; void invokeConnectSuccess() override; + void scheduleConnectTimeout() override; void cacheLocalPeerAddr(); @@ -836,6 +852,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { SSL* ssl_{nullptr}; SSL_SESSION *sslSession_{nullptr}; HandshakeTimeout handshakeTimeout_; + ConnectionTimeout connectionTimeout_; // whether the SSL session was resumed using session ID or not bool sessionIDResumed_{false}; diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 18ef9875..68ff6299 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -472,7 +472,8 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) { if (rv < 0) { auto errnoCopy = errno; if (errnoCopy == EINPROGRESS) { - scheduleConnectTimeoutAndRegisterForEvents(); + scheduleConnectTimeout(); + registerForConnectEvents(); } else { throw AsyncSocketException( AsyncSocketException::NOT_OPEN, @@ -483,7 +484,7 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) { return rv; } -void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() { +void AsyncSocket::scheduleConnectTimeout() { // Connection in progress. int timeout = connectTimeout_.count(); if (timeout > 0) { @@ -494,7 +495,9 @@ void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() { withAddr("failed to schedule AsyncSocket connect timeout")); } } +} +void AsyncSocket::registerForConnectEvents() { // Register for write events, so we'll // be notified when the connection finishes/fails. // Note that we don't register for a persistent event here. @@ -1781,7 +1784,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) { // cookie. state_ = StateEnum::CONNECTING; try { - scheduleConnectTimeoutAndRegisterForEvents(); + scheduleConnectTimeout(); + registerForConnectEvents(); } catch (const AsyncSocketException& ex) { return WriteResult( WRITE_ERROR, folly::make_unique(ex)); diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index 6e0fb77b..ca4272b3 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -838,7 +838,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { int socketConnect(const struct sockaddr* addr, socklen_t len); - void scheduleConnectTimeoutAndRegisterForEvents(); + virtual void scheduleConnectTimeout(); + void registerForConnectEvents(); bool updateEventRegistration(); @@ -869,7 +870,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { const AsyncSocketException& ex); void failWrite(const char* fn, const AsyncSocketException& ex); void failAllWrites(const AsyncSocketException& ex); - void invokeConnectErr(const AsyncSocketException& ex); + virtual void invokeConnectErr(const AsyncSocketException& ex); virtual void invokeConnectSuccess(); void invalidState(ConnectCallback* callback); void invalidState(ReadCallback* callback); diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index f09a4da4..1622dcab 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -1788,13 +1788,15 @@ class ConnCallback : public AsyncSocket::ConnectCallback { state = State::SUCCESS; } - virtual void connectErr(const AsyncSocketException&) noexcept override { + virtual void connectErr(const AsyncSocketException& ex) noexcept override { state = State::ERROR; + error = ex.what(); } enum class State { WAITING, SUCCESS, ERROR }; State state{State::WAITING}; + std::string error; }; template @@ -1869,7 +1871,7 @@ TEST(AsyncSSLSocketTest, ConnectTFOTimeout) { std::make_shared(server.getAddress(), sslContext); socket->enableTFO(); EXPECT_THROW( - socket->open(std::chrono::milliseconds(1)), AsyncSocketException); + socket->open(std::chrono::milliseconds(20)), AsyncSocketException); } TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) { @@ -1888,6 +1890,25 @@ TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) { EXPECT_EQ(ConnCallback::State::ERROR, ccb.state); } +TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) { + // Start listening on a local port + EmptyReadCallback readCallback; + HandshakeCallback handshakeCallback( + &readCallback, HandshakeCallback::EXPECT_ERROR); + HandshakeTimeoutCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback, true); + + EventBase evb; + + auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1)); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 100); + + evb.loop(); + EXPECT_EQ(ConnCallback::State::ERROR, ccb.state); + EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out")); +} + #endif } // namespace -- 2.34.1