Invoking correct callback during TFO fallback
authorSubodh Iyengar <subodh@fb.com>
Wed, 17 Aug 2016 04:52:13 +0000 (21:52 -0700)
committerFacebook Github Bot 2 <facebook-github-bot-2-bot@fb.com>
Wed, 17 Aug 2016 04:53:42 +0000 (21:53 -0700)
Summary:
If we fallback from SSL to TFO and the connection times
out, invokeConnectSuccess tries to deliver the connectError,
however we've already delivered the connect callback to the user.

This is bad because we have no way of reporting an error back.
This changes it so that when using SSL and we're scheduling a timeout
when we're falling back, we will schedule a timeout of our own which
will invoke AsyncSSLSocket's timeoutExpired. This will return a handshakeError
instead to the client.

Reviewed By: yfeldblum

Differential Revision: D3708699

fbshipit-source-id: 41fe668f00972c0875bb0318c6a6de863d3ab8f9

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSSLSocket.h
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSSLSocketTest.cpp

index 04ca1516d35877c9dc506892427b20423fab5aab..c16e6fb6c66393045e3d298d8503297b689aaedc 100644 (file)
@@ -253,7 +253,8 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
                                EventBase* evb, bool deferSecurityNegotiation) :
     AsyncSocket(evb),
     ctx_(ctx),
-    handshakeTimeout_(this, evb) {
+    handshakeTimeout_(this, evb),
+    connectionTimeout_(this, evb) {
   init();
   if (deferSecurityNegotiation) {
     sslState_ = STATE_UNENCRYPTED;
@@ -269,7 +270,8 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
     AsyncSocket(evb, fd),
     server_(server),
     ctx_(ctx),
-    handshakeTimeout_(this, evb) {
+    handshakeTimeout_(this, evb),
+    connectionTimeout_(this, evb) {
   init();
   if (server) {
     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
@@ -587,6 +589,12 @@ void AsyncSSLSocket::timeoutExpired() noexcept {
     // We are expecting a callback in restartSSLAccept.  The cache lookup
     // and rsa-call necessarily have pointers to this ssl socket, so delay
     // the cleanup until he calls us back.
+  } else if (state_ == StateEnum::CONNECTING) {
+    assert(sslState_ == STATE_CONNECTING);
+    DestructorGuard dg(this);
+    AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
+                           "Fallback connect timed out during TFO");
+    failHandshake(__func__, ex);
   } else {
     assert(state_ == StateEnum::ESTABLISHED &&
            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
@@ -1157,15 +1165,45 @@ AsyncSSLSocket::handleConnect() noexcept {
   AsyncSocket::handleInitialReadWrite();
 }
 
+void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
+  connectionTimeout_.cancelTimeout();
+  AsyncSocket::invokeConnectErr(ex);
+}
+
 void AsyncSSLSocket::invokeConnectSuccess() {
+  connectionTimeout_.cancelTimeout();
   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
     // If we failed TFO, we'd fall back to trying to connect the socket,
     // to setup things like timeouts.
     startSSLConnect();
   }
+  // still invoke the base class since it re-sets the connect time.
   AsyncSocket::invokeConnectSuccess();
 }
 
+void AsyncSSLSocket::scheduleConnectTimeout() {
+  if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
+    // We fell back from TFO, and need to set the timeouts.
+    // We will not have a connect callback in this case, thus if the timer
+    // expires we would have no-one to notify.
+    // Thus we should reset even the connect timers to point to the handshake
+    // timeouts.
+    assert(connectCallback_ == nullptr);
+    // We use a different connect timeout here than the handshake timeout, so
+    // that we can disambiguate the 2 timers.
+    int timeout = connectTimeout_.count();
+    if (timeout > 0) {
+      if (!connectionTimeout_.scheduleTimeout(timeout)) {
+        throw AsyncSocketException(
+            AsyncSocketException::INTERNAL_ERROR,
+            withAddr("failed to schedule AsyncSSLSocket connect timeout"));
+      }
+    }
+    return;
+  }
+  AsyncSocket::scheduleConnectTimeout();
+}
+
 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
   // turn on the buffer movable in openssl
index 296641db3941a754c45b8e7b605365643b6e0a35..47ad97b08160279e9c9419eaa9a97a2a164688a0 100644 (file)
@@ -136,6 +136,20 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     AsyncSSLSocket* sslSocket_;
   };
 
+  // Timer for if we fallback from SSL connects to TCP connects
+  class ConnectionTimeout : public AsyncTimeout {
+   public:
+    ConnectionTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
+        : AsyncTimeout(eventBase), sslSocket_(sslSocket) {}
+
+    virtual void timeoutExpired() noexcept override {
+      sslSocket_->timeoutExpired();
+    }
+
+   private:
+    AsyncSSLSocket* sslSocket_;
+  };
+
   /**
    * Create a client AsyncSSLSocket
    */
@@ -811,7 +825,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   void invokeHandshakeErr(const AsyncSocketException& ex);
   void invokeHandshakeCB();
 
+  void invokeConnectErr(const AsyncSocketException& ex) override;
   void invokeConnectSuccess() override;
+  void scheduleConnectTimeout() override;
 
   void cacheLocalPeerAddr();
 
@@ -836,6 +852,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   SSL* ssl_{nullptr};
   SSL_SESSION *sslSession_{nullptr};
   HandshakeTimeout handshakeTimeout_;
+  ConnectionTimeout connectionTimeout_;
   // whether the SSL session was resumed using session ID or not
   bool sessionIDResumed_{false};
 
index 18ef98752e732f578001fe860fd4d6311b5d80c9..68ff62991bdd5ed7c80bf67dfe7b1b84e43bced6 100644 (file)
@@ -472,7 +472,8 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
   if (rv < 0) {
     auto errnoCopy = errno;
     if (errnoCopy == EINPROGRESS) {
-      scheduleConnectTimeoutAndRegisterForEvents();
+      scheduleConnectTimeout();
+      registerForConnectEvents();
     } else {
       throw AsyncSocketException(
           AsyncSocketException::NOT_OPEN,
@@ -483,7 +484,7 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
   return rv;
 }
 
-void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() {
+void AsyncSocket::scheduleConnectTimeout() {
   // Connection in progress.
   int timeout = connectTimeout_.count();
   if (timeout > 0) {
@@ -494,7 +495,9 @@ void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() {
           withAddr("failed to schedule AsyncSocket connect timeout"));
     }
   }
+}
 
+void AsyncSocket::registerForConnectEvents() {
   // Register for write events, so we'll
   // be notified when the connection finishes/fails.
   // Note that we don't register for a persistent event here.
@@ -1781,7 +1784,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
       // cookie.
       state_ = StateEnum::CONNECTING;
       try {
-        scheduleConnectTimeoutAndRegisterForEvents();
+        scheduleConnectTimeout();
+        registerForConnectEvents();
       } catch (const AsyncSocketException& ex) {
         return WriteResult(
             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
index 6e0fb77b0e0283238b394f3e99f4f969c606da8a..ca4272b32796c9fc9731b0ffc298995b515ddaa2 100644 (file)
@@ -838,7 +838,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
   int socketConnect(const struct sockaddr* addr, socklen_t len);
 
-  void scheduleConnectTimeoutAndRegisterForEvents();
+  virtual void scheduleConnectTimeout();
+  void registerForConnectEvents();
 
   bool updateEventRegistration();
 
@@ -869,7 +870,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
                  const AsyncSocketException& ex);
   void failWrite(const char* fn, const AsyncSocketException& ex);
   void failAllWrites(const AsyncSocketException& ex);
-  void invokeConnectErr(const AsyncSocketException& ex);
+  virtual void invokeConnectErr(const AsyncSocketException& ex);
   virtual void invokeConnectSuccess();
   void invalidState(ConnectCallback* callback);
   void invalidState(ReadCallback* callback);
index f09a4da4307bf8069e0eddad2992b1cfcf2eac0a..1622dcabb3df8881d95b0b8ed5a7bc5a0098d836 100644 (file)
@@ -1788,13 +1788,15 @@ class ConnCallback : public AsyncSocket::ConnectCallback {
     state = State::SUCCESS;
   }
 
-  virtual void connectErr(const AsyncSocketException&) noexcept override {
+  virtual void connectErr(const AsyncSocketException& ex) noexcept override {
     state = State::ERROR;
+    error = ex.what();
   }
 
   enum class State { WAITING, SUCCESS, ERROR };
 
   State state{State::WAITING};
+  std::string error;
 };
 
 template <class Cardinality>
@@ -1869,7 +1871,7 @@ TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
   socket->enableTFO();
   EXPECT_THROW(
-      socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
+      socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
 }
 
 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
@@ -1888,6 +1890,25 @@ TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
 }
 
+TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
+  // Start listening on a local port
+  EmptyReadCallback readCallback;
+  HandshakeCallback handshakeCallback(
+      &readCallback, HandshakeCallback::EXPECT_ERROR);
+  HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 100);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+  EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
+}
+
 #endif
 
 } // namespace