Add support for TFO connections
authorSubodh Iyengar <subodh@fb.com>
Tue, 31 May 2016 04:08:33 +0000 (21:08 -0700)
committerFacebook Github Bot 2 <facebook-github-bot-2-bot@fb.com>
Tue, 31 May 2016 04:23:28 +0000 (21:23 -0700)
Summary:
This adds support to establish connections
over TFO.

The API introduced here retains the same
connect() + write() api that clients currently
use.

If enableTFO() is called then the connect will
be deferred to the first write. This only works
with request response protocols since a write
must trigger the connect. There is a tradeoff here
for the simpler API, and we can address this with
other signals such as a short timeout in the future.

Even though the client might enable TFO, the program
might run on machines without TFO support.
There were 2 choices for supporting machines where
TFO might not be enabled:
1. Fallback to normal connect if tfo sendmsg fails
2. Check tfo supported on the machine before using it

Both these have their tradeoffs, however option 1 does
not require us to read from procfs in the common code
path.

Reviewed By: Orvid

Differential Revision: D3327480

fbshipit-source-id: 9ac3a0c7ad2d206b158fdc305641fedbd93aa44d

folly/detail/SocketFastOpen.cpp
folly/detail/SocketFastOpen.h
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSocketTest.h
folly/io/async/test/AsyncSocketTest2.cpp
folly/io/async/test/BlockingSocket.h
folly/io/async/test/SocketClient.cpp [new file with mode: 0644]

index 90938a9a91677fbe6e48e92929dfb5812e2fcfe0..4455377ebe3f99efc95c77f1ee21422cbf8738f1 100644 (file)
@@ -20,21 +20,42 @@ namespace folly {
 namespace detail {
 
 #if FOLLY_ALLOW_TFO
-ssize_t tfo_sendto(
-    int sockfd,
-    const void* buf,
-    size_t len,
-    int flags,
-    const struct sockaddr* dest_addr,
-    socklen_t addrlen) {
+
+#include <netinet/tcp.h>
+#include <stdio.h>
+
+// Sometimes these flags are not present in the headers,
+// so define them if not present.
+#if !defined(MSG_FASTOPEN)
+#define MSG_FASTOPEN 0x20000000
+#endif
+
+#if !defined(TCP_FASTOPEN)
+#define TCP_FASTOPEN 23
+#endif
+
+ssize_t tfo_sendmsg(int sockfd, const struct msghdr* msg, int flags) {
   flags |= MSG_FASTOPEN;
-  return sendto(sockfd, buf, len, flags, dest_addr, addrlen);
+  return sendmsg(sockfd, msg, flags);
 }
 
 int tfo_enable(int sockfd, size_t max_queue_size) {
   return setsockopt(
       sockfd, SOL_TCP, TCP_FASTOPEN, &max_queue_size, sizeof(max_queue_size));
 }
+
+#else
+
+ssize_t tfo_sendmsg(int sockfd, const struct msghdr* msg, int flags) {
+  errno = EOPNOTSUPP;
+  return -1;
+}
+
+int tfo_enable(int sockfd, size_t max_queue_size) {
+  errno = ENOPROTOOPT;
+  return -1;
+}
+
 #endif
 }
 }
index 1295a1e640967311400d06916b4cbe7a0a3e85b5..d01e89b9ecb9f29a61a144fdae17ae4fd96e1030 100644 (file)
 #include <folly/portability/Sockets.h>
 #include <sys/types.h>
 
-#if !defined(FOLLY_ALLOW_TFO) && defined(TCP_FASTOPEN) && defined(MSG_FASTOPEN)
+#if !defined(FOLLY_ALLOW_TFO) && defined(__linux__) && !defined(__ANDROID__)
+// only allow for linux right now
 #define FOLLY_ALLOW_TFO 1
 #endif
 
 namespace folly {
 namespace detail {
 
-#if FOLLY_ALLOW_TFO
-
 /**
- * tfo_sendto has the same semantics as sendto, but is used to
+ * tfo_sendto has the same semantics as sendmsg, but is used to
  * send with TFO data.
  */
-ssize_t tfo_sendto(
-    int sockfd,
-    const void* buf,
-    size_t len,
-    int flags,
-    const struct sockaddr* dest_addr,
-    socklen_t addlen);
+ssize_t tfo_sendmsg(int sockfd, const struct msghdr* msg, int flags);
 
 /**
  * Enable TFO on a listening socket.
  */
 int tfo_enable(int sockfd, size_t max_queue_size);
-#endif
 }
 }
index 0fe8e88613c013b759771b476ed6b7db33344d3c..cffe1f82ee1f9f2c2cff8ad1fe08c98a8dfe7c84 100644 (file)
@@ -16,8 +16,7 @@
 
 #include <folly/io/async/AsyncSocket.h>
 
-#include <folly/io/async/EventBase.h>
-#include <folly/io/async/EventHandler.h>
+#include <folly/ExceptionWrapper.h>
 #include <folly/SocketAddress.h>
 #include <folly/io/IOBuf.h>
 #include <folly/portability/Fcntl.h>
@@ -429,34 +428,12 @@ void AsyncSocket::connect(ConnectCallback* callback,
     // Perform the connect()
     address.getAddress(&addrStorage);
 
-    rv = ::connect(fd_, saddr, address.getActualSize());
-    if (rv < 0) {
-      auto errnoCopy = errno;
-      if (errnoCopy == EINPROGRESS) {
-        // Connection in progress.
-        if (timeout > 0) {
-          // Start a timer in case the connection takes too long.
-          if (!writeTimeout_.scheduleTimeout(timeout)) {
-            throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                withAddr("failed to schedule AsyncSocket connect timeout"));
-          }
-        }
-
-        // Register for write events, so we'll
-        // be notified when the connection finishes/fails.
-        // Note that we don't register for a persistent event here.
-        assert(eventFlags_ == EventHandler::NONE);
-        eventFlags_ = EventHandler::WRITE;
-        if (!ioHandler_.registerHandler(eventFlags_)) {
-          throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-              withAddr("failed to register AsyncSocket connect handler"));
-        }
+    if (tfoEnabled_) {
+      state_ = StateEnum::FAST_OPEN;
+      tfoAttempted_ = true;
+    } else {
+      if (socketConnect(saddr, addr_.getActualSize()) < 0) {
         return;
-      } else {
-        throw AsyncSocketException(
-            AsyncSocketException::NOT_OPEN,
-            "connect failed (immediately)",
-            errnoCopy);
       }
     }
 
@@ -481,10 +458,52 @@ void AsyncSocket::connect(ConnectCallback* callback,
   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
   assert(readCallback_ == nullptr);
   assert(writeReqHead_ == nullptr);
-  state_ = StateEnum::ESTABLISHED;
+  if (state_ != StateEnum::FAST_OPEN) {
+    state_ = StateEnum::ESTABLISHED;
+  }
   invokeConnectSuccess();
 }
 
+int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
+  int rv = ::connect(fd_, saddr, len);
+  if (rv < 0) {
+    auto errnoCopy = errno;
+    if (errnoCopy == EINPROGRESS) {
+      scheduleConnectTimeoutAndRegisterForEvents();
+    } else {
+      throw AsyncSocketException(
+          AsyncSocketException::NOT_OPEN,
+          "connect failed (immediately)",
+          errnoCopy);
+    }
+  }
+  return rv;
+}
+
+void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() {
+  // Connection in progress.
+  int timeout = connectTimeout_.count();
+  if (timeout > 0) {
+    // Start a timer in case the connection takes too long.
+    if (!writeTimeout_.scheduleTimeout(timeout)) {
+      throw AsyncSocketException(
+          AsyncSocketException::INTERNAL_ERROR,
+          withAddr("failed to schedule AsyncSocket connect timeout"));
+    }
+  }
+
+  // Register for write events, so we'll
+  // be notified when the connection finishes/fails.
+  // Note that we don't register for a persistent event here.
+  assert(eventFlags_ == EventHandler::NONE);
+  eventFlags_ = EventHandler::WRITE;
+  if (!ioHandler_.registerHandler(eventFlags_)) {
+    throw AsyncSocketException(
+        AsyncSocketException::INTERNAL_ERROR,
+        withAddr("failed to register AsyncSocket connect handler"));
+  }
+}
+
 void AsyncSocket::connect(ConnectCallback* callback,
                            const string& ip, uint16_t port,
                            int timeout,
@@ -502,7 +521,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
 
 void AsyncSocket::cancelConnect() {
   connectCallback_ = nullptr;
-  if (state_ == StateEnum::CONNECTING) {
+  if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
     closeNow();
   }
 }
@@ -514,7 +533,7 @@ void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
   // If we are currently pending on write requests, immediately update
   // writeTimeout_ with the new value.
   if ((eventFlags_ & EventHandler::WRITE) &&
-      (state_ != StateEnum::CONNECTING)) {
+      (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
     assert(state_ == StateEnum::ESTABLISHED);
     assert((shutdownFlags_ & SHUT_WRITE) == 0);
     if (sendTimeout_ > 0) {
@@ -573,6 +592,7 @@ void AsyncSocket::setReadCB(ReadCallback *callback) {
 
   switch ((StateEnum)state_) {
     case StateEnum::CONNECTING:
+    case StateEnum::FAST_OPEN:
       // For convenience, we allow the read callback to be set while we are
       // still connecting.  We just store the callback for now.  Once the
       // connection completes we'll register for read events.
@@ -683,7 +703,8 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
   uint32_t partialWritten = 0;
   int bytesWritten = 0;
   bool mustRegister = false;
-  if (state_ == StateEnum::ESTABLISHED && !connecting()) {
+  if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
+      !connecting()) {
     if (writeReqHead_ == nullptr) {
       // If we are established and there are no other writes pending,
       // we can attempt to perform the write immediately.
@@ -715,7 +736,13 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
           bufferCallback_->onEgressBuffered();
         }
       }
-      mustRegister = true;
+      if (!connecting()) {
+        // Writes might put the socket back into connecting state
+        // if TFO is enabled, and using TFO fails.
+        // This means that write timeouts would not be active, however
+        // connect timeouts would affect this stage.
+        mustRegister = true;
+      }
     }
   } else if (!connecting()) {
     // Invalid state for writing
@@ -839,7 +866,7 @@ void AsyncSocket::closeNow() {
   switch (state_) {
     case StateEnum::ESTABLISHED:
     case StateEnum::CONNECTING:
-    {
+    case StateEnum::FAST_OPEN: {
       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
       state_ = StateEnum::CLOSED;
 
@@ -995,6 +1022,13 @@ void AsyncSocket::shutdownWriteNow() {
       // immediately shut down the write side of the socket.
       shutdownFlags_ |= SHUT_WRITE_PENDING;
       return;
+    case StateEnum::FAST_OPEN:
+      // In fast open state we haven't call connected yet, and if we shutdown
+      // the writes, we will never try to call connect, so shut everything down
+      shutdownFlags_ |= SHUT_WRITE;
+      // Immediately fail all write requests
+      failAllWrites(socketShutdownForWritesEx);
+      return;
     case StateEnum::CLOSED:
     case StateEnum::ERROR:
       // We should never get here.  SHUT_WRITE should always be set
@@ -1046,9 +1080,10 @@ bool AsyncSocket::hangup() const {
 }
 
 bool AsyncSocket::good() const {
-  return ((state_ == StateEnum::CONNECTING ||
-          state_ == StateEnum::ESTABLISHED) &&
-          (shutdownFlags_ == 0) && (eventBase_ != nullptr));
+  return (
+      (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
+       state_ == StateEnum::ESTABLISHED) &&
+      (shutdownFlags_ == 0) && (eventBase_ != nullptr));
 }
 
 bool AsyncSocket::error() const {
@@ -1695,17 +1730,97 @@ void AsyncSocket::timeoutExpired() noexcept {
   if (state_ == StateEnum::CONNECTING) {
     // connect() timed out
     // Unregister for I/O events.
-    AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
-                           "connect timed out");
-    failConnect(__func__, ex);
+    if (connectCallback_) {
+      AsyncSocketException ex(
+          AsyncSocketException::TIMED_OUT, "connect timed out");
+      failConnect(__func__, ex);
+    } else {
+      // we faced a connect error without a connect callback, which could
+      // happen due to TFO.
+      AsyncSocketException ex(
+          AsyncSocketException::TIMED_OUT, "write timed out during connection");
+      failWrite(__func__, ex);
+    }
   } else {
     // a normal write operation timed out
-    assert(state_ == StateEnum::ESTABLISHED);
     AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out");
     failWrite(__func__, ex);
   }
 }
 
+ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
+  return detail::tfo_sendmsg(fd, msg, msg_flags);
+}
+
+AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
+    struct msghdr* msg,
+    int msg_flags) {
+  ssize_t totalWritten = 0;
+  if (state_ == StateEnum::FAST_OPEN) {
+    sockaddr_storage addr;
+    auto len = addr_.getAddress(&addr);
+    msg->msg_name = &addr;
+    msg->msg_namelen = len;
+    totalWritten = tfoSendMsg(fd_, msg, msg_flags);
+    if (totalWritten >= 0) {
+      tfoFinished_ = true;
+      state_ = StateEnum::ESTABLISHED;
+      handleInitialReadWrite();
+    } else if (errno == EINPROGRESS) {
+      VLOG(4) << "TFO falling back to connecting";
+      // A normal sendmsg doesn't return EINPROGRESS, however
+      // TFO might fallback to connecting if there is no
+      // cookie.
+      state_ = StateEnum::CONNECTING;
+      try {
+        scheduleConnectTimeoutAndRegisterForEvents();
+      } catch (const AsyncSocketException& ex) {
+        return WriteResult(
+            WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
+      }
+      // Let's fake it that no bytes were written.
+      // Some clients check errno even if return code is 0, so we
+      // set it just in case.
+      errno = EAGAIN;
+      totalWritten = 0;
+    } else if (errno == EOPNOTSUPP) {
+      VLOG(4) << "TFO not supported";
+      // Try falling back to connecting.
+      state_ = StateEnum::CONNECTING;
+      try {
+        int ret = socketConnect((const sockaddr*)&addr, len);
+        if (ret == 0) {
+          // connect succeeded immediately
+          // Treat this like no data was written.
+          state_ = StateEnum::ESTABLISHED;
+          handleInitialReadWrite();
+        }
+        // If there was no exception during connections,
+        // we would return that no bytes were written.
+        // Some clients check errno even if return code is 0, so we
+        // set it just in case.
+        errno = EAGAIN;
+        totalWritten = 0;
+      } catch (const AsyncSocketException& ex) {
+        return WriteResult(
+            WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
+      }
+    } else if (errno == EAGAIN) {
+      // Normally sendmsg would indicate that the write would block.
+      // However in the fast open case, it would indicate that sendmsg
+      // fell back to a connect. This is a return code from connect()
+      // instead, and is an error condition indicating no fds available.
+      return WriteResult(
+          WRITE_ERROR,
+          folly::make_unique<AsyncSocketException>(
+              AsyncSocketException::UNKNOWN, "No more free local ports"));
+    }
+  } else {
+    totalWritten = ::sendmsg(fd_, msg, msg_flags);
+  }
+  return WriteResult(totalWritten);
+}
+
 AsyncSocket::WriteResult AsyncSocket::performWrite(
     const iovec* vec,
     uint32_t count,
@@ -1740,9 +1855,10 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
     // marks that this is the last byte of a record (response)
     msg_flags |= MSG_EOR;
   }
-  ssize_t totalWritten = ::sendmsg(fd_, &msg, msg_flags);
+  auto writeResult = sendSocketMessage(&msg, msg_flags);
+  auto totalWritten = writeResult.writeReturn;
   if (totalWritten < 0) {
-    if (errno == EAGAIN) {
+    if (!writeResult.exception && errno == EAGAIN) {
       // TCP buffer is full; we can't write any more data right now.
       *countWritten = 0;
       *partialWritten = 0;
@@ -1751,7 +1867,7 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
     // error
     *countWritten = 0;
     *partialWritten = 0;
-    return WriteResult(WRITE_ERROR);
+    return writeResult;
   }
 
   appBytesWritten_ += totalWritten;
index 2a77ac3b67465f68aa376827655844c727f711d4..f3e605d251cb95533315caf31af03d562f5027fb 100644 (file)
@@ -18,6 +18,7 @@
 
 #include <folly/Optional.h>
 #include <folly/SocketAddress.h>
+#include <folly/detail/SocketFastOpen.h>
 #include <folly/io/IOBuf.h>
 #include <folly/io/ShutdownSocketSet.h>
 #include <folly/io/async/AsyncSocketException.h>
@@ -416,6 +417,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     return connectTimeout_;
   }
 
+  bool getTFOAttempted() const {
+    return tfoAttempted_;
+  }
+
+  /**
+   * Returns whether or not the attempt to use TFO
+   * finished successfully. This does not necessarily
+   * mean TFO worked, just that trying to use TFO
+   * succeeded.
+   */
+  bool getTFOFinished() const {
+    return tfoFinished_;
+  }
+
   // Methods controlling socket options
 
   /**
@@ -509,12 +524,24 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     peek_ = peek;
   }
 
+  /**
+   * Enables TFO behavior on the AsyncSocket if FOLLY_ALLOW_TFO
+   * is set.
+   */
+  void enableTFO() {
+    // No-op if folly does not allow tfo
+#if FOLLY_ALLOW_TFO
+    tfoEnabled_ = true;
+#endif
+  }
+
   enum class StateEnum : uint8_t {
     UNINIT,
     CONNECTING,
     ESTABLISHED,
     CLOSED,
-    ERROR
+    ERROR,
+    FAST_OPEN,
   };
 
   void setBufferCallback(BufferCallback* cb);
@@ -784,6 +811,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
       uint32_t* countWritten,
       uint32_t* partialWritten);
 
+  /**
+   * Sends the message over the socket using sendmsg
+   *
+   * @param msg       Message to send
+   * @param msg_flags Flags to pass to sendmsg
+   */
+  AsyncSocket::WriteResult sendSocketMessage(struct msghdr* msg, int msg_flags);
+
+  virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags);
+
+  int socketConnect(const struct sockaddr* addr, socklen_t len);
+
+  void scheduleConnectTimeoutAndRegisterForEvents();
+
   bool updateEventRegistration();
 
   /**
@@ -854,6 +895,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   std::chrono::milliseconds connectTimeout_{0};
 
   BufferCallback* bufferCallback_{nullptr};
+  bool tfoEnabled_{false};
+  bool tfoAttempted_{false};
+  bool tfoFinished_{false};
 };
 #ifdef _MSC_VER
 #pragma vtordisp(pop)
index a0087c13effa913e6d95782236e229fdae2c1d25..549d9c6b185229ebe778f554f31244ee299a5c93 100644 (file)
@@ -72,6 +72,7 @@ class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
 
   void writeErr(size_t bytesWritten,
                 const folly::AsyncSocketException& ex) noexcept override {
+    LOG(ERROR) << ex.what();
     state = STATE_FAILED;
     this->bytesWritten = bytesWritten;
     exception = ex;
@@ -205,8 +206,7 @@ class TestServer {
  public:
   // Create a TestServer.
   // This immediately starts listening on an ephemeral port.
-  TestServer()
-    : fd_(-1) {
+  explicit TestServer(bool enableTFO = false) : fd_(-1) {
     fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
     if (fd_ < 0) {
       throw folly::AsyncSocketException(
@@ -221,6 +221,11 @@ class TestServer {
           "non-blocking mode",
           errno);
     }
+    if (enableTFO) {
+#if FOLLY_ALLOW_TFO
+      folly::detail::tfo_enable(fd_, 100);
+#endif
+    }
     if (listen(fd_, 10) != 0) {
       throw folly::AsyncSocketException(
           folly::AsyncSocketException::INTERNAL_ERROR,
index 23b806adefca6e908fc1096092a8cab5aac6e0d7..ca6460315d6108332581f34a598cbc7b7524ff7a 100644 (file)
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+#include <folly/ExceptionWrapper.h>
+#include <folly/RWSpinLock.h>
+#include <folly/Random.h>
+#include <folly/SocketAddress.h>
 #include <folly/io/async/AsyncServerSocket.h>
 #include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/RWSpinLock.h>
-#include <folly/SocketAddress.h>
-#include <folly/Random.h>
 
 #include <folly/io/IOBuf.h>
 #include <folly/io/async/test/AsyncSocketTest.h>
 #include <folly/portability/Unistd.h>
 #include <folly/test/SocketAddressTestHelper.h>
 
-#include <gtest/gtest.h>
 #include <boost/scoped_array.hpp>
-#include <iostream>
 #include <fcntl.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
 #include <sys/types.h>
+#include <iostream>
 #include <thread>
 
 using namespace boost;
@@ -47,6 +49,7 @@ using std::chrono::milliseconds;
 using boost::scoped_array;
 
 using namespace folly;
+using namespace testing;
 
 class DelayedWrite: public AsyncTimeout {
  public:
@@ -100,6 +103,28 @@ TEST(AsyncSocketTest, Connect) {
   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
 }
 
+enum class TFOState {
+  DISABLED,
+  ENABLED,
+};
+
+class AsyncSocketConnectTest : public ::testing::TestWithParam<TFOState> {};
+
+std::vector<TFOState> getTestingValues() {
+  std::vector<TFOState> vals;
+  vals.emplace_back(TFOState::DISABLED);
+
+#if FOLLY_ALLOW_TFO
+  vals.emplace_back(TFOState::ENABLED);
+#endif
+  return vals;
+}
+
+INSTANTIATE_TEST_CASE_P(
+    ConnectTests,
+    AsyncSocketConnectTest,
+    ::testing::ValuesIn(getTestingValues()));
+
 /**
  * Test connecting to a server that isn't listening
  */
@@ -115,10 +140,10 @@ TEST(AsyncSocketTest, ConnectRefused) {
 
   evb.loop();
 
-  CHECK_EQ(cb.state, STATE_FAILED);
-  CHECK_EQ(cb.exception.getType(), AsyncSocketException::NOT_OPEN);
+  EXPECT_EQ(STATE_FAILED, cb.state);
+  EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType());
   EXPECT_LE(0, socket->getConnectTime().count());
-  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
+  EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
 }
 
 /**
@@ -164,12 +189,17 @@ TEST(AsyncSocketTest, ConnectTimeout) {
  * Test writing immediately after connecting, without waiting for connect
  * to finish.
  */
-TEST(AsyncSocketTest, ConnectAndWrite) {
+TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
   TestServer server;
 
   // connect()
   EventBase evb;
   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+  if (GetParam() == TFOState::ENABLED) {
+    socket->enableTFO();
+  }
+
   ConnCallback ccb;
   socket->connect(&ccb, server.getAddress(), 30);
 
@@ -198,12 +228,16 @@ TEST(AsyncSocketTest, ConnectAndWrite) {
 /**
  * Test connecting using a nullptr connect callback.
  */
-TEST(AsyncSocketTest, ConnectNullCallback) {
+TEST_P(AsyncSocketConnectTest, ConnectNullCallback) {
   TestServer server;
 
   // connect()
   EventBase evb;
   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  if (GetParam() == TFOState::ENABLED) {
+    socket->enableTFO();
+  }
+
   socket->connect(nullptr, server.getAddress(), 30);
 
   // write some data, just so we have some way of verifing
@@ -231,12 +265,15 @@ TEST(AsyncSocketTest, ConnectNullCallback) {
  *
  * This exercises the STATE_CONNECTING_CLOSING code.
  */
-TEST(AsyncSocketTest, ConnectWriteAndClose) {
+TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) {
   TestServer server;
 
   // connect()
   EventBase evb;
   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  if (GetParam() == TFOState::ENABLED) {
+    socket->enableTFO();
+  }
   ConnCallback ccb;
   socket->connect(&ccb, server.getAddress(), 30);
 
@@ -374,18 +411,27 @@ TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
 /**
  * Test installing a read callback immediately, before connect() finishes.
  */
-TEST(AsyncSocketTest, ConnectAndRead) {
+TEST_P(AsyncSocketConnectTest, ConnectAndRead) {
   TestServer server;
 
   // connect()
   EventBase evb;
   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  if (GetParam() == TFOState::ENABLED) {
+    socket->enableTFO();
+  }
+
   ConnCallback ccb;
   socket->connect(&ccb, server.getAddress(), 30);
 
   ReadCallback rcb;
   socket->setReadCB(&rcb);
 
+  if (GetParam() == TFOState::ENABLED) {
+    // Trigger a connection
+    socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
+  }
+
   // Even though we haven't looped yet, we should be able to accept
   // the connection and send data to it.
   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
@@ -399,7 +445,6 @@ TEST(AsyncSocketTest, ConnectAndRead) {
   evb.loop();
 
   CHECK_EQ(ccb.state, STATE_SUCCEEDED);
-  CHECK_EQ(rcb.state, STATE_SUCCEEDED);
   CHECK_EQ(rcb.buffers.size(), 1);
   CHECK_EQ(rcb.buffers[0].length, sizeof(buf));
   CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
@@ -450,12 +495,15 @@ TEST(AsyncSocketTest, ConnectReadAndClose) {
  * Test both writing and installing a read callback immediately,
  * before connect() finishes.
  */
-TEST(AsyncSocketTest, ConnectWriteAndRead) {
+TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) {
   TestServer server;
 
   // connect()
   EventBase evb;
   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  if (GetParam() == TFOState::ENABLED) {
+    socket->enableTFO();
+  }
   ConnCallback ccb;
   socket->connect(&ccb, server.getAddress(), 30);
 
@@ -2309,3 +2357,415 @@ TEST(AsyncSocketTest, BufferCallbackKill) {
   evb.loop();
   CHECK_EQ(ccb.state, STATE_SUCCEEDED);
 }
+
+#if FOLLY_ALLOW_TFO
+TEST(AsyncSocketTest, ConnectTFO) {
+  // Start listening on a local port
+  TestServer server(true);
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->enableTFO();
+  ConnCallback cb;
+  socket->connect(&cb, server.getAddress(), 30);
+
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+
+  std::array<uint8_t, 3> readBuf;
+  auto sendBuf = IOBuf::copyBuffer("hey");
+
+  std::thread t([&] {
+    auto acceptedSocket = server.accept();
+    acceptedSocket->write(buf.data(), buf.size());
+    acceptedSocket->flush();
+    acceptedSocket->readAll(readBuf.data(), readBuf.size());
+    acceptedSocket->close();
+  });
+
+  evb.loop();
+
+  CHECK_EQ(cb.state, STATE_SUCCEEDED);
+  EXPECT_LE(0, socket->getConnectTime().count());
+  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
+  EXPECT_TRUE(socket->getTFOAttempted());
+
+  // Should trigger the connect
+  WriteCallback write;
+  ReadCallback rcb;
+  socket->writeChain(&write, sendBuf->clone());
+  socket->setReadCB(&rcb);
+  evb.loop();
+
+  t.join();
+
+  EXPECT_EQ(STATE_SUCCEEDED, write.state);
+  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
+  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
+  ASSERT_EQ(1, rcb.buffers.size());
+  ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
+  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
+}
+
+/**
+ * Test connecting to a server that isn't listening
+ */
+TEST(AsyncSocketTest, ConnectRefusedTFO) {
+  EventBase evb;
+
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+  socket->enableTFO();
+
+  // Hopefully nothing is actually listening on this address
+  folly::SocketAddress addr("::1", 65535);
+  ConnCallback cb;
+  socket->connect(&cb, addr, 30);
+
+  evb.loop();
+
+  WriteCallback write1;
+  // Trigger the connect if TFO attempt is supported.
+  socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
+  evb.loop();
+  WriteCallback write2;
+  socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
+  evb.loop();
+
+  if (!socket->getTFOFinished()) {
+    EXPECT_EQ(STATE_FAILED, write1.state);
+    EXPECT_FALSE(socket->getTFOFinished());
+  } else {
+    EXPECT_EQ(STATE_SUCCEEDED, write1.state);
+    EXPECT_TRUE(socket->getTFOFinished());
+  }
+
+  EXPECT_EQ(STATE_FAILED, write2.state);
+
+  EXPECT_EQ(STATE_SUCCEEDED, cb.state);
+  EXPECT_LE(0, socket->getConnectTime().count());
+  EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
+  EXPECT_TRUE(socket->getTFOAttempted());
+}
+
+/**
+ * Test calling closeNow() immediately after connecting.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) {
+  TestServer server(true);
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->enableTFO();
+
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // write()
+  std::array<char, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+
+  // close()
+  socket->closeNow();
+
+  // Loop, although there shouldn't be anything to do.
+  evb.loop();
+
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+  ASSERT_TRUE(socket->isClosedBySelf());
+  ASSERT_FALSE(socket->isClosedByPeer());
+}
+
+/**
+ * Test calling close() immediately after connect()
+ */
+TEST(AsyncSocketTest, ConnectAndCloseTFO) {
+  TestServer server(true);
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->enableTFO();
+
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  socket->close();
+
+  // Loop, although there shouldn't be anything to do.
+  evb.loop();
+
+  // Make sure the connection was aborted
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+  ASSERT_TRUE(socket->isClosedBySelf());
+  ASSERT_FALSE(socket->isClosedByPeer());
+}
+
+class MockAsyncTFOSocket : public AsyncSocket {
+ public:
+  using UniquePtr = std::unique_ptr<MockAsyncTFOSocket, Destructor>;
+
+  explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
+
+  MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+TEST(AsyncSocketTest, TestTFOUnsupported) {
+  TestServer server(true);
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+  socket->enableTFO();
+
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
+  WriteCallback write;
+  auto sendBuf = IOBuf::copyBuffer("hey");
+  socket->writeChain(&write, sendBuf->clone());
+  EXPECT_EQ(STATE_WAITING, write.state);
+
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+
+  std::array<uint8_t, 3> readBuf;
+
+  std::thread t([&] {
+    std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+    acceptedSocket->write(buf.data(), buf.size());
+    acceptedSocket->flush();
+    acceptedSocket->readAll(readBuf.data(), readBuf.size());
+    acceptedSocket->close();
+  });
+
+  evb.loop();
+
+  t.join();
+  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+  EXPECT_EQ(STATE_SUCCEEDED, write.state);
+
+  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
+  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
+  ASSERT_EQ(1, rcb.buffers.size());
+  ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
+  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
+}
+
+TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
+  // Try connecting to server that won't respond.
+  //
+  // This depends somewhat on the network where this test is run.
+  // Hopefully this IP will be routable but unresponsive.
+  // (Alternatively, we could try listening on a local raw socket, but that
+  // normally requires root privileges.)
+  auto host = SocketAddressTestHelper::isIPv6Enabled()
+      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
+      : SocketAddressTestHelper::isIPv4Enabled()
+          ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
+          : nullptr;
+  SocketAddress addr(host, 65535);
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+  socket->enableTFO();
+
+  ConnCallback ccb;
+  // Set a very small timeout
+  socket->connect(&ccb, addr, 1);
+  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
+  WriteCallback write;
+  socket->writeChain(&write, IOBuf::copyBuffer("hey"));
+
+  evb.loop();
+
+  EXPECT_EQ(STATE_FAILED, write.state);
+}
+
+TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
+  TestServer server(true);
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+  socket->enableTFO();
+
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+        sockaddr_storage addr;
+        auto len = server.getAddress().getAddress(&addr);
+        return connect(fd, (const struct sockaddr*)&addr, len);
+      }));
+  WriteCallback write;
+  auto sendBuf = IOBuf::copyBuffer("hey");
+  socket->writeChain(&write, sendBuf->clone());
+  EXPECT_EQ(STATE_WAITING, write.state);
+
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+
+  std::array<uint8_t, 3> readBuf;
+
+  std::thread t([&] {
+    std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+    acceptedSocket->write(buf.data(), buf.size());
+    acceptedSocket->flush();
+    acceptedSocket->readAll(readBuf.data(), readBuf.size());
+    acceptedSocket->close();
+  });
+
+  evb.loop();
+
+  t.join();
+  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
+
+  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+  EXPECT_EQ(STATE_SUCCEEDED, write.state);
+
+  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
+  ASSERT_EQ(1, rcb.buffers.size());
+  ASSERT_EQ(buf.size(), rcb.buffers[0].length);
+  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
+}
+
+TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
+  // Try connecting to server that won't respond.
+  //
+  // This depends somewhat on the network where this test is run.
+  // Hopefully this IP will be routable but unresponsive.
+  // (Alternatively, we could try listening on a local raw socket, but that
+  // normally requires root privileges.)
+  auto host = SocketAddressTestHelper::isIPv6Enabled()
+      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
+      : SocketAddressTestHelper::isIPv4Enabled()
+          ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
+          : nullptr;
+  SocketAddress addr(host, 65535);
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+  socket->enableTFO();
+
+  ConnCallback ccb;
+  // Set a very small timeout
+  socket->connect(&ccb, addr, 1);
+  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+        sockaddr_storage addr2;
+        auto len = addr.getAddress(&addr2);
+        return connect(fd, (const struct sockaddr*)&addr2, len);
+      }));
+  WriteCallback write;
+  socket->writeChain(&write, IOBuf::copyBuffer("hey"));
+
+  evb.loop();
+
+  EXPECT_EQ(STATE_FAILED, write.state);
+}
+
+TEST(AsyncSocketTest, TestTFOEagain) {
+  TestServer server(true);
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+  socket->enableTFO();
+
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .WillOnce(SetErrnoAndReturn(EAGAIN, -1));
+  WriteCallback write;
+  socket->writeChain(&write, IOBuf::copyBuffer("hey"));
+
+  evb.loop();
+
+  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+  EXPECT_EQ(STATE_FAILED, write.state);
+}
+
+// Sending a large amount of data in the first write which will
+// definitely not fit into MSS.
+TEST(AsyncSocketTest, ConnectTFOWithBigData) {
+  // Start listening on a local port
+  TestServer server(true);
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->enableTFO();
+  ConnCallback cb;
+  socket->connect(&cb, server.getAddress(), 30);
+
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+
+  constexpr size_t len = 10 * 1024;
+  auto sendBuf = IOBuf::create(len);
+  sendBuf->append(len);
+  std::array<uint8_t, len> readBuf;
+
+  std::thread t([&] {
+    auto acceptedSocket = server.accept();
+    acceptedSocket->write(buf.data(), buf.size());
+    acceptedSocket->flush();
+    acceptedSocket->readAll(readBuf.data(), readBuf.size());
+    acceptedSocket->close();
+  });
+
+  evb.loop();
+
+  CHECK_EQ(cb.state, STATE_SUCCEEDED);
+  EXPECT_LE(0, socket->getConnectTime().count());
+  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
+  EXPECT_TRUE(socket->getTFOAttempted());
+
+  // Should trigger the connect
+  WriteCallback write;
+  ReadCallback rcb;
+  socket->writeChain(&write, sendBuf->clone());
+  socket->setReadCB(&rcb);
+  evb.loop();
+
+  t.join();
+
+  EXPECT_EQ(STATE_SUCCEEDED, write.state);
+  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
+  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
+  ASSERT_EQ(1, rcb.buffers.size());
+  ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
+  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
+}
+
+#endif
index 360cfcb18cfa342004ea7934a21a2cb610fde79f..3830648e61d43a9e3d0407857a3f8287f9fc0c1e 100644 (file)
@@ -40,6 +40,10 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
     sock_->attachEventBase(&eventBase_);
   }
 
+  void setAddress(folly::SocketAddress address) {
+    address_ = address;
+  }
+
   void open() {
     sock_->connect(this, address_);
     eventBase_.loop();
@@ -110,11 +114,15 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
   }
 
   int32_t readHelper(uint8_t *buf, size_t len, bool all) {
+    if (!sock_->good()) {
+      return 0;
+    }
+
     readBuf_ = buf;
     readLen_ = len;
     sock_->setReadCB(this);
     while (!err_ && sock_->good() && readLen_ > 0) {
-      eventBase_.loop();
+      eventBase_.loopOnce();
       if (!all) {
         break;
       }
diff --git a/folly/io/async/test/SocketClient.cpp b/folly/io/async/test/SocketClient.cpp
new file mode 100644 (file)
index 0000000..7f20d48
--- /dev/null
@@ -0,0 +1,70 @@
+/*
+ * Copyright 2016 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/async/test/BlockingSocket.h>
+
+#include <folly/ExceptionWrapper.h>
+#include <gflags/gflags.h>
+
+using namespace folly;
+
+DEFINE_string(host, "localhost", "Host");
+DEFINE_int32(port, 0, "port");
+DEFINE_bool(tfo, false, "enable tfo");
+DEFINE_string(msg, "", "Message to send");
+
+int main(int argc, char** argv) {
+  gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+  if (FLAGS_port == 0) {
+    LOG(ERROR) << "Must specify port";
+    exit(EXIT_FAILURE);
+  }
+
+  // Prep the socket
+  EventBase evb;
+  AsyncSocket::UniquePtr socket(new AsyncSocket(&evb));
+  socket->detachEventBase();
+
+  if (FLAGS_tfo) {
+#if FOLLY_ALLOW_TFO
+    socket->enableTFO();
+#endif
+  }
+
+  // Keep this around
+  auto sockAddr = socket.get();
+
+  BlockingSocket sock(std::move(socket));
+  SocketAddress addr;
+  addr.setFromHostPort(FLAGS_host, FLAGS_port);
+  sock.setAddress(addr);
+  sock.open();
+  LOG(INFO) << "connected to " << addr.getAddressStr();
+
+  sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size());
+
+  LOG(ERROR) << "TFO attempted: " << sockAddr->getTFOAttempted();
+  LOG(ERROR) << "TFO finished: " << sockAddr->getTFOFinished();
+
+  std::array<char, 1024> buf;
+  int32_t bytesRead = 0;
+  while ((bytesRead = sock.read((uint8_t*)buf.data(), buf.size())) != 0) {
+    std::cout << std::string(buf.data(), bytesRead);
+  }
+
+  sock.close();
+  return 0;
+}