Move AsyncSocket tests from thrift to folly
authorAlan Frindell <afrind@fb.com>
Thu, 2 Apr 2015 17:20:49 +0000 (10:20 -0700)
committerafrind <afrind@fb.com>
Thu, 2 Apr 2015 19:02:41 +0000 (12:02 -0700)
Summary: These tests belong with the code that they test.  The old tests had a couple dependencies on TSocket/TSSLSocket, so I wrote a BlockingSocket wrapper for AsyncSocket/AsyncSSLSocket

Test Plan: Ran the tests

Reviewed By: alandau@fb.com

Subscribers: doug, net-systems@, alandau, bmatheny, mshneer, folly-diffs@, yfeldblum, chalfant

FB internal diff: D1959955

Signature: t1:1959955:1427917833:73d334846cf248f8bb215f3eb5b596df7f7cee4f

folly/Makefile.am
folly/io/async/test/AsyncSSLSocketTest.cpp [new file with mode: 0644]
folly/io/async/test/AsyncSSLSocketTest.h [new file with mode: 0644]
folly/io/async/test/AsyncSSLSocketTest2.cpp [new file with mode: 0644]
folly/io/async/test/AsyncSSLSocketWriteTest.cpp [new file with mode: 0644]
folly/io/async/test/AsyncSocketTest2.cpp [new file with mode: 0644]
folly/io/async/test/BlockingSocket.h [new file with mode: 0644]
folly/io/async/test/certs/ca-cert.pem [new file with mode: 0644]
folly/io/async/test/certs/tests-cert.pem [new file with mode: 0644]
folly/io/async/test/certs/tests-key.pem [new file with mode: 0644]
folly/io/test/ShutdownSocketSetTest.cpp [new file with mode: 0644]

index a5aec762e285bb8c3e2cfae65165208e8ec95ccc..2dcf677128aaefcab6bad1c17b21bc1d3bc5a610 100644 (file)
@@ -162,6 +162,8 @@ nobase_follyinclude_HEADERS = \
        io/async/Request.h \
        io/async/SSLContext.h \
        io/async/TimeoutManager.h \
+       io/async/test/AsyncSSLSocketTest.h \
+       io/async/test/BlockingSocket.h \
        io/async/test/TimeUtil.h \
        io/async/test/UndelayedDestruction.h \
        io/async/test/Util.h \
diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp
new file mode 100644 (file)
index 0000000..bc89979
--- /dev/null
@@ -0,0 +1,1221 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/async/test/AsyncSSLSocketTest.h>
+
+#include <signal.h>
+#include <pthread.h>
+
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/SocketAddress.h>
+
+#include <folly/io/async/test/BlockingSocket.h>
+
+#include <gtest/gtest.h>
+#include <iostream>
+#include <list>
+#include <set>
+#include <unistd.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/tcp.h>
+#include <folly/io/Cursor.h>
+
+using std::string;
+using std::vector;
+using std::min;
+using std::cerr;
+using std::endl;
+using std::list;
+
+namespace folly {
+uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
+uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
+uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
+
+const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
+const char* testKey = "folly/io/async/test/certs/tests-key.pem";
+const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
+
+TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
+ctx_(new folly::SSLContext),
+    acb_(acb),
+  socket_(new folly::AsyncServerSocket(&evb_)) {
+  // Set up the SSL context
+  ctx_->loadCertificate(testCert);
+  ctx_->loadPrivateKey(testKey);
+  ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  acb_->ctx_ = ctx_;
+  acb_->base_ = &evb_;
+
+  //set up the listening socket
+  socket_->bind(0);
+  socket_->getAddress(&address_);
+  socket_->listen(100);
+  socket_->addAcceptCallback(acb_, &evb_);
+  socket_->startAccepting();
+
+  int ret = pthread_create(&thread_, nullptr, Main, this);
+  assert(ret == 0);
+
+  std::cerr << "Accepting connections on " << address_ << std::endl;
+}
+
+void getfds(int fds[2]) {
+  if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
+    FAIL() << "failed to create socketpair: " << strerror(errno);
+  }
+  for (int idx = 0; idx < 2; ++idx) {
+    int flags = fcntl(fds[idx], F_GETFL, 0);
+    if (flags == -1) {
+      FAIL() << "failed to get flags for socket " << idx << ": "
+             << strerror(errno);
+    }
+    if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
+      FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
+             << strerror(errno);
+    }
+  }
+}
+
+void getctx(
+  std::shared_ptr<folly::SSLContext> clientCtx,
+  std::shared_ptr<folly::SSLContext> serverCtx) {
+  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  serverCtx->loadCertificate(
+      testCert);
+  serverCtx->loadPrivateKey(
+      testKey);
+}
+
+void sslsocketpair(
+  EventBase* eventBase,
+  AsyncSSLSocket::UniquePtr* clientSock,
+  AsyncSSLSocket::UniquePtr* serverSock) {
+  auto clientCtx = std::make_shared<folly::SSLContext>();
+  auto serverCtx = std::make_shared<folly::SSLContext>();
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, serverCtx);
+  clientSock->reset(new AsyncSSLSocket(
+                      clientCtx, eventBase, fds[0], false));
+  serverSock->reset(new AsyncSSLSocket(
+                      serverCtx, eventBase, fds[1], true));
+
+  // (*clientSock)->setSendTimeout(100);
+  // (*serverSock)->setSendTimeout(100);
+}
+
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL context.
+  std::shared_ptr<SSLContext> sslContext(new SSLContext());
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
+  //sslContext->authenticate(true, false);
+
+  // connect
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+                                                 sslContext);
+  socket->open();
+
+  // write()
+  uint8_t buf[128];
+  memset(buf, 'a', sizeof(buf));
+  socket->write(buf, sizeof(buf));
+
+  // read()
+  uint8_t readbuf[128];
+  uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
+
+  // close()
+  socket->close();
+
+  cerr << "ConnectWriteReadClose test completed" << endl;
+}
+
+/**
+ * Negative test for handshakeError().
+ */
+TEST(AsyncSSLSocketTest, HandshakeError) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  HandshakeErrorCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL context.
+  std::shared_ptr<SSLContext> sslContext(new SSLContext());
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  // connect
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+                                                 sslContext);
+  // read()
+  bool ex = false;
+  try {
+    socket->open();
+
+    uint8_t readbuf[128];
+    uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
+  } catch (AsyncSocketException &e) {
+    ex = true;
+  }
+  EXPECT_TRUE(ex);
+
+  // close()
+  socket->close();
+  cerr << "HandshakeError test completed" << endl;
+}
+
+/**
+ * Negative test for readError().
+ */
+TEST(AsyncSSLSocketTest, ReadError) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL context.
+  std::shared_ptr<SSLContext> sslContext(new SSLContext());
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  // connect
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+                                                 sslContext);
+  socket->open();
+
+  // write something to trigger ssl handshake
+  uint8_t buf[128];
+  memset(buf, 'a', sizeof(buf));
+  socket->write(buf, sizeof(buf));
+
+  socket->close();
+  cerr << "ReadError test completed" << endl;
+}
+
+/**
+ * Negative test for writeError().
+ */
+TEST(AsyncSSLSocketTest, WriteError) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL context.
+  std::shared_ptr<SSLContext> sslContext(new SSLContext());
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  // connect
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+                                                 sslContext);
+  socket->open();
+
+  // write something to trigger ssl handshake
+  uint8_t buf[128];
+  memset(buf, 'a', sizeof(buf));
+  socket->write(buf, sizeof(buf));
+
+  socket->close();
+  cerr << "WriteError test completed" << endl;
+}
+
+/**
+ * Test a socket with TCP_NODELAY unset.
+ */
+TEST(AsyncSSLSocketTest, SocketWithDelay) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL context.
+  std::shared_ptr<SSLContext> sslContext(new SSLContext());
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  // connect
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+                                                 sslContext);
+  socket->open();
+
+  // write()
+  uint8_t buf[128];
+  memset(buf, 'a', sizeof(buf));
+  socket->write(buf, sizeof(buf));
+
+  // read()
+  uint8_t readbuf[128];
+  uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
+
+  // close()
+  socket->close();
+
+  cerr << "SocketWithDelay test completed" << endl;
+}
+
+TEST(AsyncSSLSocketTest, NpnTestOverlap) {
+  EventBase eventBase;
+  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+  std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, serverCtx);
+
+  clientCtx->setAdvertisedNextProtocols({"blub","baz"});
+  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+  NpnClient client(std::move(clientSock));
+  NpnServer server(std::move(serverSock));
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.nextProtoLength != 0);
+  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+                           server.nextProtoLength), 0);
+  string selected((const char*)client.nextProto, client.nextProtoLength);
+  EXPECT_EQ(selected.compare("baz"), 0);
+}
+
+TEST(AsyncSSLSocketTest, NpnTestUnset) {
+  // Identical to above test, except that we want unset NPN before
+  // looping.
+  EventBase eventBase;
+  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+  std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, serverCtx);
+
+  clientCtx->setAdvertisedNextProtocols({"blub","baz"});
+  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+  // unsetting NPN for any of [client, server] is enought to make NPN not
+  // work
+  clientCtx->unsetNextProtocols();
+
+  NpnClient client(std::move(clientSock));
+  NpnServer server(std::move(serverSock));
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.nextProtoLength == 0);
+  EXPECT_TRUE(server.nextProtoLength == 0);
+  EXPECT_TRUE(client.nextProto == nullptr);
+  EXPECT_TRUE(server.nextProto == nullptr);
+}
+
+TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
+  EventBase eventBase;
+  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+  std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, serverCtx);
+
+  clientCtx->setAdvertisedNextProtocols({"blub"});
+  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+  NpnClient client(std::move(clientSock));
+  NpnServer server(std::move(serverSock));
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.nextProtoLength != 0);
+  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+                           server.nextProtoLength), 0);
+  string selected((const char*)client.nextProto, client.nextProtoLength);
+  EXPECT_EQ(selected.compare("blub"), 0);
+}
+
+TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
+  // Probability that this test will fail is 2^-64, which could be considered
+  // as negligible.
+  const int kTries = 64;
+
+  std::set<string> selectedProtocols;
+  for (int i = 0; i < kTries; ++i) {
+    EventBase eventBase;
+    std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
+    std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
+    int fds[2];
+    getfds(fds);
+    getctx(clientCtx, serverCtx);
+
+    clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
+    serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
+        {1, {"bar"}}});
+
+
+    AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+    AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+    NpnClient client(std::move(clientSock));
+    NpnServer server(std::move(serverSock));
+
+    eventBase.loop();
+
+    EXPECT_TRUE(client.nextProtoLength != 0);
+    EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+    EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+                             server.nextProtoLength), 0);
+    string selected((const char*)client.nextProto, client.nextProtoLength);
+    selectedProtocols.insert(selected);
+  }
+  EXPECT_EQ(selectedProtocols.size(), 2);
+}
+
+
+#ifndef OPENSSL_NO_TLSEXT
+/**
+ * 1. Client sends TLSEXT_HOSTNAME in client hello.
+ * 2. Server found a match SSL_CTX and use this SSL_CTX to
+ *    continue the SSL handshake.
+ * 3. Server sends back TLSEXT_HOSTNAME in server hello.
+ */
+TEST(AsyncSSLSocketTest, SNITestMatch) {
+  EventBase eventBase;
+  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
+  // Use the same SSLContext to continue the handshake after
+  // tlsext_hostname match.
+  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
+  const std::string serverName("xyz.newdev.facebook.com");
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, dfServerCtx);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+  SNIClient client(std::move(clientSock));
+  SNIServer server(std::move(serverSock),
+                   dfServerCtx,
+                   hskServerCtx,
+                   serverName);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.serverNameMatch);
+  EXPECT_TRUE(server.serverNameMatch);
+}
+
+/**
+ * 1. Client sends TLSEXT_HOSTNAME in client hello.
+ * 2. Server cannot find a matching SSL_CTX and continue to use
+ *    the current SSL_CTX to do the handshake.
+ * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
+ */
+TEST(AsyncSSLSocketTest, SNITestNotMatch) {
+  EventBase eventBase;
+  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
+  // Use the same SSLContext to continue the handshake after
+  // tlsext_hostname match.
+  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
+  const std::string clientRequestingServerName("foo.com");
+  const std::string serverExpectedServerName("xyz.newdev.facebook.com");
+
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, dfServerCtx);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx,
+                        &eventBase,
+                        fds[0],
+                        clientRequestingServerName));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+  SNIClient client(std::move(clientSock));
+  SNIServer server(std::move(serverSock),
+                   dfServerCtx,
+                   hskServerCtx,
+                   serverExpectedServerName);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(!client.serverNameMatch);
+  EXPECT_TRUE(!server.serverNameMatch);
+}
+
+/**
+ * 1. Client does not send TLSEXT_HOSTNAME in client hello.
+ * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
+ */
+TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
+  EventBase eventBase;
+  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
+  // Use the same SSLContext to continue the handshake after
+  // tlsext_hostname match.
+  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
+  const std::string serverExpectedServerName("xyz.newdev.facebook.com");
+
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, dfServerCtx);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+  SNIClient client(std::move(clientSock));
+  SNIServer server(std::move(serverSock),
+                   dfServerCtx,
+                   hskServerCtx,
+                   serverExpectedServerName);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(!client.serverNameMatch);
+  EXPECT_TRUE(!server.serverNameMatch);
+}
+
+#endif
+/**
+ * Test SSL client socket
+ */
+TEST(AsyncSSLSocketTest, SSLClientTest) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL client
+  EventBase eventBase;
+  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+                                             1));
+
+  client->connect();
+  EventBaseAborter eba(&eventBase, 3000);
+  eventBase.loop();
+
+  EXPECT_EQ(client->getMiss(), 1);
+  EXPECT_EQ(client->getHit(), 0);
+
+  cerr << "SSLClientTest test completed" << endl;
+}
+
+
+/**
+ * Test SSL client socket session re-use
+ */
+TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL client
+  EventBase eventBase;
+  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+                                             10));
+
+  client->connect();
+  EventBaseAborter eba(&eventBase, 3000);
+  eventBase.loop();
+
+  EXPECT_EQ(client->getMiss(), 1);
+  EXPECT_EQ(client->getHit(), 9);
+
+  cerr << "SSLClientTestReuse test completed" << endl;
+}
+
+/**
+ * Test SSL client socket timeout
+ */
+TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
+  // Start listening on a local port
+  EmptyReadCallback readCallback;
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL client
+  EventBase eventBase;
+  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+                                             1, 10));
+  client->connect(true /* write before connect completes */);
+  EventBaseAborter eba(&eventBase, 3000);
+  eventBase.loop();
+
+  usleep(100000);
+  // This is checking that the connectError callback precedes any queued
+  // writeError callbacks.  This matches AsyncSocket's behavior
+  EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
+  EXPECT_EQ(client->getErrors(), 1);
+  EXPECT_EQ(client->getMiss(), 0);
+  EXPECT_EQ(client->getHit(), 0);
+
+  cerr << "SSLClientTimeoutTest test completed" << endl;
+}
+
+
+/**
+ * Test SSL server async cache
+ */
+TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLAsyncCacheServer server(&acceptCallback);
+
+  // Set up SSL client
+  EventBase eventBase;
+  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+                                             10, 500));
+
+  client->connect();
+  EventBaseAborter eba(&eventBase, 3000);
+  eventBase.loop();
+
+  EXPECT_EQ(server.getAsyncCallbacks(), 18);
+  EXPECT_EQ(server.getAsyncLookups(), 9);
+  EXPECT_EQ(client->getMiss(), 10);
+  EXPECT_EQ(client->getHit(), 0);
+
+  cerr << "SSLServerAsyncCacheTest test completed" << endl;
+}
+
+
+/**
+ * Test SSL server accept timeout with cache path
+ */
+TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  EmptyReadCallback clientReadCallback;
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
+  TestSSLAsyncCacheServer server(&acceptCallback);
+
+  // Set up SSL client
+  EventBase eventBase;
+  // only do a TCP connect
+  std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
+  sock->connect(nullptr, server.getAddress());
+  clientReadCallback.tcpSocket_ = sock;
+  sock->setReadCB(&clientReadCallback);
+
+  EventBaseAborter eba(&eventBase, 3000);
+  eventBase.loop();
+
+  EXPECT_EQ(readCallback.state, STATE_WAITING);
+
+  cerr << "SSLServerTimeoutTest test completed" << endl;
+}
+
+/**
+ * Test SSL server accept timeout with cache path
+ */
+TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
+  TestSSLAsyncCacheServer server(&acceptCallback);
+
+  // Set up SSL client
+  EventBase eventBase;
+  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+                                             2));
+
+  client->connect();
+  EventBaseAborter eba(&eventBase, 3000);
+  eventBase.loop();
+
+  EXPECT_EQ(server.getAsyncCallbacks(), 1);
+  EXPECT_EQ(server.getAsyncLookups(), 1);
+  EXPECT_EQ(client->getErrors(), 1);
+  EXPECT_EQ(client->getMiss(), 1);
+  EXPECT_EQ(client->getHit(), 0);
+
+  cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
+}
+
+/**
+ * Test SSL server accept timeout with cache path
+ */
+TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLAsyncCacheServer server(&acceptCallback, 500);
+
+  // Set up SSL client
+  EventBase eventBase;
+  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+                                             2, 100));
+
+  client->connect();
+  EventBaseAborter eba(&eventBase, 3000);
+  eventBase.loop();
+
+  server.getEventBase().runInEventBaseThread([&handshakeCallback]{
+      handshakeCallback.closeSocket();});
+  // give time for the cache lookup to come back and find it closed
+  usleep(500000);
+
+  EXPECT_EQ(server.getAsyncCallbacks(), 1);
+  EXPECT_EQ(server.getAsyncLookups(), 1);
+  EXPECT_EQ(client->getErrors(), 1);
+  EXPECT_EQ(client->getMiss(), 1);
+  EXPECT_EQ(client->getHit(), 0);
+
+  cerr << "SSLServerCacheCloseTest test completed" << endl;
+}
+
+/**
+ * Verify Client Ciphers obtained using SSL MSG Callback.
+ */
+TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto serverCtx = std::make_shared<SSLContext>();
+  serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+  serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
+  serverCtx->loadPrivateKey(testKey);
+  serverCtx->loadCertificate(testCert);
+  serverCtx->loadTrustedCertificates(testCA);
+  serverCtx->loadClientCAList(testCA);
+
+  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+  clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
+  clientCtx->loadPrivateKey(testKey);
+  clientCtx->loadCertificate(testCert);
+  clientCtx->loadTrustedCertificates(testCA);
+
+  int fds[2];
+  getfds(fds);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+  SSLHandshakeClient client(std::move(clientSock), true, true);
+  SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
+
+  eventBase.loop();
+
+  EXPECT_EQ(server.clientCiphers_,
+            "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
+  EXPECT_TRUE(client.handshakeVerify_);
+  EXPECT_TRUE(client.handshakeSuccess_);
+  EXPECT_TRUE(!client.handshakeError_);
+  EXPECT_TRUE(server.handshakeVerify_);
+  EXPECT_TRUE(server.handshakeSuccess_);
+  EXPECT_TRUE(!server.handshakeError_);
+}
+
+TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
+  EventBase eventBase;
+  auto ctx = std::make_shared<SSLContext>();
+
+  int fds[2];
+  getfds(fds);
+
+  int bufLen = 42;
+  uint8_t majorVersion = 18;
+  uint8_t minorVersion = 25;
+
+  // Create callback buf
+  auto buf = IOBuf::create(bufLen);
+  buf->append(bufLen);
+  folly::io::RWPrivateCursor cursor(buf.get());
+  cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
+  cursor.write<uint16_t>(0);
+  cursor.write<uint8_t>(38);
+  cursor.write<uint8_t>(majorVersion);
+  cursor.write<uint8_t>(minorVersion);
+  cursor.skip(32);
+  cursor.write<uint32_t>(0);
+
+  SSL* ssl = ctx->createSSL();
+  AsyncSSLSocket::UniquePtr sock(
+      new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
+  sock->enableClientHelloParsing();
+
+  // Test client hello parsing in one packet
+  AsyncSSLSocket::clientHelloParsingCallback(
+      0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
+  buf.reset();
+
+  auto parsedClientHello = sock->getClientHelloInfo();
+  EXPECT_TRUE(parsedClientHello != nullptr);
+  EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
+  EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
+}
+
+TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
+  EventBase eventBase;
+  auto ctx = std::make_shared<SSLContext>();
+
+  int fds[2];
+  getfds(fds);
+
+  int bufLen = 42;
+  uint8_t majorVersion = 18;
+  uint8_t minorVersion = 25;
+
+  // Create callback buf
+  auto buf = IOBuf::create(bufLen);
+  buf->append(bufLen);
+  folly::io::RWPrivateCursor cursor(buf.get());
+  cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
+  cursor.write<uint16_t>(0);
+  cursor.write<uint8_t>(38);
+  cursor.write<uint8_t>(majorVersion);
+  cursor.write<uint8_t>(minorVersion);
+  cursor.skip(32);
+  cursor.write<uint32_t>(0);
+
+  SSL* ssl = ctx->createSSL();
+  AsyncSSLSocket::UniquePtr sock(
+      new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
+  sock->enableClientHelloParsing();
+
+  // Test parsing with two packets with first packet size < 3
+  auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
+  AsyncSSLSocket::clientHelloParsingCallback(
+      0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
+      ssl, sock.get());
+  bufCopy.reset();
+  bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
+  AsyncSSLSocket::clientHelloParsingCallback(
+      0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
+      ssl, sock.get());
+  bufCopy.reset();
+
+  auto parsedClientHello = sock->getClientHelloInfo();
+  EXPECT_TRUE(parsedClientHello != nullptr);
+  EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
+  EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
+}
+
+TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
+  EventBase eventBase;
+  auto ctx = std::make_shared<SSLContext>();
+
+  int fds[2];
+  getfds(fds);
+
+  int bufLen = 42;
+  uint8_t majorVersion = 18;
+  uint8_t minorVersion = 25;
+
+  // Create callback buf
+  auto buf = IOBuf::create(bufLen);
+  buf->append(bufLen);
+  folly::io::RWPrivateCursor cursor(buf.get());
+  cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
+  cursor.write<uint16_t>(0);
+  cursor.write<uint8_t>(38);
+  cursor.write<uint8_t>(majorVersion);
+  cursor.write<uint8_t>(minorVersion);
+  cursor.skip(32);
+  cursor.write<uint32_t>(0);
+
+  SSL* ssl = ctx->createSSL();
+  AsyncSSLSocket::UniquePtr sock(
+      new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
+  sock->enableClientHelloParsing();
+
+  // Test parsing with multiple small packets
+  for (uint64_t i = 0; i < buf->length(); i += 3) {
+    auto bufCopy = folly::IOBuf::copyBuffer(
+        buf->data() + i, std::min((uint64_t)3, buf->length() - i));
+    AsyncSSLSocket::clientHelloParsingCallback(
+        0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
+        ssl, sock.get());
+    bufCopy.reset();
+  }
+
+  auto parsedClientHello = sock->getClientHelloInfo();
+  EXPECT_TRUE(parsedClientHello != nullptr);
+  EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
+  EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
+}
+
+/**
+ * Verify sucessful behavior of SSL certificate validation.
+ */
+TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, dfServerCtx);
+
+  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+  SSLHandshakeClient client(std::move(clientSock), true, true);
+  clientCtx->loadTrustedCertificates(testCA);
+
+  SSLHandshakeServer server(std::move(serverSock), true, true);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.handshakeVerify_);
+  EXPECT_TRUE(client.handshakeSuccess_);
+  EXPECT_TRUE(!client.handshakeError_);
+  EXPECT_TRUE(!server.handshakeVerify_);
+  EXPECT_TRUE(server.handshakeSuccess_);
+  EXPECT_TRUE(!server.handshakeError_);
+}
+
+/**
+ * Verify that the client's verification callback is able to fail SSL
+ * connection establishment.
+ */
+TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, dfServerCtx);
+
+  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+  SSLHandshakeClient client(std::move(clientSock), true, false);
+  clientCtx->loadTrustedCertificates(testCA);
+
+  SSLHandshakeServer server(std::move(serverSock), true, true);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.handshakeVerify_);
+  EXPECT_TRUE(!client.handshakeSuccess_);
+  EXPECT_TRUE(client.handshakeError_);
+  EXPECT_TRUE(!server.handshakeVerify_);
+  EXPECT_TRUE(!server.handshakeSuccess_);
+  EXPECT_TRUE(server.handshakeError_);
+}
+
+/**
+ * Verify that the options in SSLContext can be overridden in
+ * sslConnect/Accept.i.e specifying that no validation should be performed
+ * allows an otherwise-invalid certificate to be accepted and doesn't fire
+ * the validation callback.
+ */
+TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, dfServerCtx);
+
+  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+  SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
+  clientCtx->loadTrustedCertificates(testCA);
+
+  SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(!client.handshakeVerify_);
+  EXPECT_TRUE(client.handshakeSuccess_);
+  EXPECT_TRUE(!client.handshakeError_);
+  EXPECT_TRUE(!server.handshakeVerify_);
+  EXPECT_TRUE(server.handshakeSuccess_);
+  EXPECT_TRUE(!server.handshakeError_);
+}
+
+/**
+ * Verify that the options in SSLContext can be overridden in
+ * sslConnect/Accept. Enable verification even if context says otherwise.
+ * Test requireClientCert with client cert
+ */
+TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto serverCtx = std::make_shared<SSLContext>();
+  serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  serverCtx->loadPrivateKey(testKey);
+  serverCtx->loadCertificate(testCert);
+  serverCtx->loadTrustedCertificates(testCA);
+  serverCtx->loadClientCAList(testCA);
+
+  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  clientCtx->loadPrivateKey(testKey);
+  clientCtx->loadCertificate(testCert);
+  clientCtx->loadTrustedCertificates(testCA);
+
+  int fds[2];
+  getfds(fds);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+  SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
+  SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.handshakeVerify_);
+  EXPECT_TRUE(client.handshakeSuccess_);
+  EXPECT_FALSE(client.handshakeError_);
+  EXPECT_TRUE(server.handshakeVerify_);
+  EXPECT_TRUE(server.handshakeSuccess_);
+  EXPECT_FALSE(server.handshakeError_);
+}
+
+/**
+ * Verify that the client's verification callback is able to override
+ * the preverification failure and allow a successful connection.
+ */
+TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, dfServerCtx);
+
+  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+  SSLHandshakeClient client(std::move(clientSock), false, true);
+  SSLHandshakeServer server(std::move(serverSock), true, true);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.handshakeVerify_);
+  EXPECT_TRUE(client.handshakeSuccess_);
+  EXPECT_TRUE(!client.handshakeError_);
+  EXPECT_TRUE(!server.handshakeVerify_);
+  EXPECT_TRUE(server.handshakeSuccess_);
+  EXPECT_TRUE(!server.handshakeError_);
+}
+
+/**
+ * Verify that specifying that no validation should be performed allows an
+ * otherwise-invalid certificate to be accepted and doesn't fire the validation
+ * callback.
+ */
+TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, dfServerCtx);
+
+  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+  SSLHandshakeClient client(std::move(clientSock), false, false);
+  SSLHandshakeServer server(std::move(serverSock), false, false);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(!client.handshakeVerify_);
+  EXPECT_TRUE(client.handshakeSuccess_);
+  EXPECT_TRUE(!client.handshakeError_);
+  EXPECT_TRUE(!server.handshakeVerify_);
+  EXPECT_TRUE(server.handshakeSuccess_);
+  EXPECT_TRUE(!server.handshakeError_);
+}
+
+/**
+ * Test requireClientCert with client cert
+ */
+TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto serverCtx = std::make_shared<SSLContext>();
+  serverCtx->setVerificationOption(
+      SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  serverCtx->loadPrivateKey(testKey);
+  serverCtx->loadCertificate(testCert);
+  serverCtx->loadTrustedCertificates(testCA);
+  serverCtx->loadClientCAList(testCA);
+
+  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  clientCtx->loadPrivateKey(testKey);
+  clientCtx->loadCertificate(testCert);
+  clientCtx->loadTrustedCertificates(testCA);
+
+  int fds[2];
+  getfds(fds);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+  SSLHandshakeClient client(std::move(clientSock), true, true);
+  SSLHandshakeServer server(std::move(serverSock), true, true);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.handshakeVerify_);
+  EXPECT_TRUE(client.handshakeSuccess_);
+  EXPECT_FALSE(client.handshakeError_);
+  EXPECT_TRUE(server.handshakeVerify_);
+  EXPECT_TRUE(server.handshakeSuccess_);
+  EXPECT_FALSE(server.handshakeError_);
+}
+
+
+/**
+ * Test requireClientCert with no client cert
+ */
+TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto serverCtx = std::make_shared<SSLContext>();
+  serverCtx->setVerificationOption(
+      SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  serverCtx->loadPrivateKey(testKey);
+  serverCtx->loadCertificate(testCert);
+  serverCtx->loadTrustedCertificates(testCA);
+  serverCtx->loadClientCAList(testCA);
+  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  int fds[2];
+  getfds(fds);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+  SSLHandshakeClient client(std::move(clientSock), false, false);
+  SSLHandshakeServer server(std::move(serverSock), false, false);
+
+  eventBase.loop();
+
+  EXPECT_FALSE(server.handshakeVerify_);
+  EXPECT_FALSE(server.handshakeSuccess_);
+  EXPECT_TRUE(server.handshakeError_);
+}
+}
+
+///////////////////////////////////////////////////////////////////////////
+// init_unit_test_suite
+///////////////////////////////////////////////////////////////////////////
+namespace {
+struct Initializer {
+  Initializer() {
+    signal(SIGPIPE, SIG_IGN);
+  }
+};
+Initializer initializer;
+} // anonymous
diff --git a/folly/io/async/test/AsyncSSLSocketTest.h b/folly/io/async/test/AsyncSSLSocketTest.h
new file mode 100644 (file)
index 0000000..78623f7
--- /dev/null
@@ -0,0 +1,1277 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <signal.h>
+#include <pthread.h>
+
+#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTransport.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/io/async/AsyncTimeout.h>
+#include <folly/SocketAddress.h>
+
+#include <gtest/gtest.h>
+#include <iostream>
+#include <list>
+#include <unistd.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/tcp.h>
+
+namespace folly {
+
+enum StateEnum {
+  STATE_WAITING,
+  STATE_SUCCEEDED,
+  STATE_FAILED
+};
+
+// The destructors of all callback classes assert that the state is
+// STATE_SUCCEEDED, for both possitive and negative tests. The tests
+// are responsible for setting the succeeded state properly before the
+// destructors are called.
+
+class WriteCallbackBase :
+public AsyncTransportWrapper::WriteCallback {
+public:
+  WriteCallbackBase()
+      : state(STATE_WAITING)
+      , bytesWritten(0)
+      , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+  ~WriteCallbackBase() {
+    EXPECT_EQ(state, STATE_SUCCEEDED);
+  }
+
+  void setSocket(
+    const std::shared_ptr<AsyncSSLSocket> &socket) {
+    socket_ = socket;
+  }
+
+  void writeSuccess() noexcept override {
+    std::cerr << "writeSuccess" << std::endl;
+    state = STATE_SUCCEEDED;
+  }
+
+  void writeErr(
+    size_t bytesWritten,
+    const AsyncSocketException& ex) noexcept override {
+    std::cerr << "writeError: bytesWritten " << bytesWritten
+         << ", exception " << ex.what() << std::endl;
+
+    state = STATE_FAILED;
+    this->bytesWritten = bytesWritten;
+    exception = ex;
+    socket_->close();
+    socket_->detachEventBase();
+  }
+
+  std::shared_ptr<AsyncSSLSocket> socket_;
+  StateEnum state;
+  size_t bytesWritten;
+  AsyncSocketException exception;
+};
+
+class ReadCallbackBase :
+public AsyncTransportWrapper::ReadCallback {
+public:
+  explicit ReadCallbackBase(WriteCallbackBase *wcb)
+      : wcb_(wcb)
+      , state(STATE_WAITING) {}
+
+  ~ReadCallbackBase() {
+    EXPECT_EQ(state, STATE_SUCCEEDED);
+  }
+
+  void setSocket(
+    const std::shared_ptr<AsyncSSLSocket> &socket) {
+    socket_ = socket;
+  }
+
+  void setState(StateEnum s) {
+    state = s;
+    if (wcb_) {
+      wcb_->state = s;
+    }
+  }
+
+  void readErr(
+    const AsyncSocketException& ex) noexcept override {
+    std::cerr << "readError " << ex.what() << std::endl;
+    state = STATE_FAILED;
+    socket_->close();
+    socket_->detachEventBase();
+  }
+
+  void readEOF() noexcept override {
+    std::cerr << "readEOF" << std::endl;
+
+    socket_->close();
+    socket_->detachEventBase();
+  }
+
+  std::shared_ptr<AsyncSSLSocket> socket_;
+  WriteCallbackBase *wcb_;
+  StateEnum state;
+};
+
+class ReadCallback : public ReadCallbackBase {
+public:
+  explicit ReadCallback(WriteCallbackBase *wcb)
+      : ReadCallbackBase(wcb)
+      , buffers() {}
+
+  ~ReadCallback() {
+    for (std::vector<Buffer>::iterator it = buffers.begin();
+         it != buffers.end();
+         ++it) {
+      it->free();
+    }
+    currentBuffer.free();
+  }
+
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    if (!currentBuffer.buffer) {
+      currentBuffer.allocate(4096);
+    }
+    *bufReturn = currentBuffer.buffer;
+    *lenReturn = currentBuffer.length;
+  }
+
+  void readDataAvailable(size_t len) noexcept override {
+    std::cerr << "readDataAvailable, len " << len << std::endl;
+
+    currentBuffer.length = len;
+
+    wcb_->setSocket(socket_);
+
+    // Write back the same data.
+    socket_->write(wcb_, currentBuffer.buffer, len);
+
+    buffers.push_back(currentBuffer);
+    currentBuffer.reset();
+    state = STATE_SUCCEEDED;
+  }
+
+  class Buffer {
+  public:
+    Buffer() : buffer(nullptr), length(0) {}
+    Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
+
+    void reset() {
+      buffer = nullptr;
+      length = 0;
+    }
+    void allocate(size_t length) {
+      assert(buffer == nullptr);
+      this->buffer = static_cast<char*>(malloc(length));
+      this->length = length;
+    }
+    void free() {
+      ::free(buffer);
+      reset();
+    }
+
+    char* buffer;
+    size_t length;
+  };
+
+  std::vector<Buffer> buffers;
+  Buffer currentBuffer;
+};
+
+class ReadErrorCallback : public ReadCallbackBase {
+public:
+  explicit ReadErrorCallback(WriteCallbackBase *wcb)
+      : ReadCallbackBase(wcb) {}
+
+  // Return nullptr buffer to trigger readError()
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *bufReturn = nullptr;
+    *lenReturn = 0;
+  }
+
+  void readDataAvailable(size_t len) noexcept override {
+    // This should never to called.
+    FAIL();
+  }
+
+  void readErr(
+    const AsyncSocketException& ex) noexcept override {
+    ReadCallbackBase::readErr(ex);
+    std::cerr << "ReadErrorCallback::readError" << std::endl;
+    setState(STATE_SUCCEEDED);
+  }
+};
+
+class WriteErrorCallback : public ReadCallback {
+public:
+  explicit WriteErrorCallback(WriteCallbackBase *wcb)
+      : ReadCallback(wcb) {}
+
+  void readDataAvailable(size_t len) noexcept override {
+    std::cerr << "readDataAvailable, len " << len << std::endl;
+
+    currentBuffer.length = len;
+
+    // close the socket before writing to trigger writeError().
+    ::close(socket_->getFd());
+
+    wcb_->setSocket(socket_);
+
+    // Write back the same data.
+    socket_->write(wcb_, currentBuffer.buffer, len);
+
+    if (wcb_->state == STATE_FAILED) {
+      setState(STATE_SUCCEEDED);
+    } else {
+      state = STATE_FAILED;
+    }
+
+    buffers.push_back(currentBuffer);
+    currentBuffer.reset();
+  }
+
+  void readErr(const AsyncSocketException& ex) noexcept override {
+    std::cerr << "readError " << ex.what() << std::endl;
+    // do nothing since this is expected
+  }
+};
+
+class EmptyReadCallback : public ReadCallback {
+public:
+  explicit EmptyReadCallback()
+      : ReadCallback(nullptr) {}
+
+  void readErr(const AsyncSocketException& ex) noexcept override {
+    std::cerr << "readError " << ex.what() << std::endl;
+    state = STATE_FAILED;
+    tcpSocket_->close();
+    tcpSocket_->detachEventBase();
+  }
+
+  void readEOF() noexcept override {
+    std::cerr << "readEOF" << std::endl;
+
+    tcpSocket_->close();
+    tcpSocket_->detachEventBase();
+    state = STATE_SUCCEEDED;
+  }
+
+  std::shared_ptr<AsyncSocket> tcpSocket_;
+};
+
+class HandshakeCallback :
+public AsyncSSLSocket::HandshakeCB {
+public:
+  enum ExpectType {
+    EXPECT_SUCCESS,
+    EXPECT_ERROR
+  };
+
+  explicit HandshakeCallback(ReadCallbackBase *rcb,
+                             ExpectType expect = EXPECT_SUCCESS):
+      state(STATE_WAITING),
+      rcb_(rcb),
+      expect_(expect) {}
+
+  void setSocket(
+    const std::shared_ptr<AsyncSSLSocket> &socket) {
+    socket_ = socket;
+  }
+
+  void setState(StateEnum s) {
+    state = s;
+    rcb_->setState(s);
+  }
+
+  // Functions inherited from AsyncSSLSocketHandshakeCallback
+  void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
+    EXPECT_EQ(sock, socket_.get());
+    std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
+    rcb_->setSocket(socket_);
+    sock->setReadCB(rcb_);
+    state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
+  }
+  void handshakeErr(
+    AsyncSSLSocket *sock,
+    const AsyncSocketException& ex) noexcept override {
+    std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
+    state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
+    if (expect_ == EXPECT_ERROR) {
+      // rcb will never be invoked
+      rcb_->setState(STATE_SUCCEEDED);
+    }
+  }
+
+  ~HandshakeCallback() {
+    EXPECT_EQ(state, STATE_SUCCEEDED);
+  }
+
+  void closeSocket() {
+    socket_->close();
+    state = STATE_SUCCEEDED;
+  }
+
+  StateEnum state;
+  std::shared_ptr<AsyncSSLSocket> socket_;
+  ReadCallbackBase *rcb_;
+  ExpectType expect_;
+};
+
+class SSLServerAcceptCallbackBase:
+public folly::AsyncServerSocket::AcceptCallback {
+public:
+  explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
+  state(STATE_WAITING), hcb_(hcb) {}
+
+  ~SSLServerAcceptCallbackBase() {
+    EXPECT_EQ(state, STATE_SUCCEEDED);
+  }
+
+  void acceptError(const std::exception& ex) noexcept override {
+    std::cerr << "SSLServerAcceptCallbackBase::acceptError "
+              << ex.what() << std::endl;
+    state = STATE_FAILED;
+  }
+
+  void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
+    noexcept override{
+    printf("Connection accepted\n");
+    std::shared_ptr<AsyncSSLSocket> sslSock;
+    try {
+      // Create a AsyncSSLSocket object with the fd. The socket should be
+      // added to the event base and in the state of accepting SSL connection.
+      sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
+    } catch (const std::exception &e) {
+      LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
+        "object with socket " << e.what() << fd;
+      ::close(fd);
+      acceptError(e);
+      return;
+    }
+
+    connAccepted(sslSock);
+  }
+
+  virtual void connAccepted(
+    const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
+
+  StateEnum state;
+  HandshakeCallback *hcb_;
+  std::shared_ptr<folly::SSLContext> ctx_;
+  folly::EventBase* base_;
+};
+
+class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
+public:
+  uint32_t timeout_;
+
+  explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
+                                   uint32_t timeout = 0):
+      SSLServerAcceptCallbackBase(hcb),
+      timeout_(timeout) {}
+
+  virtual ~SSLServerAcceptCallback() {
+    if (timeout_ > 0) {
+      // if we set a timeout, we expect failure
+      EXPECT_EQ(hcb_->state, STATE_FAILED);
+      hcb_->setState(STATE_SUCCEEDED);
+    }
+  }
+
+  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+  void connAccepted(
+    const std::shared_ptr<folly::AsyncSSLSocket> &s)
+    noexcept override {
+    auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+    std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
+
+    hcb_->setSocket(sock);
+    sock->sslAccept(hcb_, timeout_);
+    EXPECT_EQ(sock->getSSLState(),
+                      AsyncSSLSocket::STATE_ACCEPTING);
+
+    state = STATE_SUCCEEDED;
+  }
+};
+
+class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
+public:
+  explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
+      SSLServerAcceptCallback(hcb) {}
+
+  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+  void connAccepted(
+    const std::shared_ptr<folly::AsyncSSLSocket> &s)
+    noexcept override {
+
+    auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+
+    std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
+              << std::endl;
+    int fd = sock->getFd();
+
+#ifndef TCP_NOPUSH
+    {
+    // The accepted connection should already have TCP_NODELAY set
+    int value;
+    socklen_t valueLength = sizeof(value);
+    int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
+    EXPECT_EQ(rc, 0);
+    EXPECT_EQ(value, 1);
+    }
+#endif
+
+    // Unset the TCP_NODELAY option.
+    int value = 0;
+    socklen_t valueLength = sizeof(value);
+    int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
+    EXPECT_EQ(rc, 0);
+
+    rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
+    EXPECT_EQ(rc, 0);
+    EXPECT_EQ(value, 0);
+
+    SSLServerAcceptCallback::connAccepted(sock);
+  }
+};
+
+class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
+public:
+  explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
+                                             uint32_t timeout = 0):
+    SSLServerAcceptCallback(hcb, timeout) {}
+
+  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+  void connAccepted(
+    const std::shared_ptr<folly::AsyncSSLSocket> &s)
+    noexcept override {
+    auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+
+    std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
+
+    hcb_->setSocket(sock);
+    sock->sslAccept(hcb_, timeout_);
+    ASSERT_TRUE((sock->getSSLState() ==
+                 AsyncSSLSocket::STATE_ACCEPTING) ||
+                (sock->getSSLState() ==
+                 AsyncSSLSocket::STATE_CACHE_LOOKUP));
+
+    state = STATE_SUCCEEDED;
+  }
+};
+
+
+class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
+public:
+  explicit HandshakeErrorCallback(HandshakeCallback *hcb):
+  SSLServerAcceptCallbackBase(hcb)  {}
+
+  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+  void connAccepted(
+    const std::shared_ptr<folly::AsyncSSLSocket> &s)
+    noexcept override {
+    auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+
+    std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
+
+    // The first call to sslAccept() should succeed.
+    hcb_->setSocket(sock);
+    sock->sslAccept(hcb_);
+    EXPECT_EQ(sock->getSSLState(),
+                      AsyncSSLSocket::STATE_ACCEPTING);
+
+    // The second call to sslAccept() should fail.
+    HandshakeCallback callback2(hcb_->rcb_);
+    callback2.setSocket(sock);
+    sock->sslAccept(&callback2);
+    EXPECT_EQ(sock->getSSLState(),
+                      AsyncSSLSocket::STATE_ERROR);
+
+    // Both callbacks should be in the error state.
+    EXPECT_EQ(hcb_->state, STATE_FAILED);
+    EXPECT_EQ(callback2.state, STATE_FAILED);
+
+    sock->detachEventBase();
+
+    state = STATE_SUCCEEDED;
+    hcb_->setState(STATE_SUCCEEDED);
+    callback2.setState(STATE_SUCCEEDED);
+  }
+};
+
+class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
+public:
+  explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
+  SSLServerAcceptCallbackBase(hcb)  {}
+
+  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+  void connAccepted(
+    const std::shared_ptr<folly::AsyncSSLSocket> &s)
+    noexcept override {
+    std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
+
+    auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+
+    hcb_->setSocket(sock);
+    sock->getEventBase()->tryRunAfterDelay([=] {
+        std::cerr << "Delayed SSL accept, client will have close by now"
+                  << std::endl;
+        // SSL accept will fail
+        EXPECT_EQ(
+          sock->getSSLState(),
+          AsyncSSLSocket::STATE_UNINIT);
+        hcb_->socket_->sslAccept(hcb_);
+        // This registers for an event
+        EXPECT_EQ(
+          sock->getSSLState(),
+          AsyncSSLSocket::STATE_ACCEPTING);
+
+        state = STATE_SUCCEEDED;
+      }, 100);
+  }
+};
+
+
+class TestSSLServer {
+ protected:
+  EventBase evb_;
+  std::shared_ptr<folly::SSLContext> ctx_;
+  SSLServerAcceptCallbackBase *acb_;
+  folly::AsyncServerSocket *socket_;
+  folly::SocketAddress address_;
+  pthread_t thread_;
+
+  static void *Main(void *ctx) {
+    TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
+    self->evb_.loop();
+    std::cerr << "Server thread exited event loop" << std::endl;
+    return nullptr;
+  }
+
+ public:
+  // Create a TestSSLServer.
+  // This immediately starts listening on the given port.
+  explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
+
+  // Kill the thread.
+  ~TestSSLServer() {
+    evb_.runInEventBaseThread([&](){
+      socket_->stopAccepting();
+    });
+    std::cerr << "Waiting for server thread to exit" << std::endl;
+    pthread_join(thread_, nullptr);
+  }
+
+  EventBase &getEventBase() { return evb_; }
+
+  const folly::SocketAddress& getAddress() const {
+    return address_;
+  }
+};
+
+class TestSSLAsyncCacheServer : public TestSSLServer {
+ public:
+  explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
+        int lookupDelay = 100) :
+      TestSSLServer(acb) {
+    SSL_CTX *sslCtx = ctx_->getSSLCtx();
+    SSL_CTX_sess_set_get_cb(sslCtx,
+                            TestSSLAsyncCacheServer::getSessionCallback);
+    SSL_CTX_set_session_cache_mode(
+      sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
+    asyncCallbacks_ = 0;
+    asyncLookups_ = 0;
+    lookupDelay_ = lookupDelay;
+  }
+
+  uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
+  uint32_t getAsyncLookups() const { return asyncLookups_; }
+
+ private:
+  static uint32_t asyncCallbacks_;
+  static uint32_t asyncLookups_;
+  static uint32_t lookupDelay_;
+
+  static SSL_SESSION *getSessionCallback(SSL *ssl,
+                                         unsigned char *sess_id,
+                                         int id_len,
+                                         int *copyflag) {
+    *copyflag = 0;
+    asyncCallbacks_++;
+#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
+    if (!SSL_want_sess_cache_lookup(ssl)) {
+      // libssl.so mismatch
+      std::cerr << "no async support" << std::endl;
+      return nullptr;
+    }
+
+    AsyncSSLSocket *sslSocket =
+        AsyncSSLSocket::getFromSSL(ssl);
+    assert(sslSocket != nullptr);
+    // Going to simulate an async cache by just running delaying the miss 100ms
+    if (asyncCallbacks_ % 2 == 0) {
+      // This socket is already blocked on lookup, return miss
+      std::cerr << "returning miss" << std::endl;
+    } else {
+      // fresh meat - block it
+      std::cerr << "async lookup" << std::endl;
+      sslSocket->getEventBase()->tryRunAfterDelay(
+        std::bind(&AsyncSSLSocket::restartSSLAccept,
+                  sslSocket), lookupDelay_);
+      *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
+      asyncLookups_++;
+    }
+#endif
+    return nullptr;
+  }
+};
+
+void getfds(int fds[2]);
+
+void getctx(
+  std::shared_ptr<folly::SSLContext> clientCtx,
+  std::shared_ptr<folly::SSLContext> serverCtx);
+
+void sslsocketpair(
+  EventBase* eventBase,
+  AsyncSSLSocket::UniquePtr* clientSock,
+  AsyncSSLSocket::UniquePtr* serverSock);
+
+class BlockingWriteClient :
+  private AsyncSSLSocket::HandshakeCB,
+  private AsyncTransportWrapper::WriteCallback {
+ public:
+  explicit BlockingWriteClient(
+    AsyncSSLSocket::UniquePtr socket)
+    : socket_(std::move(socket)),
+      bufLen_(2500),
+      iovCount_(2000) {
+    // Fill buf_
+    buf_.reset(new uint8_t[bufLen_]);
+    for (uint32_t n = 0; n < sizeof(buf_); ++n) {
+      buf_[n] = n % 0xff;
+    }
+
+    // Initialize iov_
+    iov_.reset(new struct iovec[iovCount_]);
+    for (uint32_t n = 0; n < iovCount_; ++n) {
+      iov_[n].iov_base = buf_.get() + n;
+      if (n & 0x1) {
+        iov_[n].iov_len = n % bufLen_;
+      } else {
+        iov_[n].iov_len = bufLen_ - (n % bufLen_);
+      }
+    }
+
+    socket_->sslConn(this, 100);
+  }
+
+  struct iovec* getIovec() const {
+    return iov_.get();
+  }
+  uint32_t getIovecCount() const {
+    return iovCount_;
+  }
+
+ private:
+  void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    socket_->writev(this, iov_.get(), iovCount_);
+  }
+  void handshakeErr(
+    AsyncSSLSocket*,
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "client handshake error: " << ex.what();
+  }
+  void writeSuccess() noexcept override {
+    socket_->close();
+  }
+  void writeErr(
+    size_t bytesWritten,
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
+                  << ex.what();
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+  uint32_t bufLen_;
+  uint32_t iovCount_;
+  std::unique_ptr<uint8_t[]> buf_;
+  std::unique_ptr<struct iovec[]> iov_;
+};
+
+class BlockingWriteServer :
+    private AsyncSSLSocket::HandshakeCB,
+    private AsyncTransportWrapper::ReadCallback {
+ public:
+  explicit BlockingWriteServer(
+    AsyncSSLSocket::UniquePtr socket)
+    : socket_(std::move(socket)),
+      bufSize_(2500 * 2000),
+      bytesRead_(0) {
+    buf_.reset(new uint8_t[bufSize_]);
+    socket_->sslAccept(this, 100);
+  }
+
+  void checkBuffer(struct iovec* iov, uint32_t count) const {
+    uint32_t idx = 0;
+    for (uint32_t n = 0; n < count; ++n) {
+      size_t bytesLeft = bytesRead_ - idx;
+      int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
+                      std::min(iov[n].iov_len, bytesLeft));
+      if (rc != 0) {
+        FAIL() << "buffer mismatch at iovec " << n << "/" << count
+               << ": rc=" << rc;
+
+      }
+      if (iov[n].iov_len > bytesLeft) {
+        FAIL() << "server did not read enough data: "
+               << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
+               << " in iovec " << n << "/" << count;
+      }
+
+      idx += iov[n].iov_len;
+    }
+    if (idx != bytesRead_) {
+      ADD_FAILURE() << "server read extra data: " << bytesRead_
+                    << " bytes read; expected " << idx;
+    }
+  }
+
+ private:
+  void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    // Wait 10ms before reading, so the client's writes will initially block.
+    socket_->getEventBase()->tryRunAfterDelay(
+        [this] { socket_->setReadCB(this); }, 10);
+  }
+  void handshakeErr(
+    AsyncSSLSocket*,
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "server handshake error: " << ex.what();
+  }
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *bufReturn = buf_.get() + bytesRead_;
+    *lenReturn = bufSize_ - bytesRead_;
+  }
+  void readDataAvailable(size_t len) noexcept override {
+    bytesRead_ += len;
+    socket_->setReadCB(nullptr);
+    socket_->getEventBase()->tryRunAfterDelay(
+        [this] { socket_->setReadCB(this); }, 2);
+  }
+  void readEOF() noexcept override {
+    socket_->close();
+  }
+  void readErr(
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "server read error: " << ex.what();
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+  uint32_t bufSize_;
+  uint32_t bytesRead_;
+  std::unique_ptr<uint8_t[]> buf_;
+};
+
+class NpnClient :
+  private AsyncSSLSocket::HandshakeCB,
+  private AsyncTransportWrapper::WriteCallback {
+ public:
+  explicit NpnClient(
+    AsyncSSLSocket::UniquePtr socket)
+      : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
+    socket_->sslConn(this);
+  }
+
+  const unsigned char* nextProto;
+  unsigned nextProtoLength;
+ private:
+  void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    socket_->getSelectedNextProtocol(&nextProto,
+                                     &nextProtoLength);
+  }
+  void handshakeErr(
+    AsyncSSLSocket*,
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "client handshake error: " << ex.what();
+  }
+  void writeSuccess() noexcept override {
+    socket_->close();
+  }
+  void writeErr(
+    size_t bytesWritten,
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
+                  << ex.what();
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+};
+
+class NpnServer :
+    private AsyncSSLSocket::HandshakeCB,
+    private AsyncTransportWrapper::ReadCallback {
+ public:
+  explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
+      : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
+    socket_->sslAccept(this);
+  }
+
+  const unsigned char* nextProto;
+  unsigned nextProtoLength;
+ private:
+  void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    socket_->getSelectedNextProtocol(&nextProto,
+                                     &nextProtoLength);
+  }
+  void handshakeErr(
+    AsyncSSLSocket*,
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "server handshake error: " << ex.what();
+  }
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *lenReturn = 0;
+  }
+  void readDataAvailable(size_t len) noexcept override {
+  }
+  void readEOF() noexcept override {
+    socket_->close();
+  }
+  void readErr(
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "server read error: " << ex.what();
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+};
+
+#ifndef OPENSSL_NO_TLSEXT
+class SNIClient :
+  private AsyncSSLSocket::HandshakeCB,
+  private AsyncTransportWrapper::WriteCallback {
+ public:
+  explicit SNIClient(
+    AsyncSSLSocket::UniquePtr socket)
+      : serverNameMatch(false), socket_(std::move(socket)) {
+    socket_->sslConn(this);
+  }
+
+  bool serverNameMatch;
+
+ private:
+  void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    serverNameMatch = socket_->isServerNameMatch();
+  }
+  void handshakeErr(
+    AsyncSSLSocket*,
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "client handshake error: " << ex.what();
+  }
+  void writeSuccess() noexcept override {
+    socket_->close();
+  }
+  void writeErr(
+    size_t bytesWritten,
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
+                  << ex.what();
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+};
+
+class SNIServer :
+    private AsyncSSLSocket::HandshakeCB,
+    private AsyncTransportWrapper::ReadCallback {
+ public:
+  explicit SNIServer(
+    AsyncSSLSocket::UniquePtr socket,
+    const std::shared_ptr<folly::SSLContext>& ctx,
+    const std::shared_ptr<folly::SSLContext>& sniCtx,
+    const std::string& expectedServerName)
+      : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
+        expectedServerName_(expectedServerName) {
+    ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
+                                         std::placeholders::_1));
+    socket_->sslAccept(this);
+  }
+
+  bool serverNameMatch;
+
+ private:
+  void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
+  void handshakeErr(
+    AsyncSSLSocket*,
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "server handshake error: " << ex.what();
+  }
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *lenReturn = 0;
+  }
+  void readDataAvailable(size_t len) noexcept override {
+  }
+  void readEOF() noexcept override {
+    socket_->close();
+  }
+  void readErr(
+    const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "server read error: " << ex.what();
+  }
+
+  folly::SSLContext::ServerNameCallbackResult
+    serverNameCallback(SSL *ssl) {
+    const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+    if (sniCtx_ &&
+        sn &&
+        !strcasecmp(expectedServerName_.c_str(), sn)) {
+      AsyncSSLSocket *sslSocket =
+          AsyncSSLSocket::getFromSSL(ssl);
+      sslSocket->switchServerSSLContext(sniCtx_);
+      serverNameMatch = true;
+      return folly::SSLContext::SERVER_NAME_FOUND;
+    } else {
+      serverNameMatch = false;
+      return folly::SSLContext::SERVER_NAME_NOT_FOUND;
+    }
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+  std::shared_ptr<folly::SSLContext> sniCtx_;
+  std::string expectedServerName_;
+};
+#endif
+
+class SSLClient : public AsyncSocket::ConnectCallback,
+                  public AsyncTransportWrapper::WriteCallback,
+                  public AsyncTransportWrapper::ReadCallback
+{
+ private:
+  EventBase *eventBase_;
+  std::shared_ptr<AsyncSSLSocket> sslSocket_;
+  SSL_SESSION *session_;
+  std::shared_ptr<folly::SSLContext> ctx_;
+  uint32_t requests_;
+  folly::SocketAddress address_;
+  uint32_t timeout_;
+  char buf_[128];
+  char readbuf_[128];
+  uint32_t bytesRead_;
+  uint32_t hit_;
+  uint32_t miss_;
+  uint32_t errors_;
+  uint32_t writeAfterConnectErrors_;
+
+ public:
+  SSLClient(EventBase *eventBase,
+            const folly::SocketAddress& address,
+            uint32_t requests, uint32_t timeout = 0)
+      : eventBase_(eventBase),
+        session_(nullptr),
+        requests_(requests),
+        address_(address),
+        timeout_(timeout),
+        bytesRead_(0),
+        hit_(0),
+        miss_(0),
+        errors_(0),
+        writeAfterConnectErrors_(0) {
+    ctx_.reset(new folly::SSLContext());
+    ctx_->setOptions(SSL_OP_NO_TICKET);
+    ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+    memset(buf_, 'a', sizeof(buf_));
+  }
+
+  ~SSLClient() {
+    if (session_) {
+      SSL_SESSION_free(session_);
+    }
+    if (errors_ == 0) {
+      EXPECT_EQ(bytesRead_, sizeof(buf_));
+    }
+  }
+
+  uint32_t getHit() const { return hit_; }
+
+  uint32_t getMiss() const { return miss_; }
+
+  uint32_t getErrors() const { return errors_; }
+
+  uint32_t getWriteAfterConnectErrors() const {
+    return writeAfterConnectErrors_;
+  }
+
+  void connect(bool writeNow = false) {
+    sslSocket_ = AsyncSSLSocket::newSocket(
+      ctx_, eventBase_);
+    if (session_ != nullptr) {
+      sslSocket_->setSSLSession(session_);
+    }
+    requests_--;
+    sslSocket_->connect(this, address_, timeout_);
+    if (sslSocket_ && writeNow) {
+      // write some junk, used in an error test
+      sslSocket_->write(this, buf_, sizeof(buf_));
+    }
+  }
+
+  void connectSuccess() noexcept override {
+    std::cerr << "client SSL socket connected" << std::endl;
+    if (sslSocket_->getSSLSessionReused()) {
+      hit_++;
+    } else {
+      miss_++;
+      if (session_ != nullptr) {
+        SSL_SESSION_free(session_);
+      }
+      session_ = sslSocket_->getSSLSession();
+    }
+
+    // write()
+    sslSocket_->write(this, buf_, sizeof(buf_));
+    sslSocket_->setReadCB(this);
+    memset(readbuf_, 'b', sizeof(readbuf_));
+    bytesRead_ = 0;
+  }
+
+  void connectErr(
+    const AsyncSocketException& ex) noexcept override {
+    std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
+    errors_++;
+    sslSocket_.reset();
+  }
+
+  void writeSuccess() noexcept override {
+    std::cerr << "client write success" << std::endl;
+  }
+
+  void writeErr(
+    size_t bytesWritten,
+    const AsyncSocketException& ex)
+    noexcept override {
+    std::cerr << "client writeError: " << ex.what() << std::endl;
+    if (!sslSocket_) {
+      writeAfterConnectErrors_++;
+    }
+  }
+
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *bufReturn = readbuf_ + bytesRead_;
+    *lenReturn = sizeof(readbuf_) - bytesRead_;
+  }
+
+  void readEOF() noexcept override {
+    std::cerr << "client readEOF" << std::endl;
+  }
+
+  void readErr(
+    const AsyncSocketException& ex) noexcept override {
+    std::cerr << "client readError: " << ex.what() << std::endl;
+  }
+
+  void readDataAvailable(size_t len) noexcept override {
+    std::cerr << "client read data: " << len << std::endl;
+    bytesRead_ += len;
+    if (len == sizeof(buf_)) {
+      EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
+      sslSocket_->closeNow();
+      sslSocket_.reset();
+      if (requests_ != 0) {
+        connect();
+      }
+    }
+  }
+
+};
+
+class SSLHandshakeBase :
+  public AsyncSSLSocket::HandshakeCB,
+  private AsyncTransportWrapper::WriteCallback {
+ public:
+  explicit SSLHandshakeBase(
+   AsyncSSLSocket::UniquePtr socket,
+   bool preverifyResult,
+   bool verifyResult) :
+    handshakeVerify_(false),
+    handshakeSuccess_(false),
+    handshakeError_(false),
+    socket_(std::move(socket)),
+    preverifyResult_(preverifyResult),
+    verifyResult_(verifyResult) {
+  }
+
+  bool handshakeVerify_;
+  bool handshakeSuccess_;
+  bool handshakeError_;
+
+ protected:
+  AsyncSSLSocket::UniquePtr socket_;
+  bool preverifyResult_;
+  bool verifyResult_;
+
+  // HandshakeCallback
+  bool handshakeVer(
+   AsyncSSLSocket* sock,
+   bool preverifyOk,
+   X509_STORE_CTX* ctx) noexcept override {
+    handshakeVerify_ = true;
+
+    EXPECT_EQ(preverifyResult_, preverifyOk);
+    return verifyResult_;
+  }
+
+  void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    handshakeSuccess_ = true;
+  }
+
+  void handshakeErr(
+   AsyncSSLSocket*,
+   const AsyncSocketException& ex) noexcept override {
+    handshakeError_ = true;
+  }
+
+  // WriteCallback
+  void writeSuccess() noexcept override {
+    socket_->close();
+  }
+
+  void writeErr(
+   size_t bytesWritten,
+   const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
+                  << ex.what();
+  }
+};
+
+class SSLHandshakeClient : public SSLHandshakeBase {
+ public:
+  SSLHandshakeClient(
+   AsyncSSLSocket::UniquePtr socket,
+   bool preverifyResult,
+   bool verifyResult) :
+    SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+    socket_->sslConn(this, 0);
+  }
+};
+
+class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
+ public:
+  SSLHandshakeClientNoVerify(
+   AsyncSSLSocket::UniquePtr socket,
+   bool preverifyResult,
+   bool verifyResult) :
+    SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+    socket_->sslConn(this, 0,
+      folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+  }
+};
+
+class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
+ public:
+  SSLHandshakeClientDoVerify(
+   AsyncSSLSocket::UniquePtr socket,
+   bool preverifyResult,
+   bool verifyResult) :
+    SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+    socket_->sslConn(this, 0,
+      folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
+  }
+};
+
+class SSLHandshakeServer : public SSLHandshakeBase {
+ public:
+  SSLHandshakeServer(
+      AsyncSSLSocket::UniquePtr socket,
+      bool preverifyResult,
+      bool verifyResult)
+    : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+    socket_->sslAccept(this, 0);
+  }
+};
+
+class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
+ public:
+  SSLHandshakeServerParseClientHello(
+      AsyncSSLSocket::UniquePtr socket,
+      bool preverifyResult,
+      bool verifyResult)
+      : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+    socket_->enableClientHelloParsing();
+    socket_->sslAccept(this, 0);
+  }
+
+  std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
+
+ protected:
+  void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
+    handshakeSuccess_ = true;
+    sock->getSSLSharedCiphers(sharedCiphers_);
+    sock->getSSLServerCiphers(serverCiphers_);
+    sock->getSSLClientCiphers(clientCiphers_);
+    chosenCipher_ = sock->getNegotiatedCipherName();
+  }
+};
+
+
+class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
+ public:
+  SSLHandshakeServerNoVerify(
+      AsyncSSLSocket::UniquePtr socket,
+      bool preverifyResult,
+      bool verifyResult)
+    : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+    socket_->sslAccept(this, 0,
+      folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+  }
+};
+
+class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
+ public:
+  SSLHandshakeServerDoVerify(
+      AsyncSSLSocket::UniquePtr socket,
+      bool preverifyResult,
+      bool verifyResult)
+    : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+    socket_->sslAccept(this, 0,
+      folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+  }
+};
+
+class EventBaseAborter : public AsyncTimeout {
+ public:
+  EventBaseAborter(EventBase* eventBase,
+                   uint32_t timeoutMS)
+    : AsyncTimeout(
+      eventBase, AsyncTimeout::InternalEnum::INTERNAL)
+    , eventBase_(eventBase) {
+    scheduleTimeout(timeoutMS);
+  }
+
+  void timeoutExpired() noexcept override {
+    FAIL() << "test timed out";
+    eventBase_->terminateLoopSoon();
+  }
+
+ private:
+  EventBase* eventBase_;
+};
+
+}
diff --git a/folly/io/async/test/AsyncSSLSocketTest2.cpp b/folly/io/async/test/AsyncSSLSocketTest2.cpp
new file mode 100644 (file)
index 0000000..5f4818e
--- /dev/null
@@ -0,0 +1,147 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/async/test/AsyncSSLSocketTest.h>
+
+#include <gtest/gtest.h>
+#include <pthread.h>
+
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/EventBase.h>
+
+using std::string;
+using std::vector;
+using std::min;
+using std::cerr;
+using std::endl;
+using std::list;
+
+namespace folly {
+
+class AttachDetachClient : public AsyncSocket::ConnectCallback,
+                           public AsyncTransportWrapper::WriteCallback,
+                           public AsyncTransportWrapper::ReadCallback {
+ private:
+  EventBase *eventBase_;
+  std::shared_ptr<AsyncSSLSocket> sslSocket_;
+  std::shared_ptr<SSLContext> ctx_;
+  folly::SocketAddress address_;
+  char buf_[128];
+  char readbuf_[128];
+  uint32_t bytesRead_;
+ public:
+  AttachDetachClient(EventBase *eventBase, const folly::SocketAddress& address)
+      : eventBase_(eventBase), address_(address), bytesRead_(0) {
+    ctx_.reset(new SSLContext());
+    ctx_->setOptions(SSL_OP_NO_TICKET);
+    ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  }
+
+  void connect() {
+    sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_);
+    sslSocket_->connect(this, address_);
+  }
+
+  void connectSuccess() noexcept override {
+    cerr << "client SSL socket connected" << endl;
+
+    for (int i = 0; i < 1000; ++i) {
+      sslSocket_->detachSSLContext();
+      sslSocket_->attachSSLContext(ctx_);
+    }
+
+    EXPECT_EQ(ctx_->getSSLCtx()->references, 2);
+
+    sslSocket_->write(this, buf_, sizeof(buf_));
+    sslSocket_->setReadCB(this);
+    memset(readbuf_, 'b', sizeof(readbuf_));
+    bytesRead_ = 0;
+  }
+
+  void connectErr(const AsyncSocketException& ex) noexcept override
+  {
+    cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
+    sslSocket_.reset();
+  }
+
+  void writeSuccess() noexcept override {
+    cerr << "client write success" << endl;
+  }
+
+  void writeErr(size_t bytesWritten, const AsyncSocketException& ex)
+    noexcept override {
+    cerr << "client writeError: " << ex.what() << endl;
+  }
+
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *bufReturn = readbuf_ + bytesRead_;
+    *lenReturn = sizeof(readbuf_) - bytesRead_;
+  }
+  void readEOF() noexcept override {
+    cerr << "client readEOF" << endl;
+  }
+
+  void readErr(const AsyncSocketException& ex) noexcept override {
+    cerr << "client readError: " << ex.what() << endl;
+  }
+
+  void readDataAvailable(size_t len) noexcept override {
+    cerr << "client read data: " << len << endl;
+    bytesRead_ += len;
+    if (len == sizeof(buf_)) {
+      EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
+      sslSocket_->closeNow();
+    }
+  }
+};
+
+/**
+ * Test passing contexts between threads
+ */
+TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  EventBase eventBase;
+  EventBaseAborter eba(&eventBase, 3000);
+  std::shared_ptr<AttachDetachClient> client(
+    new AttachDetachClient(&eventBase, server.getAddress()));
+
+  client->connect();
+  eventBase.loop();
+}
+
+}
+///////////////////////////////////////////////////////////////////////////
+// init_unit_test_suite
+///////////////////////////////////////////////////////////////////////////
+
+namespace {
+using folly::SSLContext;
+struct Initializer {
+  Initializer() {
+    signal(SIGPIPE, SIG_IGN);
+    SSLContext::setSSLLockTypes({
+        {CRYPTO_LOCK_EVP_PKEY, SSLContext::LOCK_NONE},
+        {CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_SPINLOCK},
+        {CRYPTO_LOCK_SSL_CTX, SSLContext::LOCK_NONE}});
+  }
+};
+Initializer initializer;
+} // anonymous
diff --git a/folly/io/async/test/AsyncSSLSocketWriteTest.cpp b/folly/io/async/test/AsyncSSLSocketWriteTest.cpp
new file mode 100644 (file)
index 0000000..fd9dac7
--- /dev/null
@@ -0,0 +1,396 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/Foreach.h>
+#include <folly/io/Cursor.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/EventBase.h>
+
+#include <gtest/gtest.h>
+#include <gmock/gmock.h>
+#include <string>
+#include <vector>
+
+using std::string;
+using namespace testing;
+
+namespace folly {
+
+class MockAsyncSSLSocket : public AsyncSSLSocket{
+ public:
+  static std::shared_ptr<MockAsyncSSLSocket> newSocket(
+    const std::shared_ptr<SSLContext>& ctx,
+    EventBase* evb) {
+    auto sock = std::shared_ptr<MockAsyncSSLSocket>(
+      new MockAsyncSSLSocket(ctx, evb),
+      Destructor());
+    sock->ssl_ = SSL_new(ctx->getSSLCtx());
+    SSL_set_fd(sock->ssl_, -1);
+    return sock;
+  }
+
+  // Fake constructor sets the state to established without call to connect
+  // or accept
+  MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx,
+                      EventBase* evb)
+      : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
+    state_ = AsyncSocket::StateEnum::ESTABLISHED;
+    sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
+  }
+
+  // mock the calls to SSL_write to see the buffer length and contents
+  MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n));
+
+  // mock the calls to getRawBytesWritten()
+  MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
+
+  // public wrapper for protected interface
+  ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags,
+                           uint32_t* countWritten, uint32_t* partialWritten) {
+    return performWrite(vec, count, flags, countWritten, partialWritten);
+  }
+
+  void checkEor(size_t appEor, size_t rawEor) {
+    EXPECT_EQ(appEor, appEorByteNo_);
+    EXPECT_EQ(rawEor, minEorRawByteNo_);
+  }
+
+  void setAppBytesWritten(size_t n) {
+    appBytesWritten_ = n;
+  }
+};
+
+class AsyncSSLSocketWriteTest : public testing::Test {
+ public:
+  AsyncSSLSocketWriteTest() :
+      sslContext_(new SSLContext()),
+      sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
+    for (int i = 0; i < 500; i++) {
+      memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
+    }
+  }
+
+  // Make an iovec containing chunks of the reference text with requested sizes
+  // for each chunk
+  iovec *makeVec(std::vector<uint32_t> sizes) {
+    iovec *vec = new iovec[sizes.size()];
+    int i = 0;
+    int pos = 0;
+    for (auto size: sizes) {
+      vec[i].iov_base = (void *)(source_ + pos);
+      vec[i++].iov_len = size;
+      pos += size;
+    }
+    return vec;
+  }
+
+  // Verify that the given buf/pos matches the reference text
+  void verifyVec(const void *buf, int n, int pos) {
+    ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
+  }
+
+  // Update a vec on partial write
+  void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) {
+    vec[countWritten].iov_base =
+      ((char *)vec[countWritten].iov_base) + partialWritten;
+    vec[countWritten].iov_len -= partialWritten;
+  }
+
+  EventBase eventBase_;
+  std::shared_ptr<SSLContext> sslContext_;
+  std::shared_ptr<MockAsyncSSLSocket> sock_;
+  char source_[26 * 500];
+};
+
+
+// The entire vec fits in one packet
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
+  int n = 3;
+  iovec *vec = makeVec({3, 3, 3});
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
+    .WillOnce(Invoke([this] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, 0);
+          return 9; }));
+  uint32_t countWritten = 0;
+  uint32_t partialWritten = 0;
+  sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+                          &partialWritten);
+  EXPECT_EQ(countWritten, n);
+  EXPECT_EQ(partialWritten, 0);
+}
+
+// First packet is full, second two go in one packet
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
+  int n = 3;
+  iovec *vec = makeVec({1500, 3, 3});
+  int pos = 0;
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+  uint32_t countWritten = 0;
+  uint32_t partialWritten = 0;
+  sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+                          &partialWritten);
+  EXPECT_EQ(countWritten, n);
+  EXPECT_EQ(partialWritten, 0);
+}
+
+// Two exactly full packets (coalesce ends midway through second chunk)
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
+  int n = 3;
+  iovec *vec = makeVec({1000, 1000, 1000});
+  int pos = 0;
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+    .Times(2)
+    .WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+  uint32_t countWritten = 0;
+  uint32_t partialWritten = 0;
+  sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+                          &partialWritten);
+  EXPECT_EQ(countWritten, n);
+  EXPECT_EQ(partialWritten, 0);
+}
+
+// Partial write success midway through a coalesced vec
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
+  int n = 5;
+  iovec *vec = makeVec({300, 300, 300, 300, 300});
+  int pos = 0;
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, pos);
+          pos += 1000;
+          return 1000; /* 500 bytes "pending" */ }));
+  uint32_t countWritten = 0;
+  uint32_t partialWritten = 0;
+  sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+                          &partialWritten);
+  EXPECT_EQ(countWritten, 3);
+  EXPECT_EQ(partialWritten, 100);
+  consumeVec(vec, countWritten, partialWritten);
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, pos);
+          pos += n;
+          return 500; }));
+  sock_->testPerformWrite(vec + countWritten, n - countWritten,
+                          WriteFlags::NONE,
+                          &countWritten, &partialWritten);
+  EXPECT_EQ(countWritten, 2);
+  EXPECT_EQ(partialWritten, 0);
+}
+
+// coalesce ends exactly on a buffer boundary
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
+  int n = 3;
+  iovec *vec = makeVec({1000, 500, 500});
+  int pos = 0;
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+  uint32_t countWritten = 0;
+  uint32_t partialWritten = 0;
+  sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+                          &partialWritten);
+  EXPECT_EQ(countWritten, 3);
+  EXPECT_EQ(partialWritten, 0);
+}
+
+// partial write midway through first chunk
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
+  int n = 2;
+  iovec *vec = makeVec({1000, 500});
+  int pos = 0;
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, pos);
+          pos += 700;
+          return 700; }));
+  uint32_t countWritten = 0;
+  uint32_t partialWritten = 0;
+  sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+                          &partialWritten);
+  EXPECT_EQ(countWritten, 0);
+  EXPECT_EQ(partialWritten, 700);
+  consumeVec(vec, countWritten, partialWritten);
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+  sock_->testPerformWrite(vec + countWritten, n - countWritten,
+                          WriteFlags::NONE,
+                          &countWritten, &partialWritten);
+  EXPECT_EQ(countWritten, 2);
+  EXPECT_EQ(partialWritten, 0);
+}
+
+// Repeat coalescing2 with WriteFlags::EOR
+TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
+  int n = 3;
+  iovec *vec = makeVec({1500, 3, 3});
+  int pos = 0;
+  const size_t initAppBytesWritten = 500;
+  const size_t appEor = initAppBytesWritten + 1506;
+
+  sock_->setAppBytesWritten(initAppBytesWritten);
+  EXPECT_FALSE(sock_->isEorTrackingEnabled());
+  sock_->setEorTracking(true);
+  EXPECT_TRUE(sock_->isEorTrackingEnabled());
+
+  EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
+    // rawBytesWritten after writting initAppBytesWritten + 1500
+    // + some random SSL overhead
+    .WillOnce(Return(3600))
+    // rawBytesWritten after writting last 6 bytes
+    // + some random SSL overhead
+    .WillOnce(Return(3728));
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+    .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
+          // the first 1500 does not have the EOR byte
+          sock_->checkEor(0, 0);
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
+    .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
+          sock_->checkEor(appEor, 3600 + n);
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+
+  uint32_t countWritten = 0;
+  uint32_t partialWritten = 0;
+  sock_->testPerformWrite(vec, n , WriteFlags::EOR,
+                          &countWritten, &partialWritten);
+  EXPECT_EQ(countWritten, n);
+  EXPECT_EQ(partialWritten, 0);
+  sock_->checkEor(0, 0);
+}
+
+// coalescing with left over at the last chunk
+// WriteFlags::EOR turned on
+TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
+  int n = 3;
+  iovec *vec = makeVec({600, 600, 600});
+  int pos = 0;
+  const size_t initAppBytesWritten = 500;
+  const size_t appEor = initAppBytesWritten + 1800;
+
+  sock_->setAppBytesWritten(initAppBytesWritten);
+  sock_->setEorTracking(true);
+
+  EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
+    // rawBytesWritten after writting initAppBytesWritten +  1500 bytes
+    // + some random SSL overhead
+    .WillOnce(Return(3600))
+    // rawBytesWritten after writting last 300 bytes
+    // + some random SSL overhead
+    .WillOnce(Return(4100));
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+    .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
+          // the first 1500 does not have the EOR byte
+          sock_->checkEor(0, 0);
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
+    .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
+          sock_->checkEor(appEor, 3600 + n);
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+
+  uint32_t countWritten = 0;
+  uint32_t partialWritten = 0;
+  sock_->testPerformWrite(vec, n, WriteFlags::EOR,
+                          &countWritten, &partialWritten);
+  EXPECT_EQ(countWritten, n);
+  EXPECT_EQ(partialWritten, 0);
+  sock_->checkEor(0, 0);
+}
+
+// WriteFlags::EOR set
+// One buf in iovec
+// Partial write at 1000-th byte
+TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
+  int n = 1;
+  iovec *vec = makeVec({1600});
+  int pos = 0;
+  const size_t initAppBytesWritten = 500;
+  const size_t appEor = initAppBytesWritten + 1600;
+
+  sock_->setAppBytesWritten(initAppBytesWritten);
+  sock_->setEorTracking(true);
+
+  EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
+    // rawBytesWritten after the initAppBytesWritten
+    // + some random SSL overhead
+    .WillOnce(Return(2000))
+    // rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead)
+    // + some random SSL overhead
+    .WillOnce(Return(3100));
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          sock_->checkEor(appEor, 2000 + n);
+          verifyVec(buf, n, pos);
+          pos += 1000;
+          return 1000; }));
+
+  uint32_t countWritten = 0;
+  uint32_t partialWritten = 0;
+  sock_->testPerformWrite(vec, n, WriteFlags::EOR,
+                          &countWritten, &partialWritten);
+  EXPECT_EQ(countWritten, 0);
+  EXPECT_EQ(partialWritten, 1000);
+  sock_->checkEor(appEor, 2000 + 1600);
+  consumeVec(vec, countWritten, partialWritten);
+
+  EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
+    .WillOnce(Return(3100))
+    .WillOnce(Return(3800));
+  EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
+    .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+          sock_->checkEor(appEor, 3100 + n);
+          verifyVec(buf, n, pos);
+          pos += n;
+          return n; }));
+  sock_->testPerformWrite(vec + countWritten, n - countWritten,
+                          WriteFlags::EOR,
+                          &countWritten, &partialWritten);
+  EXPECT_EQ(countWritten, n);
+  EXPECT_EQ(partialWritten, 0);
+  sock_->checkEor(0, 0);
+}
+
+}
diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp
new file mode 100644 (file)
index 0000000..147bec9
--- /dev/null
@@ -0,0 +1,2103 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/SocketAddress.h>
+
+#include <folly/io/IOBuf.h>
+#include <folly/io/async/test/BlockingSocket.h>
+#include <folly/io/async/test/Util.h>
+
+#include <gtest/gtest.h>
+#include <boost/scoped_array.hpp>
+#include <iostream>
+#include <unistd.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/tcp.h>
+#include <thread>
+
+using namespace boost;
+
+using std::string;
+using std::vector;
+using std::min;
+using std::cerr;
+using std::endl;
+using std::unique_ptr;
+using std::chrono::milliseconds;
+using boost::scoped_array;
+
+using namespace folly;
+
+enum StateEnum {
+  STATE_WAITING,
+  STATE_SUCCEEDED,
+  STATE_FAILED
+};
+
+typedef std::function<void()> VoidCallback;
+
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+  ConnCallback()
+    : state(STATE_WAITING)
+    , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+  void connectSuccess() noexcept override {
+    state = STATE_SUCCEEDED;
+    if (successCallback) {
+      successCallback();
+    }
+  }
+
+  void connectErr(const AsyncSocketException& ex) noexcept override {
+    state = STATE_FAILED;
+    exception = ex;
+    if (errorCallback) {
+      errorCallback();
+    }
+  }
+
+  StateEnum state;
+  AsyncSocketException exception;
+  VoidCallback successCallback;
+  VoidCallback errorCallback;
+};
+
+class WriteCallback : public AsyncTransportWrapper::WriteCallback {
+ public:
+  WriteCallback()
+    : state(STATE_WAITING)
+    , bytesWritten(0)
+    , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+  void writeSuccess() noexcept override {
+    state = STATE_SUCCEEDED;
+    if (successCallback) {
+      successCallback();
+    }
+  }
+
+  void writeErr(size_t bytesWritten,
+                const AsyncSocketException& ex) noexcept override {
+    state = STATE_FAILED;
+    this->bytesWritten = bytesWritten;
+    exception = ex;
+    if (errorCallback) {
+      errorCallback();
+    }
+  }
+
+  StateEnum state;
+  size_t bytesWritten;
+  AsyncSocketException exception;
+  VoidCallback successCallback;
+  VoidCallback errorCallback;
+};
+
+class ReadCallback : public AsyncTransportWrapper::ReadCallback {
+ public:
+  ReadCallback()
+    : state(STATE_WAITING)
+    , exception(AsyncSocketException::UNKNOWN, "none")
+    , buffers() {}
+
+  ~ReadCallback() {
+    for (vector<Buffer>::iterator it = buffers.begin();
+         it != buffers.end();
+         ++it) {
+      it->free();
+    }
+    currentBuffer.free();
+  }
+
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    if (!currentBuffer.buffer) {
+      currentBuffer.allocate(4096);
+    }
+    *bufReturn = currentBuffer.buffer;
+    *lenReturn = currentBuffer.length;
+  }
+
+  void readDataAvailable(size_t len) noexcept override {
+    currentBuffer.length = len;
+    buffers.push_back(currentBuffer);
+    currentBuffer.reset();
+    if (dataAvailableCallback) {
+      dataAvailableCallback();
+    }
+  }
+
+  void readEOF() noexcept override {
+    state = STATE_SUCCEEDED;
+  }
+
+  void readErr(const AsyncSocketException& ex) noexcept override {
+    state = STATE_FAILED;
+    exception = ex;
+  }
+
+  void verifyData(const char* expected, size_t expectedLen) const {
+    size_t offset = 0;
+    for (size_t idx = 0; idx < buffers.size(); ++idx) {
+      const auto& buf = buffers[idx];
+      size_t cmpLen = std::min(buf.length, expectedLen - offset);
+      CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
+      CHECK_EQ(cmpLen, buf.length);
+      offset += cmpLen;
+    }
+    CHECK_EQ(offset, expectedLen);
+  }
+
+  class Buffer {
+   public:
+    Buffer() : buffer(nullptr), length(0) {}
+    Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
+
+    void reset() {
+      buffer = nullptr;
+      length = 0;
+    }
+    void allocate(size_t length) {
+      assert(buffer == nullptr);
+      this->buffer = static_cast<char*>(malloc(length));
+      this->length = length;
+    }
+    void free() {
+      ::free(buffer);
+      reset();
+    }
+
+    char* buffer;
+    size_t length;
+  };
+
+  StateEnum state;
+  AsyncSocketException exception;
+  vector<Buffer> buffers;
+  Buffer currentBuffer;
+  VoidCallback dataAvailableCallback;
+};
+
+class ReadVerifier {
+};
+
+class TestServer {
+ public:
+  // Create a TestServer.
+  // This immediately starts listening on an ephemeral port.
+  TestServer()
+    : fd_(-1) {
+    fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
+    if (fd_ < 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "failed to create test server socket", errno);
+    }
+    if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "failed to put test server socket in "
+                                "non-blocking mode", errno);
+    }
+    if (listen(fd_, 10) != 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "failed to listen on test server socket",
+                                errno);
+    }
+
+    address_.setFromLocalAddress(fd_);
+    // The local address will contain 0.0.0.0.
+    // Change it to 127.0.0.1, so it can be used to connect to the server
+    address_.setFromIpPort("127.0.0.1", address_.getPort());
+  }
+
+  // Get the address for connecting to the server
+  const folly::SocketAddress& getAddress() const {
+    return address_;
+  }
+
+  int acceptFD(int timeout=50) {
+    struct pollfd pfd;
+    pfd.fd = fd_;
+    pfd.events = POLLIN;
+    int ret = poll(&pfd, 1, timeout);
+    if (ret == 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "test server accept() timed out");
+    } else if (ret < 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "test server accept() poll failed", errno);
+    }
+
+    int acceptedFd = ::accept(fd_, nullptr, nullptr);
+    if (acceptedFd < 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "test server accept() failed", errno);
+    }
+
+    return acceptedFd;
+  }
+
+  std::shared_ptr<BlockingSocket> accept(int timeout=50) {
+    int fd = acceptFD(timeout);
+    return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
+  }
+
+  std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
+    int fd = acceptFD(timeout);
+    return AsyncSocket::newSocket(evb, fd);
+  }
+
+  /**
+   * Accept a connection, read data from it, and verify that it matches the
+   * data in the specified buffer.
+   */
+  void verifyConnection(const char* buf, size_t len) {
+    // accept a connection
+    std::shared_ptr<BlockingSocket> acceptedSocket = accept();
+    // read the data and compare it to the specified buffer
+    scoped_array<uint8_t> readbuf(new uint8_t[len]);
+    acceptedSocket->readAll(readbuf.get(), len);
+    CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
+    // make sure we get EOF next
+    uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
+    CHECK_EQ(bytesRead, 0);
+  }
+
+ private:
+  int fd_;
+  folly::SocketAddress address_;
+};
+
+class DelayedWrite: public AsyncTimeout {
+ public:
+  DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
+      unique_ptr<IOBuf>&& bufs, AsyncTransportWrapper::WriteCallback* wcb,
+      bool cork, bool lastWrite = false):
+    AsyncTimeout(socket->getEventBase()),
+    socket_(socket),
+    bufs_(std::move(bufs)),
+    wcb_(wcb),
+    cork_(cork),
+    lastWrite_(lastWrite) {}
+
+ private:
+  void timeoutExpired() noexcept override {
+    WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
+    socket_->writeChain(wcb_, std::move(bufs_), flags);
+    if (lastWrite_) {
+      socket_->shutdownWrite();
+    }
+  }
+
+  std::shared_ptr<AsyncSocket> socket_;
+  unique_ptr<IOBuf> bufs_;
+  AsyncTransportWrapper::WriteCallback* wcb_;
+  bool cork_;
+  bool lastWrite_;
+};
+
+///////////////////////////////////////////////////////////////////////////
+// connect() tests
+///////////////////////////////////////////////////////////////////////////
+
+/**
+ * Test connecting to a server
+ */
+TEST(AsyncSocketTest, Connect) {
+  // Start listening on a local port
+  TestServer server;
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback cb;
+  socket->connect(&cb, server.getAddress(), 30);
+
+  evb.loop();
+
+  CHECK_EQ(cb.state, STATE_SUCCEEDED);
+}
+
+/**
+ * Test connecting to a server that isn't listening
+ */
+TEST(AsyncSocketTest, ConnectRefused) {
+  EventBase evb;
+
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+  // Hopefully nothing is actually listening on this address
+  folly::SocketAddress addr("127.0.0.1", 65535);
+  ConnCallback cb;
+  socket->connect(&cb, addr, 30);
+
+  evb.loop();
+
+  CHECK_EQ(cb.state, STATE_FAILED);
+  CHECK_EQ(cb.exception.getType(), AsyncSocketException::NOT_OPEN);
+}
+
+/**
+ * Test connection timeout
+ */
+TEST(AsyncSocketTest, ConnectTimeout) {
+  EventBase evb;
+
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+  // Try connecting to server that won't respond.
+  //
+  // This depends somewhat on the network where this test is run.
+  // Hopefully this IP will be routable but unresponsive.
+  // (Alternatively, we could try listening on a local raw socket, but that
+  // normally requires root privileges.)
+  folly::SocketAddress addr("8.8.8.8", 65535);
+  ConnCallback cb;
+  socket->connect(&cb, addr, 1); // also set a ridiculously small timeout
+
+  evb.loop();
+
+  CHECK_EQ(cb.state, STATE_FAILED);
+  CHECK_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);
+
+  // Verify that we can still get the peer address after a timeout.
+  // Use case is if the client was created from a client pool, and we want
+  // to log which peer failed.
+  folly::SocketAddress peer;
+  socket->getPeerAddress(&peer);
+  CHECK_EQ(peer, addr);
+}
+
+/**
+ * Test writing immediately after connecting, without waiting for connect
+ * to finish.
+ */
+TEST(AsyncSocketTest, ConnectAndWrite) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // write()
+  char buf[128];
+  memset(buf, 'a', sizeof(buf));
+  WriteCallback wcb;
+  socket->write(&wcb, buf, sizeof(buf));
+
+  // Loop.  We don't bother accepting on the server socket yet.
+  // The kernel should be able to buffer the write request so it can succeed.
+  evb.loop();
+
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+  // Make sure the server got a connection and received the data
+  socket->close();
+  server.verifyConnection(buf, sizeof(buf));
+}
+
+/**
+ * Test connecting using a nullptr connect callback.
+ */
+TEST(AsyncSocketTest, ConnectNullCallback) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->connect(nullptr, server.getAddress(), 30);
+
+  // write some data, just so we have some way of verifing
+  // that the socket works correctly after connecting
+  char buf[128];
+  memset(buf, 'a', sizeof(buf));
+  WriteCallback wcb;
+  socket->write(&wcb, buf, sizeof(buf));
+
+  evb.loop();
+
+  CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+  // Make sure the server got a connection and received the data
+  socket->close();
+  server.verifyConnection(buf, sizeof(buf));
+}
+
+/**
+ * Test calling both write() and close() immediately after connecting, without
+ * waiting for connect to finish.
+ *
+ * This exercises the STATE_CONNECTING_CLOSING code.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndClose) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // write()
+  char buf[128];
+  memset(buf, 'a', sizeof(buf));
+  WriteCallback wcb;
+  socket->write(&wcb, buf, sizeof(buf));
+
+  // close()
+  socket->close();
+
+  // Loop.  We don't bother accepting on the server socket yet.
+  // The kernel should be able to buffer the write request so it can succeed.
+  evb.loop();
+
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+  // Make sure the server got a connection and received the data
+  server.verifyConnection(buf, sizeof(buf));
+}
+
+/**
+ * Test calling close() immediately after connect()
+ */
+TEST(AsyncSocketTest, ConnectAndClose) {
+  TestServer server;
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Hopefully the connect didn't succeed immediately.
+  // If it did, we can't exercise the close-while-connecting code path.
+  if (ccb.state == STATE_SUCCEEDED) {
+    LOG(INFO) << "connect() succeeded immediately; aborting test "
+                       "of close-during-connect behavior";
+    return;
+  }
+
+  socket->close();
+
+  // Loop, although there shouldn't be anything to do.
+  evb.loop();
+
+  // Make sure the connection was aborted
+  CHECK_EQ(ccb.state, STATE_FAILED);
+}
+
+/**
+ * Test calling closeNow() immediately after connect()
+ *
+ * This should be identical to the normal close behavior.
+ */
+TEST(AsyncSocketTest, ConnectAndCloseNow) {
+  TestServer server;
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Hopefully the connect didn't succeed immediately.
+  // If it did, we can't exercise the close-while-connecting code path.
+  if (ccb.state == STATE_SUCCEEDED) {
+    LOG(INFO) << "connect() succeeded immediately; aborting test "
+                       "of closeNow()-during-connect behavior";
+    return;
+  }
+
+  socket->closeNow();
+
+  // Loop, although there shouldn't be anything to do.
+  evb.loop();
+
+  // Make sure the connection was aborted
+  CHECK_EQ(ccb.state, STATE_FAILED);
+}
+
+/**
+ * Test calling both write() and closeNow() immediately after connecting,
+ * without waiting for connect to finish.
+ *
+ * This should abort the pending write.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Hopefully the connect didn't succeed immediately.
+  // If it did, we can't exercise the close-while-connecting code path.
+  if (ccb.state == STATE_SUCCEEDED) {
+    LOG(INFO) << "connect() succeeded immediately; aborting test "
+                       "of write-during-connect behavior";
+    return;
+  }
+
+  // write()
+  char buf[128];
+  memset(buf, 'a', sizeof(buf));
+  WriteCallback wcb;
+  socket->write(&wcb, buf, sizeof(buf));
+
+  // close()
+  socket->closeNow();
+
+  // Loop, although there shouldn't be anything to do.
+  evb.loop();
+
+  CHECK_EQ(ccb.state, STATE_FAILED);
+  CHECK_EQ(wcb.state, STATE_FAILED);
+}
+
+/**
+ * Test installing a read callback immediately, before connect() finishes.
+ */
+TEST(AsyncSocketTest, ConnectAndRead) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  // Even though we haven't looped yet, we should be able to accept
+  // the connection and send data to it.
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+  uint8_t buf[128];
+  memset(buf, 'a', sizeof(buf));
+  acceptedSocket->write(buf, sizeof(buf));
+  acceptedSocket->flush();
+  acceptedSocket->close();
+
+  // Loop, although there shouldn't be anything to do.
+  evb.loop();
+
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.buffers.size(), 1);
+  CHECK_EQ(rcb.buffers[0].length, sizeof(buf));
+  CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
+}
+
+/**
+ * Test installing a read callback and then closing immediately before the
+ * connect attempt finishes.
+ */
+TEST(AsyncSocketTest, ConnectReadAndClose) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Hopefully the connect didn't succeed immediately.
+  // If it did, we can't exercise the close-while-connecting code path.
+  if (ccb.state == STATE_SUCCEEDED) {
+    LOG(INFO) << "connect() succeeded immediately; aborting test "
+                       "of read-during-connect behavior";
+    return;
+  }
+
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  // close()
+  socket->close();
+
+  // Loop, although there shouldn't be anything to do.
+  evb.loop();
+
+  CHECK_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
+  CHECK_EQ(rcb.buffers.size(), 0);
+  CHECK_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
+}
+
+/**
+ * Test both writing and installing a read callback immediately,
+ * before connect() finishes.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndRead) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // write()
+  char buf1[128];
+  memset(buf1, 'a', sizeof(buf1));
+  WriteCallback wcb;
+  socket->write(&wcb, buf1, sizeof(buf1));
+
+  // set a read callback
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  // Even though we haven't looped yet, we should be able to accept
+  // the connection and send data to it.
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+  uint8_t buf2[128];
+  memset(buf2, 'b', sizeof(buf2));
+  acceptedSocket->write(buf2, sizeof(buf2));
+  acceptedSocket->flush();
+
+  // shut down the write half of acceptedSocket, so that the AsyncSocket
+  // will stop reading and we can break out of the event loop.
+  shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
+
+  // Loop
+  evb.loop();
+
+  // Make sure the connect succeeded
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+  // Make sure the AsyncSocket read the data written by the accepted socket
+  CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.buffers.size(), 1);
+  CHECK_EQ(rcb.buffers[0].length, sizeof(buf2));
+  CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);
+
+  // Close the AsyncSocket so we'll see EOF on acceptedSocket
+  socket->close();
+
+  // Make sure the accepted socket saw the data written by the AsyncSocket
+  uint8_t readbuf[sizeof(buf1)];
+  acceptedSocket->readAll(readbuf, sizeof(readbuf));
+  CHECK_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
+  uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
+  CHECK_EQ(bytesRead, 0);
+}
+
+/**
+ * Test writing to the socket then shutting down writes before the connect
+ * attempt finishes.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Hopefully the connect didn't succeed immediately.
+  // If it did, we can't exercise the write-while-connecting code path.
+  if (ccb.state == STATE_SUCCEEDED) {
+    LOG(INFO) << "connect() succeeded immediately; skipping test";
+    return;
+  }
+
+  // Ask to write some data
+  char wbuf[128];
+  memset(wbuf, 'a', sizeof(wbuf));
+  WriteCallback wcb;
+  socket->write(&wcb, wbuf, sizeof(wbuf));
+  socket->shutdownWrite();
+
+  // Shutdown writes
+  socket->shutdownWrite();
+
+  // Even though we haven't looped yet, we should be able to accept
+  // the connection.
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+  // Since the connection is still in progress, there should be no data to
+  // read yet.  Verify that the accepted socket is not readable.
+  struct pollfd fds[1];
+  fds[0].fd = acceptedSocket->getSocketFD();
+  fds[0].events = POLLIN;
+  fds[0].revents = 0;
+  int rc = poll(fds, 1, 0);
+  CHECK_EQ(rc, 0);
+
+  // Write data to the accepted socket
+  uint8_t acceptedWbuf[192];
+  memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
+  acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
+  acceptedSocket->flush();
+
+  // Loop
+  evb.loop();
+
+  // The loop should have completed the connection, written the queued data,
+  // and shutdown writes on the socket.
+  //
+  // Check that the connection was completed successfully and that the write
+  // callback succeeded.
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+  // Check that we can read the data that was written to the socket, and that
+  // we see an EOF, since its socket was half-shutdown.
+  uint8_t readbuf[sizeof(wbuf)];
+  acceptedSocket->readAll(readbuf, sizeof(readbuf));
+  CHECK_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
+  uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
+  CHECK_EQ(bytesRead, 0);
+
+  // Close the accepted socket.  This will cause it to see EOF
+  // and uninstall the read callback when we loop next.
+  acceptedSocket->close();
+
+  // Install a read callback, then loop again.
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+  evb.loop();
+
+  // This loop should have read the data and seen the EOF
+  CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.buffers.size(), 1);
+  CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
+  CHECK_EQ(memcmp(rcb.buffers[0].buffer,
+                           acceptedWbuf, sizeof(acceptedWbuf)), 0);
+}
+
+/**
+ * Test reading, writing, and shutting down writes before the connect attempt
+ * finishes.
+ */
+TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Hopefully the connect didn't succeed immediately.
+  // If it did, we can't exercise the write-while-connecting code path.
+  if (ccb.state == STATE_SUCCEEDED) {
+    LOG(INFO) << "connect() succeeded immediately; skipping test";
+    return;
+  }
+
+  // Install a read callback
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  // Ask to write some data
+  char wbuf[128];
+  memset(wbuf, 'a', sizeof(wbuf));
+  WriteCallback wcb;
+  socket->write(&wcb, wbuf, sizeof(wbuf));
+
+  // Shutdown writes
+  socket->shutdownWrite();
+
+  // Even though we haven't looped yet, we should be able to accept
+  // the connection.
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+  // Since the connection is still in progress, there should be no data to
+  // read yet.  Verify that the accepted socket is not readable.
+  struct pollfd fds[1];
+  fds[0].fd = acceptedSocket->getSocketFD();
+  fds[0].events = POLLIN;
+  fds[0].revents = 0;
+  int rc = poll(fds, 1, 0);
+  CHECK_EQ(rc, 0);
+
+  // Write data to the accepted socket
+  uint8_t acceptedWbuf[192];
+  memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
+  acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
+  acceptedSocket->flush();
+  // Shutdown writes to the accepted socket.  This will cause it to see EOF
+  // and uninstall the read callback.
+  ::shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
+
+  // Loop
+  evb.loop();
+
+  // The loop should have completed the connection, written the queued data,
+  // shutdown writes on the socket, read the data we wrote to it, and see the
+  // EOF.
+  //
+  // Check that the connection was completed successfully and that the read
+  // and write callbacks were invoked as expected.
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.buffers.size(), 1);
+  CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
+  CHECK_EQ(memcmp(rcb.buffers[0].buffer,
+                           acceptedWbuf, sizeof(acceptedWbuf)), 0);
+  CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+  // Check that we can read the data that was written to the socket, and that
+  // we see an EOF, since its socket was half-shutdown.
+  uint8_t readbuf[sizeof(wbuf)];
+  acceptedSocket->readAll(readbuf, sizeof(readbuf));
+  CHECK_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
+  uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
+  CHECK_EQ(bytesRead, 0);
+
+  // Fully close both sockets
+  acceptedSocket->close();
+  socket->close();
+}
+
+/**
+ * Test reading, writing, and calling shutdownWriteNow() before the
+ * connect attempt finishes.
+ */
+TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Hopefully the connect didn't succeed immediately.
+  // If it did, we can't exercise the write-while-connecting code path.
+  if (ccb.state == STATE_SUCCEEDED) {
+    LOG(INFO) << "connect() succeeded immediately; skipping test";
+    return;
+  }
+
+  // Install a read callback
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  // Ask to write some data
+  char wbuf[128];
+  memset(wbuf, 'a', sizeof(wbuf));
+  WriteCallback wcb;
+  socket->write(&wcb, wbuf, sizeof(wbuf));
+
+  // Shutdown writes immediately.
+  // This should immediately discard the data that we just tried to write.
+  socket->shutdownWriteNow();
+
+  // Verify that writeError() was invoked on the write callback.
+  CHECK_EQ(wcb.state, STATE_FAILED);
+  CHECK_EQ(wcb.bytesWritten, 0);
+
+  // Even though we haven't looped yet, we should be able to accept
+  // the connection.
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+  // Since the connection is still in progress, there should be no data to
+  // read yet.  Verify that the accepted socket is not readable.
+  struct pollfd fds[1];
+  fds[0].fd = acceptedSocket->getSocketFD();
+  fds[0].events = POLLIN;
+  fds[0].revents = 0;
+  int rc = poll(fds, 1, 0);
+  CHECK_EQ(rc, 0);
+
+  // Write data to the accepted socket
+  uint8_t acceptedWbuf[192];
+  memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
+  acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
+  acceptedSocket->flush();
+  // Shutdown writes to the accepted socket.  This will cause it to see EOF
+  // and uninstall the read callback.
+  ::shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
+
+  // Loop
+  evb.loop();
+
+  // The loop should have completed the connection, written the queued data,
+  // shutdown writes on the socket, read the data we wrote to it, and see the
+  // EOF.
+  //
+  // Check that the connection was completed successfully and that the read
+  // callback was invoked as expected.
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.buffers.size(), 1);
+  CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
+  CHECK_EQ(memcmp(rcb.buffers[0].buffer,
+                           acceptedWbuf, sizeof(acceptedWbuf)), 0);
+
+  // Since we used shutdownWriteNow(), it should have discarded all pending
+  // write data.  Verify we see an immediate EOF when reading from the accepted
+  // socket.
+  uint8_t readbuf[sizeof(wbuf)];
+  uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
+  CHECK_EQ(bytesRead, 0);
+
+  // Fully close both sockets
+  acceptedSocket->close();
+  socket->close();
+}
+
+// Helper function for use in testConnectOptWrite()
+// Temporarily disable the read callback
+void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
+  // Uninstall the read callback
+  socket->setReadCB(nullptr);
+  // Schedule the read callback to be reinstalled after 1ms
+  socket->getEventBase()->runInLoop(
+      std::bind(&AsyncSocket::setReadCB, socket, rcb));
+}
+
+/**
+ * Test connect+write, then have the connect callback perform another write.
+ *
+ * This tests interaction of the optimistic writing after connect with
+ * additional write attempts that occur in the connect callback.
+ */
+void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
+  TestServer server;
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+  // connect()
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Hopefully the connect didn't succeed immediately.
+  // If it did, we can't exercise the optimistic write code path.
+  if (ccb.state == STATE_SUCCEEDED) {
+    LOG(INFO) << "connect() succeeded immediately; aborting test "
+                       "of optimistic write behavior";
+    return;
+  }
+
+  // Tell the connect callback to perform a write when the connect succeeds
+  WriteCallback wcb2;
+  scoped_array<char> buf2(new char[size2]);
+  memset(buf2.get(), 'b', size2);
+  if (size2 > 0) {
+    ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
+    // Tell the second write callback to close the connection when it is done
+    wcb2.successCallback = [&] { socket->closeNow(); };
+  }
+
+  // Schedule one write() immediately, before the connect finishes
+  scoped_array<char> buf1(new char[size1]);
+  memset(buf1.get(), 'a', size1);
+  WriteCallback wcb1;
+  if (size1 > 0) {
+    socket->write(&wcb1, buf1.get(), size1);
+  }
+
+  if (close) {
+    // immediately perform a close, before connect() completes
+    socket->close();
+  }
+
+  // Start reading from the other endpoint after 10ms.
+  // If we're using large buffers, we have to read so that the writes don't
+  // block forever.
+  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
+  ReadCallback rcb;
+  rcb.dataAvailableCallback = std::bind(tmpDisableReads,
+                                        acceptedSocket.get(), &rcb);
+  socket->getEventBase()->tryRunAfterDelay(
+      std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb),
+      10);
+
+  // Loop.  We don't bother accepting on the server socket yet.
+  // The kernel should be able to buffer the write request so it can succeed.
+  evb.loop();
+
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+  if (size1 > 0) {
+    CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
+  }
+  if (size2 > 0) {
+    CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
+  }
+
+  socket->close();
+
+  // Make sure the read callback received all of the data
+  size_t bytesRead = 0;
+  for (vector<ReadCallback::Buffer>::const_iterator it = rcb.buffers.begin();
+       it != rcb.buffers.end();
+       ++it) {
+    size_t start = bytesRead;
+    bytesRead += it->length;
+    size_t end = bytesRead;
+    if (start < size1) {
+      size_t cmpLen = min(size1, end) - start;
+      CHECK_EQ(memcmp(it->buffer, buf1.get() + start, cmpLen), 0);
+    }
+    if (end > size1 && end <= size1 + size2) {
+      size_t itOffset;
+      size_t buf2Offset;
+      size_t cmpLen;
+      if (start >= size1) {
+        itOffset = 0;
+        buf2Offset = start - size1;
+        cmpLen = end - start;
+      } else {
+        itOffset = size1 - start;
+        buf2Offset = 0;
+        cmpLen = end - size1;
+      }
+      CHECK_EQ(memcmp(it->buffer + itOffset, buf2.get() + buf2Offset,
+                               cmpLen),
+                        0);
+    }
+  }
+  CHECK_EQ(bytesRead, size1 + size2);
+}
+
+TEST(AsyncSocketTest, ConnectCallbackWrite) {
+  // Test using small writes that should both succeed immediately
+  testConnectOptWrite(100, 200);
+
+  // Test using a large buffer in the connect callback, that should block
+  const size_t largeSize = 8*1024*1024;
+  testConnectOptWrite(100, largeSize);
+
+  // Test using a large initial write
+  testConnectOptWrite(largeSize, 100);
+
+  // Test using two large buffers
+  testConnectOptWrite(largeSize, largeSize);
+
+  // Test a small write in the connect callback,
+  // but no immediate write before connect completes
+  testConnectOptWrite(0, 64);
+
+  // Test a large write in the connect callback,
+  // but no immediate write before connect completes
+  testConnectOptWrite(0, largeSize);
+
+  // Test connect, a small write, then immediately call close() before connect
+  // completes
+  testConnectOptWrite(211, 0, true);
+
+  // Test connect, a large immediate write (that will block), then immediately
+  // call close() before connect completes
+  testConnectOptWrite(largeSize, 0, true);
+}
+
+///////////////////////////////////////////////////////////////////////////
+// write() related tests
+///////////////////////////////////////////////////////////////////////////
+
+/**
+ * Test writing using a nullptr callback
+ */
+TEST(AsyncSocketTest, WriteNullCallback) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket =
+    AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+  evb.loop(); // loop until the socket is connected
+
+  // write() with a nullptr callback
+  char buf[128];
+  memset(buf, 'a', sizeof(buf));
+  socket->write(nullptr, buf, sizeof(buf));
+
+  evb.loop(); // loop until the data is sent
+
+  // Make sure the server got a connection and received the data
+  socket->close();
+  server.verifyConnection(buf, sizeof(buf));
+}
+
+/**
+ * Test writing with a send timeout
+ */
+TEST(AsyncSocketTest, WriteTimeout) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket =
+    AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+  evb.loop(); // loop until the socket is connected
+
+  // write() a large chunk of data, with no-one on the other end reading
+  size_t writeLength = 8*1024*1024;
+  uint32_t timeout = 200;
+  socket->setSendTimeout(timeout);
+  scoped_array<char> buf(new char[writeLength]);
+  memset(buf.get(), 'a', writeLength);
+  WriteCallback wcb;
+  socket->write(&wcb, buf.get(), writeLength);
+
+  TimePoint start;
+  evb.loop();
+  TimePoint end;
+
+  // Make sure the write attempt timed out as requested
+  CHECK_EQ(wcb.state, STATE_FAILED);
+  CHECK_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);
+
+  // Check that the write timed out within a reasonable period of time.
+  // We don't check for exactly the specified timeout, since AsyncSocket only
+  // times out when it hasn't made progress for that period of time.
+  //
+  // On linux, the first write sends a few hundred kb of data, then blocks for
+  // writability, and then unblocks again after 40ms and is able to write
+  // another smaller of data before blocking permanently.  Therefore it doesn't
+  // time out until 40ms + timeout.
+  //
+  // I haven't fully verified the cause of this, but I believe it probably
+  // occurs because the receiving end delays sending an ack for up to 40ms.
+  // (This is the default value for TCP_DELACK_MIN.)  Once the sender receives
+  // the ack, it can send some more data.  However, after that point the
+  // receiver's kernel buffer is full.  This 40ms delay happens even with
+  // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints.  However, the
+  // kernel may be automatically disabling TCP_QUICKACK after receiving some
+  // data.
+  //
+  // For now, we simply check that the timeout occurred within 160ms of
+  // the requested value.
+  T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
+}
+
+/**
+ * Test writing to a socket that the remote endpoint has closed
+ */
+TEST(AsyncSocketTest, WritePipeError) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket =
+    AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+  socket->setSendTimeout(1000);
+  evb.loop(); // loop until the socket is connected
+
+  // accept and immediately close the socket
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+  acceptedSocket.reset();
+
+  // write() a large chunk of data
+  size_t writeLength = 8*1024*1024;
+  scoped_array<char> buf(new char[writeLength]);
+  memset(buf.get(), 'a', writeLength);
+  WriteCallback wcb;
+  socket->write(&wcb, buf.get(), writeLength);
+
+  evb.loop();
+
+  // Make sure the write failed.
+  // It would be nice if AsyncSocketException could convey the errno value,
+  // so that we could check for EPIPE
+  CHECK_EQ(wcb.state, STATE_FAILED);
+  CHECK_EQ(wcb.exception.getType(),
+                    AsyncSocketException::INTERNAL_ERROR);
+}
+
+/**
+ * Test writing a mix of simple buffers and IOBufs
+ */
+TEST(AsyncSocketTest, WriteIOBuf) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Accept the connection
+  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
+  ReadCallback rcb;
+  acceptedSocket->setReadCB(&rcb);
+
+  // Write a simple buffer to the socket
+  size_t simpleBufLength = 5;
+  char simpleBuf[simpleBufLength];
+  memset(simpleBuf, 'a', simpleBufLength);
+  WriteCallback wcb;
+  socket->write(&wcb, simpleBuf, simpleBufLength);
+
+  // Write a single-element IOBuf chain
+  size_t buf1Length = 7;
+  unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
+  memset(buf1->writableData(), 'b', buf1Length);
+  buf1->append(buf1Length);
+  unique_ptr<IOBuf> buf1Copy(buf1->clone());
+  WriteCallback wcb2;
+  socket->writeChain(&wcb2, std::move(buf1));
+
+  // Write a multiple-element IOBuf chain
+  size_t buf2Length = 11;
+  unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
+  memset(buf2->writableData(), 'c', buf2Length);
+  buf2->append(buf2Length);
+  size_t buf3Length = 13;
+  unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
+  memset(buf3->writableData(), 'd', buf3Length);
+  buf3->append(buf3Length);
+  buf2->appendChain(std::move(buf3));
+  unique_ptr<IOBuf> buf2Copy(buf2->clone());
+  buf2Copy->coalesce();
+  WriteCallback wcb3;
+  socket->writeChain(&wcb3, std::move(buf2));
+  socket->shutdownWrite();
+
+  // Let the reads and writes run to completion
+  evb.loop();
+
+  CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb3.state, STATE_SUCCEEDED);
+
+  // Make sure the reader got the right data in the right order
+  CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.buffers.size(), 1);
+  CHECK_EQ(rcb.buffers[0].length,
+      simpleBufLength + buf1Length + buf2Length + buf3Length);
+  CHECK_EQ(
+      memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
+  CHECK_EQ(
+      memcmp(rcb.buffers[0].buffer + simpleBufLength,
+          buf1Copy->data(), buf1Copy->length()), 0);
+  CHECK_EQ(
+      memcmp(rcb.buffers[0].buffer + simpleBufLength + buf1Length,
+          buf2Copy->data(), buf2Copy->length()), 0);
+
+  acceptedSocket->close();
+  socket->close();
+}
+
+TEST(AsyncSocketTest, WriteIOBufCorked) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Accept the connection
+  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
+  ReadCallback rcb;
+  acceptedSocket->setReadCB(&rcb);
+
+  // Do three writes, 100ms apart, with the "cork" flag set
+  // on the second write.  The reader should see the first write
+  // arrive by itself, followed by the second and third writes
+  // arriving together.
+  size_t buf1Length = 5;
+  unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
+  memset(buf1->writableData(), 'a', buf1Length);
+  buf1->append(buf1Length);
+  size_t buf2Length = 7;
+  unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
+  memset(buf2->writableData(), 'b', buf2Length);
+  buf2->append(buf2Length);
+  size_t buf3Length = 11;
+  unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
+  memset(buf3->writableData(), 'c', buf3Length);
+  buf3->append(buf3Length);
+  WriteCallback wcb1;
+  socket->writeChain(&wcb1, std::move(buf1));
+  WriteCallback wcb2;
+  DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
+  write2.scheduleTimeout(100);
+  WriteCallback wcb3;
+  DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
+  write3.scheduleTimeout(200);
+
+  evb.loop();
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
+  if (wcb3.state != STATE_SUCCEEDED) {
+    throw(wcb3.exception);
+  }
+  CHECK_EQ(wcb3.state, STATE_SUCCEEDED);
+
+  // Make sure the reader got the data with the right grouping
+  CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+  CHECK_EQ(rcb.buffers.size(), 2);
+  CHECK_EQ(rcb.buffers[0].length, buf1Length);
+  CHECK_EQ(rcb.buffers[1].length, buf2Length + buf3Length);
+
+  acceptedSocket->close();
+  socket->close();
+}
+
+/**
+ * Test performing a zero-length write
+ */
+TEST(AsyncSocketTest, ZeroLengthWrite) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket =
+    AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+  evb.loop(); // loop until the socket is connected
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+  ReadCallback rcb;
+  acceptedSocket->setReadCB(&rcb);
+
+  size_t len1 = 1024*1024;
+  size_t len2 = 1024*1024;
+  std::unique_ptr<char[]> buf(new char[len1 + len2]);
+  memset(buf.get(), 'a', len1);
+  memset(buf.get(), 'b', len2);
+
+  WriteCallback wcb1;
+  WriteCallback wcb2;
+  WriteCallback wcb3;
+  WriteCallback wcb4;
+  socket->write(&wcb1, buf.get(), 0);
+  socket->write(&wcb2, buf.get(), len1);
+  socket->write(&wcb3, buf.get() + len1, 0);
+  socket->write(&wcb4, buf.get() + len1, len2);
+  socket->close();
+
+  evb.loop(); // loop until the data is sent
+
+  CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb3.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb4.state, STATE_SUCCEEDED);
+  rcb.verifyData(buf.get(), len1 + len2);
+}
+
+TEST(AsyncSocketTest, ZeroLengthWritev) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket =
+    AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+  evb.loop(); // loop until the socket is connected
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+  ReadCallback rcb;
+  acceptedSocket->setReadCB(&rcb);
+
+  size_t len1 = 1024*1024;
+  size_t len2 = 1024*1024;
+  std::unique_ptr<char[]> buf(new char[len1 + len2]);
+  memset(buf.get(), 'a', len1);
+  memset(buf.get(), 'b', len2);
+
+  WriteCallback wcb;
+  size_t iovCount = 4;
+  struct iovec iov[iovCount];
+  iov[0].iov_base = buf.get();
+  iov[0].iov_len = len1;
+  iov[1].iov_base = buf.get() + len1;
+  iov[1].iov_len = 0;
+  iov[2].iov_base = buf.get() + len1;
+  iov[2].iov_len = len2;
+  iov[3].iov_base = buf.get() + len1 + len2;
+  iov[3].iov_len = 0;
+
+  socket->writev(&wcb, iov, iovCount);
+  socket->close();
+  evb.loop(); // loop until the data is sent
+
+  CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+  rcb.verifyData(buf.get(), len1 + len2);
+}
+
+///////////////////////////////////////////////////////////////////////////
+// close() related tests
+///////////////////////////////////////////////////////////////////////////
+
+/**
+ * Test calling close() with pending writes when the socket is already closing.
+ */
+TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // accept the socket on the server side
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+  // Loop to ensure the connect has completed
+  evb.loop();
+
+  // Make sure we are connected
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+  // Schedule pending writes, until several write attempts have blocked
+  char buf[128];
+  memset(buf, 'a', sizeof(buf));
+  typedef vector< std::shared_ptr<WriteCallback> > WriteCallbackVector;
+  WriteCallbackVector writeCallbacks;
+
+  writeCallbacks.reserve(5);
+  while (writeCallbacks.size() < 5) {
+    std::shared_ptr<WriteCallback> wcb(new WriteCallback);
+
+    socket->write(wcb.get(), buf, sizeof(buf));
+    if (wcb->state == STATE_SUCCEEDED) {
+      // Succeeded immediately.  Keep performing more writes
+      continue;
+    }
+
+    // This write is blocked.
+    // Have the write callback call close() when writeError() is invoked
+    wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
+    writeCallbacks.push_back(wcb);
+  }
+
+  // Call closeNow() to immediately fail the pending writes
+  socket->closeNow();
+
+  // Make sure writeError() was invoked on all of the pending write callbacks
+  for (WriteCallbackVector::const_iterator it = writeCallbacks.begin();
+       it != writeCallbacks.end();
+       ++it) {
+    CHECK_EQ((*it)->state, STATE_FAILED);
+  }
+}
+
+
+// TODO:
+// - Test connect() and have the connect callback set the read callback
+// - Test connect() and have the connect callback unset the read callback
+// - Test reading/writing/closing/destroying the socket in the connect callback
+// - Test reading/writing/closing/destroying the socket in the read callback
+// - Test reading/writing/closing/destroying the socket in the write callback
+// - Test one-way shutdown behavior
+// - Test changing the EventBase
+//
+// - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
+//   in connectSuccess(), readDataAvailable(), writeSuccess()
+
+
+///////////////////////////////////////////////////////////////////////////
+// AsyncServerSocket tests
+///////////////////////////////////////////////////////////////////////////
+
+/**
+ * Helper AcceptCallback class for the test code
+ * It records the callbacks that were invoked, and also supports calling
+ * generic std::function objects in each callback.
+ */
+class TestAcceptCallback : public AsyncServerSocket::AcceptCallback {
+ public:
+  enum EventType {
+    TYPE_START,
+    TYPE_ACCEPT,
+    TYPE_ERROR,
+    TYPE_STOP
+  };
+  struct EventInfo {
+    EventInfo(int fd, const folly::SocketAddress& addr)
+      : type(TYPE_ACCEPT),
+        fd(fd),
+        address(addr),
+        errorMsg() {}
+    explicit EventInfo(const std::string& msg)
+      : type(TYPE_ERROR),
+        fd(-1),
+        address(),
+        errorMsg(msg) {}
+    explicit EventInfo(EventType et)
+      : type(et),
+        fd(-1),
+        address(),
+        errorMsg() {}
+
+    EventType type;
+    int fd;  // valid for TYPE_ACCEPT
+    folly::SocketAddress address;  // valid for TYPE_ACCEPT
+    string errorMsg;  // valid for TYPE_ERROR
+  };
+  typedef std::deque<EventInfo> EventList;
+
+  TestAcceptCallback()
+    : connectionAcceptedFn_(),
+      acceptErrorFn_(),
+      acceptStoppedFn_(),
+      events_() {}
+
+  std::deque<EventInfo>* getEvents() {
+    return &events_;
+  }
+
+  void setConnectionAcceptedFn(
+      const std::function<void(int, const folly::SocketAddress&)>& fn) {
+    connectionAcceptedFn_ = fn;
+  }
+  void setAcceptErrorFn(const std::function<void(const std::exception&)>& fn) {
+    acceptErrorFn_ = fn;
+  }
+  void setAcceptStartedFn(const std::function<void()>& fn) {
+    acceptStartedFn_ = fn;
+  }
+  void setAcceptStoppedFn(const std::function<void()>& fn) {
+    acceptStoppedFn_ = fn;
+  }
+
+  void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
+      noexcept {
+    events_.push_back(EventInfo(fd, clientAddr));
+
+    if (connectionAcceptedFn_) {
+      connectionAcceptedFn_(fd, clientAddr);
+    }
+  }
+  void acceptError(const std::exception& ex) noexcept {
+    events_.push_back(EventInfo(ex.what()));
+
+    if (acceptErrorFn_) {
+      acceptErrorFn_(ex);
+    }
+  }
+  void acceptStarted() noexcept {
+    events_.push_back(EventInfo(TYPE_START));
+
+    if (acceptStartedFn_) {
+      acceptStartedFn_();
+    }
+  }
+  void acceptStopped() noexcept {
+    events_.push_back(EventInfo(TYPE_STOP));
+
+    if (acceptStoppedFn_) {
+      acceptStoppedFn_();
+    }
+  }
+
+ private:
+  std::function<void(int, const folly::SocketAddress&)> connectionAcceptedFn_;
+  std::function<void(const std::exception&)> acceptErrorFn_;
+  std::function<void()> acceptStartedFn_;
+  std::function<void()> acceptStoppedFn_;
+
+  std::deque<EventInfo> events_;
+};
+
+/**
+ * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
+ */
+TEST(AsyncSocketTest, ServerAcceptOptions) {
+  EventBase eventBase;
+
+  // Create a server socket
+  std::shared_ptr<AsyncServerSocket> serverSocket(
+      AsyncServerSocket::newSocket(&eventBase));
+  serverSocket->bind(0);
+  serverSocket->listen(16);
+  folly::SocketAddress serverAddress;
+  serverSocket->getAddress(&serverAddress);
+
+  // Add a callback to accept one connection then stop the loop
+  TestAcceptCallback acceptCallback;
+  acceptCallback.setConnectionAcceptedFn(
+    [&](int fd, const folly::SocketAddress& addr) {
+      serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    });
+  acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
+    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+  });
+  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->startAccepting();
+
+  // Connect to the server socket
+  std::shared_ptr<AsyncSocket> socket(
+      AsyncSocket::newSocket(&eventBase, serverAddress));
+
+  eventBase.loop();
+
+  // Verify that the server accepted a connection
+  CHECK_EQ(acceptCallback.getEvents()->size(), 3);
+  CHECK_EQ(acceptCallback.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(acceptCallback.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(acceptCallback.getEvents()->at(2).type,
+                    TestAcceptCallback::TYPE_STOP);
+  int fd = acceptCallback.getEvents()->at(1).fd;
+
+  // The accepted connection should already be in non-blocking mode
+  int flags = fcntl(fd, F_GETFL, 0);
+  CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
+
+#ifndef TCP_NOPUSH
+  // The accepted connection should already have TCP_NODELAY set
+  int value;
+  socklen_t valueLength = sizeof(value);
+  int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
+  CHECK_EQ(rc, 0);
+  CHECK_EQ(value, 1);
+#endif
+}
+
+/**
+ * Test AsyncServerSocket::removeAcceptCallback()
+ */
+TEST(AsyncSocketTest, RemoveAcceptCallback) {
+  // Create a new AsyncServerSocket
+  EventBase eventBase;
+  std::shared_ptr<AsyncServerSocket> serverSocket(
+      AsyncServerSocket::newSocket(&eventBase));
+  serverSocket->bind(0);
+  serverSocket->listen(16);
+  folly::SocketAddress serverAddress;
+  serverSocket->getAddress(&serverAddress);
+
+  // Add several accept callbacks
+  TestAcceptCallback cb1;
+  TestAcceptCallback cb2;
+  TestAcceptCallback cb3;
+  TestAcceptCallback cb4;
+  TestAcceptCallback cb5;
+  TestAcceptCallback cb6;
+  TestAcceptCallback cb7;
+
+  // Test having callbacks remove other callbacks before them on the list,
+  // after them on the list, or removing themselves.
+  //
+  // Have callback 2 remove callback 3 and callback 5 the first time it is
+  // called.
+  int cb2Count = 0;
+  cb1.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+      std::shared_ptr<AsyncSocket> sock2(
+        AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2: -cb3 -cb5
+      });
+  cb3.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+    });
+  cb4.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+      std::shared_ptr<AsyncSocket> sock3(
+        AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
+    });
+  cb5.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+  std::shared_ptr<AsyncSocket> sock5(
+      AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
+
+    });
+  cb2.setConnectionAcceptedFn(
+    [&](int fd, const folly::SocketAddress& addr) {
+      if (cb2Count == 0) {
+        serverSocket->removeAcceptCallback(&cb3, nullptr);
+        serverSocket->removeAcceptCallback(&cb5, nullptr);
+      }
+      ++cb2Count;
+    });
+  // Have callback 6 remove callback 4 the first time it is called,
+  // and destroy the server socket the second time it is called
+  int cb6Count = 0;
+  cb6.setConnectionAcceptedFn(
+    [&](int fd, const folly::SocketAddress& addr) {
+      if (cb6Count == 0) {
+        serverSocket->removeAcceptCallback(&cb4, nullptr);
+        std::shared_ptr<AsyncSocket> sock6(
+          AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
+        std::shared_ptr<AsyncSocket> sock7(
+          AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
+        std::shared_ptr<AsyncSocket> sock8(
+          AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop
+
+      } else {
+        serverSocket.reset();
+      }
+      ++cb6Count;
+    });
+  // Have callback 7 remove itself
+  cb7.setConnectionAcceptedFn(
+    [&](int fd, const folly::SocketAddress& addr) {
+      serverSocket->removeAcceptCallback(&cb7, nullptr);
+    });
+
+  serverSocket->addAcceptCallback(&cb1, nullptr);
+  serverSocket->addAcceptCallback(&cb2, nullptr);
+  serverSocket->addAcceptCallback(&cb3, nullptr);
+  serverSocket->addAcceptCallback(&cb4, nullptr);
+  serverSocket->addAcceptCallback(&cb5, nullptr);
+  serverSocket->addAcceptCallback(&cb6, nullptr);
+  serverSocket->addAcceptCallback(&cb7, nullptr);
+  serverSocket->startAccepting();
+
+  // Make several connections to the socket
+  std::shared_ptr<AsyncSocket> sock1(
+      AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
+  std::shared_ptr<AsyncSocket> sock4(
+      AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4
+
+  // Loop until we are stopped
+  eventBase.loop();
+
+  // Check to make sure that the expected callbacks were invoked.
+  //
+  // NOTE: This code depends on the AsyncServerSocket operating calling all of
+  // the AcceptCallbacks in round-robin fashion, in the order that they were
+  // added.  The code is implemented this way right now, but the API doesn't
+  // explicitly require it be done this way.  If we change the code not to be
+  // exactly round robin in the future, we can simplify the test checks here.
+  // (We'll also need to update the termination code, since we expect cb6 to
+  // get called twice to terminate the loop.)
+  CHECK_EQ(cb1.getEvents()->size(), 4);
+  CHECK_EQ(cb1.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(cb1.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(cb1.getEvents()->at(2).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(cb1.getEvents()->at(3).type,
+                    TestAcceptCallback::TYPE_STOP);
+
+  CHECK_EQ(cb2.getEvents()->size(), 4);
+  CHECK_EQ(cb2.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(cb2.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(cb2.getEvents()->at(2).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(cb2.getEvents()->at(3).type,
+                    TestAcceptCallback::TYPE_STOP);
+
+  CHECK_EQ(cb3.getEvents()->size(), 2);
+  CHECK_EQ(cb3.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(cb3.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_STOP);
+
+  CHECK_EQ(cb4.getEvents()->size(), 3);
+  CHECK_EQ(cb4.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(cb4.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(cb4.getEvents()->at(2).type,
+                    TestAcceptCallback::TYPE_STOP);
+
+  CHECK_EQ(cb5.getEvents()->size(), 2);
+  CHECK_EQ(cb5.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(cb5.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_STOP);
+
+  CHECK_EQ(cb6.getEvents()->size(), 4);
+  CHECK_EQ(cb6.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(cb6.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(cb6.getEvents()->at(2).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(cb6.getEvents()->at(3).type,
+                    TestAcceptCallback::TYPE_STOP);
+
+  CHECK_EQ(cb7.getEvents()->size(), 3);
+  CHECK_EQ(cb7.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(cb7.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(cb7.getEvents()->at(2).type,
+                    TestAcceptCallback::TYPE_STOP);
+}
+
+/**
+ * Test AsyncServerSocket::removeAcceptCallback()
+ */
+TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
+  // Create a new AsyncServerSocket
+  EventBase eventBase;
+  std::shared_ptr<AsyncServerSocket> serverSocket(
+      AsyncServerSocket::newSocket(&eventBase));
+  serverSocket->bind(0);
+  serverSocket->listen(16);
+  folly::SocketAddress serverAddress;
+  serverSocket->getAddress(&serverAddress);
+
+  // Add several accept callbacks
+  TestAcceptCallback cb1;
+  auto thread_id = pthread_self();
+  cb1.setAcceptStartedFn([&](){
+    CHECK_NE(thread_id, pthread_self());
+    thread_id = pthread_self();
+  });
+  cb1.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+    CHECK_EQ(thread_id, pthread_self());
+    serverSocket->removeAcceptCallback(&cb1, nullptr);
+  });
+  cb1.setAcceptStoppedFn([&](){
+    CHECK_EQ(thread_id, pthread_self());
+  });
+
+  // Test having callbacks remove other callbacks before them on the list,
+  serverSocket->addAcceptCallback(&cb1, nullptr);
+  serverSocket->startAccepting();
+
+  // Make several connections to the socket
+  std::shared_ptr<AsyncSocket> sock1(
+      AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
+
+  // Loop in another thread
+  auto other = std::thread([&](){
+    eventBase.loop();
+  });
+  other.join();
+
+  // Check to make sure that the expected callbacks were invoked.
+  //
+  // NOTE: This code depends on the AsyncServerSocket operating calling all of
+  // the AcceptCallbacks in round-robin fashion, in the order that they were
+  // added.  The code is implemented this way right now, but the API doesn't
+  // explicitly require it be done this way.  If we change the code not to be
+  // exactly round robin in the future, we can simplify the test checks here.
+  // (We'll also need to update the termination code, since we expect cb6 to
+  // get called twice to terminate the loop.)
+  CHECK_EQ(cb1.getEvents()->size(), 3);
+  CHECK_EQ(cb1.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(cb1.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(cb1.getEvents()->at(2).type,
+                    TestAcceptCallback::TYPE_STOP);
+
+}
+
+void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
+  // Add a callback to accept one connection then stop accepting
+  TestAcceptCallback acceptCallback;
+  acceptCallback.setConnectionAcceptedFn(
+    [&](int fd, const folly::SocketAddress& addr) {
+      serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    });
+  acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
+    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+  });
+  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->startAccepting();
+
+  // Connect to the server socket
+  EventBase* eventBase = serverSocket->getEventBase();
+  folly::SocketAddress serverAddress;
+  serverSocket->getAddress(&serverAddress);
+  AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
+
+  // Loop to process all events
+  eventBase->loop();
+
+  // Verify that the server accepted a connection
+  CHECK_EQ(acceptCallback.getEvents()->size(), 3);
+  CHECK_EQ(acceptCallback.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(acceptCallback.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(acceptCallback.getEvents()->at(2).type,
+                    TestAcceptCallback::TYPE_STOP);
+}
+
+/* Verify that we don't leak sockets if we are destroyed()
+ * and there are still writes pending
+ *
+ * If destroy() only calls close() instead of closeNow(),
+ * it would shutdown(writes) on the socket, but it would
+ * never be close()'d, and the socket would leak
+ */
+TEST(AsyncSocketTest, DestroyCloseTest) {
+  TestServer server;
+
+  // connect()
+  EventBase clientEB;
+  EventBase serverEB;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  // Accept the connection
+  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
+  ReadCallback rcb;
+  acceptedSocket->setReadCB(&rcb);
+
+  // Write a large buffer to the socket that is larger than kernel buffer
+  size_t simpleBufLength = 5000000;
+  char* simpleBuf = new char[simpleBufLength];
+  memset(simpleBuf, 'a', simpleBufLength);
+  WriteCallback wcb;
+
+  // Let the reads and writes run to completion
+  int fd = acceptedSocket->getFd();
+
+  acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
+  socket.reset();
+  acceptedSocket.reset();
+
+  // Test that server socket was closed
+  ssize_t sz = read(fd, simpleBuf, simpleBufLength);
+  CHECK_EQ(sz, -1);
+  CHECK_EQ(errno, 9);
+  delete[] simpleBuf;
+}
+
+/**
+ * Test AsyncServerSocket::useExistingSocket()
+ */
+TEST(AsyncSocketTest, ServerExistingSocket) {
+  EventBase eventBase;
+
+  // Test creating a socket, and letting AsyncServerSocket bind and listen
+  {
+    // Manually create a socket
+    int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    ASSERT_GE(fd, 0);
+
+    // Create a server socket
+    AsyncServerSocket::UniquePtr serverSocket(
+        new AsyncServerSocket(&eventBase));
+    serverSocket->useExistingSocket(fd);
+    folly::SocketAddress address;
+    serverSocket->getAddress(&address);
+    address.setPort(0);
+    serverSocket->bind(address);
+    serverSocket->listen(16);
+
+    // Make sure the socket works
+    serverSocketSanityTest(serverSocket.get());
+  }
+
+  // Test creating a socket and binding manually,
+  // then letting AsyncServerSocket listen
+  {
+    // Manually create a socket
+    int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    ASSERT_GE(fd, 0);
+    // bind
+    struct sockaddr_in addr;
+    addr.sin_family = AF_INET;
+    addr.sin_port = 0;
+    addr.sin_addr.s_addr = INADDR_ANY;
+    CHECK_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
+                             sizeof(addr)), 0);
+    // Look up the address that we bound to
+    folly::SocketAddress boundAddress;
+    boundAddress.setFromLocalAddress(fd);
+
+    // Create a server socket
+    AsyncServerSocket::UniquePtr serverSocket(
+        new AsyncServerSocket(&eventBase));
+    serverSocket->useExistingSocket(fd);
+    serverSocket->listen(16);
+
+    // Make sure AsyncServerSocket reports the same address that we bound to
+    folly::SocketAddress serverSocketAddress;
+    serverSocket->getAddress(&serverSocketAddress);
+    CHECK_EQ(boundAddress, serverSocketAddress);
+
+    // Make sure the socket works
+    serverSocketSanityTest(serverSocket.get());
+  }
+
+  // Test creating a socket, binding and listening manually,
+  // then giving it to AsyncServerSocket
+  {
+    // Manually create a socket
+    int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    ASSERT_GE(fd, 0);
+    // bind
+    struct sockaddr_in addr;
+    addr.sin_family = AF_INET;
+    addr.sin_port = 0;
+    addr.sin_addr.s_addr = INADDR_ANY;
+    CHECK_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
+                             sizeof(addr)), 0);
+    // Look up the address that we bound to
+    folly::SocketAddress boundAddress;
+    boundAddress.setFromLocalAddress(fd);
+    // listen
+    CHECK_EQ(listen(fd, 16), 0);
+
+    // Create a server socket
+    AsyncServerSocket::UniquePtr serverSocket(
+        new AsyncServerSocket(&eventBase));
+    serverSocket->useExistingSocket(fd);
+
+    // Make sure AsyncServerSocket reports the same address that we bound to
+    folly::SocketAddress serverSocketAddress;
+    serverSocket->getAddress(&serverSocketAddress);
+    CHECK_EQ(boundAddress, serverSocketAddress);
+
+    // Make sure the socket works
+    serverSocketSanityTest(serverSocket.get());
+  }
+}
+
+TEST(AsyncSocketTest, UnixDomainSocketTest) {
+  EventBase eventBase;
+
+  // Create a server socket
+  std::shared_ptr<AsyncServerSocket> serverSocket(
+      AsyncServerSocket::newSocket(&eventBase));
+  string path(1, 0);
+  path.append("/anonymous");
+  folly::SocketAddress serverAddress;
+  serverAddress.setFromPath(path);
+  serverSocket->bind(serverAddress);
+  serverSocket->listen(16);
+
+  // Add a callback to accept one connection then stop the loop
+  TestAcceptCallback acceptCallback;
+  acceptCallback.setConnectionAcceptedFn(
+    [&](int fd, const folly::SocketAddress& addr) {
+      serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    });
+  acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
+    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+  });
+  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->startAccepting();
+
+  // Connect to the server socket
+  std::shared_ptr<AsyncSocket> socket(
+      AsyncSocket::newSocket(&eventBase, serverAddress));
+
+  eventBase.loop();
+
+  // Verify that the server accepted a connection
+  CHECK_EQ(acceptCallback.getEvents()->size(), 3);
+  CHECK_EQ(acceptCallback.getEvents()->at(0).type,
+                    TestAcceptCallback::TYPE_START);
+  CHECK_EQ(acceptCallback.getEvents()->at(1).type,
+                    TestAcceptCallback::TYPE_ACCEPT);
+  CHECK_EQ(acceptCallback.getEvents()->at(2).type,
+                    TestAcceptCallback::TYPE_STOP);
+  int fd = acceptCallback.getEvents()->at(1).fd;
+
+  // The accepted connection should already be in non-blocking mode
+  int flags = fcntl(fd, F_GETFL, 0);
+  CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
+}
diff --git a/folly/io/async/test/BlockingSocket.h b/folly/io/async/test/BlockingSocket.h
new file mode 100644 (file)
index 0000000..7d2ee45
--- /dev/null
@@ -0,0 +1,126 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/Optional.h>
+#include <folly/io/async/SSLContext.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+
+class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
+                       public folly::AsyncTransportWrapper::ReadCallback,
+                       public folly::AsyncTransportWrapper::WriteCallback
+{
+ public:
+  explicit BlockingSocket(int fd)
+    : sock_(new folly::AsyncSocket(&eventBase_, fd)) {
+  }
+
+  BlockingSocket(folly::SocketAddress address,
+                 std::shared_ptr<folly::SSLContext> sslContext)
+    : sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) :
+            new folly::AsyncSocket(&eventBase_)),
+    address_(address) {}
+
+  void open() {
+    sock_->connect(this, address_);
+    eventBase_.loop();
+    if (err_.hasValue()) {
+      throw err_.value();
+    }
+  }
+  void close() {
+    sock_->close();
+  }
+
+  int32_t write(uint8_t const* buf, size_t len) {
+    sock_->write(this, buf, len);
+    eventBase_.loop();
+    if (err_.hasValue()) {
+      throw err_.value();
+    }
+    return len;
+  }
+
+  void flush() {}
+
+  int32_t readAll(uint8_t *buf, size_t len) {
+    return readHelper(buf, len, true);
+  }
+
+  int32_t read(uint8_t *buf, size_t len) {
+    return readHelper(buf, len, false);
+  }
+
+  int getSocketFD() const {
+    return sock_->getFd();
+  }
+
+ private:
+  folly::EventBase eventBase_;
+  folly::AsyncSocket::UniquePtr sock_;
+  folly::Optional<folly::AsyncSocketException> err_;
+  uint8_t *readBuf_{nullptr};
+  size_t readLen_{0};
+  folly::SocketAddress address_;
+
+  void connectSuccess() noexcept override {}
+  void connectErr(const folly::AsyncSocketException& ex) noexcept override {
+    err_ = ex;
+  }
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *bufReturn = readBuf_;
+    *lenReturn = readLen_;
+  }
+  void readDataAvailable(size_t len) noexcept override {
+    readBuf_ += len;
+    readLen_ -= len;
+    if (readLen_ == 0) {
+      sock_->setReadCB(nullptr);
+    }
+  }
+  void readEOF() noexcept override {
+  }
+  void readErr(const folly::AsyncSocketException& ex) noexcept override {
+    err_ = ex;
+  }
+  void writeSuccess() noexcept override {}
+  void writeErr(size_t bytesWritten,
+                const folly::AsyncSocketException& ex) noexcept override {
+    err_ = ex;
+  }
+
+  int32_t readHelper(uint8_t *buf, size_t len, bool all) {
+    readBuf_ = buf;
+    readLen_ = len;
+    sock_->setReadCB(this);
+    while (!err_ && sock_->good() && readLen_ > 0) {
+      eventBase_.loop();
+      if (!all) {
+        break;
+      }
+    }
+    sock_->setReadCB(nullptr);
+    if (err_.hasValue()) {
+      throw err_.value();
+    }
+    if (all && readLen_ > 0) {
+      throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN,
+                                        "eof");
+    }
+    return len - readLen_;
+  }
+};
diff --git a/folly/io/async/test/certs/ca-cert.pem b/folly/io/async/test/certs/ca-cert.pem
new file mode 100644 (file)
index 0000000..1a4f22b
--- /dev/null
@@ -0,0 +1,21 @@
+-----BEGIN CERTIFICATE-----
+MIIDXTCCAkWgAwIBAgIJAKMZICGWUzawMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
+BAYTAlVTMQ8wDQYDVQQKDAZUaHJpZnQxJTAjBgNVBAMMHFRocmlmdCBDZXJ0aWZp
+Y2F0ZSBBdXRob3JpdHkwHhcNMTQwNTE2MjAyODUyWhcNNDExMDAxMjAyODUyWjBF
+MQswCQYDVQQGEwJVUzEPMA0GA1UECgwGVGhyaWZ0MSUwIwYDVQQDDBxUaHJpZnQg
+Q2VydGlmaWNhdGUgQXV0aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
+CgKCAQEA1Bx2vUvXZ8PrvEBxwdH5qM1F2Xo7UkeC1jzQ+OLUBEcCiEduyStitSvB
+NOAzAGdjt7NmHTP/7OJngp2vzQGjSQzm20XacyTieFUuPBuikUc0Ge3Tf+uQXtiU
+zZPh+xn6arHH+zBWtmUCt3cBrpgRqdnWUsbl8eqo5HsczY781FxQbDoT9VP6A+9R
+KGTsEhxxKbWJ1C7OngwLKc7Zv4DtTC1JFlFyKd8ryDtxP4s/GgsXJkoK0Hkpputr
+cMxMm6OGt77mFvzR2qRY1CpEK/9rjBB6Gqd8GakXsvoOsqL/37k2wVhN/JoS/Pde
+12Mp6TZ2rA8NW8vRujfWU0u55gnQnwIDAQABo1AwTjAdBgNVHQ4EFgQUQ00NGVmY
+NZ6LJg8UQUOVLZX1Gh8wHwYDVR0jBBgwFoAUQ00NGVmYNZ6LJg8UQUOVLZX1Gh8w
+DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQUFAAOCAQEAdlxt5+z9uXCBr1Wt6r49
+4MmOYw9lOnEOG1JPMRo108TLpmwXEWReCAtjQuR7BitRJW0kJtlO1M6t3qoIh6GA
+sBkgsjQM1xNY3YEpx71MLt1V+JD+2WtSBKMyysj1TiOmIH66kkvXO3ptXzhjhZyX
+G6B+kxLtxrqkn9SJULyN55X8T+dkW28UIBZVLavoREDU+UPrYU9JgZeIVObtGSWi
+DvS4RIJZNjgG3vTrT00rfUGEfTlI54Vbcmv0cYvswP/nMsLtDStCdgI7c/ipyJve
+dfuI4CedjE240AxK5OFxFg/k/IfnB4a5oojbdIR9hKrTU57TPaUVD50Na9WA1aqX
+5Q==
+-----END CERTIFICATE-----
diff --git a/folly/io/async/test/certs/tests-cert.pem b/folly/io/async/test/certs/tests-cert.pem
new file mode 100644 (file)
index 0000000..894ba82
--- /dev/null
@@ -0,0 +1,19 @@
+-----BEGIN CERTIFICATE-----
+MIIDKzCCAhOgAwIBAgIBCjANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJVUzEP
+MA0GA1UECgwGVGhyaWZ0MSUwIwYDVQQDDBxUaHJpZnQgQ2VydGlmaWNhdGUgQXV0
+aG9yaXR5MB4XDTE0MDUxNjIwMjg1MloXDTQxMTAwMTIwMjg1MlowRjELMAkGA1UE
+BhMCVVMxDTALBgNVBAgTBE9oaW8xETAPBgNVBAcTCEhpbGxpYXJkMRUwEwYDVQQD
+EwxBc294IENvbXBhbnkwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCz
+ZGrJ5XQHAuMYHlBgn32OOc9l0n3RjXccio2ceeWctXkSxDP3vFyZ4kBILF1lsY1g
+o8UTjMkqSDYcytCLK0qavrv9BZRLB9FcpqJ9o4V9feaI/HsHa8DYHyEs8qyNTTNG
+YQ3i4j+AA9iDSpezIYy/tyAOAjrSquUW1jI4tzKTBh8hk8MAMvR2/NPHPkrp4gI+
+EMH6u4vWdr4F9bbriLFWoU04T9mWOMk7G+h8BS9sgINg2+v5cWvl3BC4kLk5L1yJ
+FEyuofSSCEEe6dDf7uVh+RPKa4hEkIYo31AEOPFrN56d+pCj/5l67HTWXoQx3rjy
+dNXMvgU75urm6TQe8dB5AgMBAAGjJTAjMCEGA1UdEQQaMBiHBH8AAAGHEAAAAAAA
+AAAAAAAAAAAAAAEwDQYJKoZIhvcNAQEFBQADggEBAD26XYInaEvlWZJYgtl3yQyC
+3NRQc3LG7XxWg4aFdXCxYLPRAL2HLoarKYH8GPFso57t5xnhA8WfP7iJxmgsKdCS
+0pNIicOWsMmXvYLib0j9tMCFR+a8rn3f4n+clwnqas4w/vWBJUoMgyxtkP8NNNZO
+kIl02JKRhuyiFyPLilVp5tu0e+lmyUER+ak53WjLq2yoytYAlHkzkOpc4MZ/TNt5
+UTEtx/WVlZvlrPi3dsi7QikkjQgo1wCnm7owtuAHlPDMAB8wKk4+vvIOjsGM33T/
+8ffq/4X1HeYM0w0fM+SVlX1rwkXA1RW/jn48VWFHpWbE10+m196OdiToGfm2OJI=
+-----END CERTIFICATE-----
diff --git a/folly/io/async/test/certs/tests-key.pem b/folly/io/async/test/certs/tests-key.pem
new file mode 100644 (file)
index 0000000..caa9052
--- /dev/null
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEAs2RqyeV0BwLjGB5QYJ99jjnPZdJ90Y13HIqNnHnlnLV5EsQz
+97xcmeJASCxdZbGNYKPFE4zJKkg2HMrQiytKmr67/QWUSwfRXKaifaOFfX3miPx7
+B2vA2B8hLPKsjU0zRmEN4uI/gAPYg0qXsyGMv7cgDgI60qrlFtYyOLcykwYfIZPD
+ADL0dvzTxz5K6eICPhDB+ruL1na+BfW264ixVqFNOE/ZljjJOxvofAUvbICDYNvr
++XFr5dwQuJC5OS9ciRRMrqH0kghBHunQ3+7lYfkTymuIRJCGKN9QBDjxazeenfqQ
+o/+Zeux01l6EMd648nTVzL4FO+bq5uk0HvHQeQIDAQABAoIBAQCSPcBYindF5/Kd
+jMjVm+9M7I/IYAo1tG9vkvvSngSy9bWXuN7sjF+pCyqAK7qP1mh8acWVJGYx0+BZ
+JHVRnp8Y+3hg0hWL/PmN4EICzjVakjJHZhwddpglF2uCKurD3jV4oFIjrXE6uOfe
+UAbO/wCwoWa+RM8TQkGzljYmyiGufCcXlgEKMNA7TIvbJ9TVx3VTCOQy6EjZ13jd
+M6X7byV/ZOFpZ2H0QV46LvZraw04riXQ/59gVmzizYdI+BwnxxapsCmalTJoV/Y0
+LMI2ylat4PTMVTxPF+ti7Nt+rUkkEx6kuiAgfc+bzE4BSD5X4wy3fdLVLccoxXYw
+4N3fOuQhAoGBAOLrMhiSCrzXGjDWTbPrwzxXDO0qm+wURELi3N5SXIkKUdG2/In6
+wNdpXdvqblOm7SASgPf9KCwUSADrNw6R6nbfrrir5EHg66YydI/OW42QzJKcBUFh
+5Q5na3fvoL/zRhsmh0gEymBg+OIfNel2LY69bl8aAko2y0R1kj7zb8X1AoGBAMph
+9hlnkIBSw60+pKHaOqo2t/kihNNMFyfOgJvh8960eFeMDhMIXgxPUR8yaPX0bBMb
+bCdEJJ2pmq7zUBPvxVJLedwkGMhywElA8yYVh+S6x4Cg+lYo4spIjrHQ/WTvJkHB
+GrDskxdq80lbXjwRd0dPJZkxhKJec1o0n8S03Mn1AoGAGarK5taWGlgmYUHMVj6j
+vc6G6si4DFMaiYpJu2gLiYC+Un9lP2I6r+L+N+LjidjG16rgJazf/2Rn5Jq2hpJg
+uAODKuZekkkTvp/UaXPJDVFEooy9V3DwTNnL4SwcvbmRw35vLOlFzvMJE+K94WN5
+sbyhoGY7vhNGmL7HxREaIoUCgYEAwpteVWFz3yE2ziF9l7FMVh7V23go9zGk1n9I
+xhyJL26khbLEWeLi5L1kiTYlHdUSE3F8F2n8N6s+ddq79t/KA29WV6xSNHW7lvUg
+mk975CMC8hpZfn5ETjVlGXGYJ/Wa+QGiE9z5ODx8gt6cB/DXnLdrtRqbqrJeA7C0
+rScpY/0CgYBCC1QeuAiwWHOqQn3BwsZo9JQBTyT0QvDqLH/F+h9QbXep+4HvyAxG
+nTMNDtGyfyKGDaDUn5hyeU7Oxvzq0K9P+eZD3MjQeaMEg/++GPGUPmDUTqyb2UT8
+5s0NIUobxfKnTD6IpgOIq7ffvVY6cKBMyuLmu/gSvscsbONHjKti3Q==
+-----END RSA PRIVATE KEY-----
diff --git a/folly/io/test/ShutdownSocketSetTest.cpp b/folly/io/test/ShutdownSocketSetTest.cpp
new file mode 100644 (file)
index 0000000..76a598f
--- /dev/null
@@ -0,0 +1,231 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/ShutdownSocketSet.h>
+
+#include <atomic>
+#include <chrono>
+#include <thread>
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <sys/socket.h>
+
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+
+using folly::ShutdownSocketSet;
+
+namespace folly { namespace test {
+
+ShutdownSocketSet shutdownSocketSet;
+
+class Server {
+ public:
+  Server();
+
+  void stop(bool abortive);
+  void join();
+  int port() const { return port_; }
+  int closeClients(bool abortive);
+
+ private:
+  int acceptSocket_;
+  int port_;
+  enum StopMode {
+    NO_STOP,
+    ORDERLY,
+    ABORTIVE
+  };
+  std::atomic<StopMode> stop_;
+  std::thread serverThread_;
+  std::vector<int> fds_;
+};
+
+Server::Server()
+  : acceptSocket_(-1),
+    port_(0),
+    stop_(NO_STOP) {
+  acceptSocket_ = socket(PF_INET, SOCK_STREAM, 0);
+  CHECK_ERR(acceptSocket_);
+  shutdownSocketSet.add(acceptSocket_);
+
+  sockaddr_in addr;
+  addr.sin_family = AF_INET;
+  addr.sin_port = 0;
+  addr.sin_addr.s_addr = INADDR_ANY;
+  CHECK_ERR(bind(acceptSocket_,
+                 reinterpret_cast<const sockaddr*>(&addr),
+                 sizeof(addr)));
+
+  CHECK_ERR(listen(acceptSocket_, 10));
+
+  socklen_t addrLen = sizeof(addr);
+  CHECK_ERR(getsockname(acceptSocket_,
+                        reinterpret_cast<sockaddr*>(&addr),
+                        &addrLen));
+
+  port_ = ntohs(addr.sin_port);
+
+  serverThread_ = std::thread([this] {
+    while (stop_ == NO_STOP) {
+      sockaddr_in peer;
+      socklen_t peerLen = sizeof(peer);
+      int fd = accept(acceptSocket_,
+                      reinterpret_cast<sockaddr*>(&peer),
+                      &peerLen);
+      if (fd == -1) {
+        if (errno == EINTR) {
+          continue;
+        }
+        if (errno == EINVAL || errno == ENOTSOCK) {  // socket broken
+          break;
+        }
+      }
+      CHECK_ERR(fd);
+      shutdownSocketSet.add(fd);
+      fds_.push_back(fd);
+    }
+
+    if (stop_ != NO_STOP) {
+      closeClients(stop_ == ABORTIVE);
+    }
+
+    shutdownSocketSet.close(acceptSocket_);
+    acceptSocket_ = -1;
+    port_ = 0;
+  });
+}
+
+int Server::closeClients(bool abortive) {
+  for (int fd : fds_) {
+    if (abortive) {
+      struct linger l = {1, 0};
+      CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)));
+    }
+    shutdownSocketSet.close(fd);
+  }
+  int n = fds_.size();
+  fds_.clear();
+  return n;
+}
+
+void Server::stop(bool abortive) {
+  stop_ = abortive ? ABORTIVE : ORDERLY;
+  shutdown(acceptSocket_, SHUT_RDWR);
+}
+
+void Server::join() {
+  serverThread_.join();
+}
+
+int createConnectedSocket(int port) {
+  int sock = socket(PF_INET, SOCK_STREAM, 0);
+  CHECK_ERR(sock);
+  sockaddr_in addr;
+  addr.sin_family = AF_INET;
+  addr.sin_port = htons(port);
+  addr.sin_addr.s_addr = htonl((127 << 24) | 1);  // XXX
+  CHECK_ERR(connect(sock,
+                    reinterpret_cast<const sockaddr*>(&addr),
+                    sizeof(addr)));
+  return sock;
+}
+
+void runCloseTest(bool abortive) {
+  Server server;
+
+  int sock = createConnectedSocket(server.port());
+
+  std::thread stopper([&server, abortive] {
+    std::this_thread::sleep_for(std::chrono::milliseconds(200));
+    server.stop(abortive);
+    server.join();
+  });
+
+  char c;
+  int r = read(sock, &c, 1);
+  if (abortive) {
+    int e = errno;
+    EXPECT_EQ(-1, r);
+    EXPECT_EQ(ECONNRESET, e);
+  } else {
+    EXPECT_EQ(0, r);
+  }
+
+  close(sock);
+
+  stopper.join();
+
+  EXPECT_EQ(0, server.closeClients(false));  // closed by server when it exited
+}
+
+TEST(ShutdownSocketSetTest, OrderlyClose) {
+  runCloseTest(false);
+}
+
+TEST(ShutdownSocketSetTest, AbortiveClose) {
+  runCloseTest(true);
+}
+
+void runKillTest(bool abortive) {
+  Server server;
+
+  int sock = createConnectedSocket(server.port());
+
+  std::thread killer([&server, abortive] {
+    std::this_thread::sleep_for(std::chrono::milliseconds(200));
+    shutdownSocketSet.shutdownAll(abortive);
+    server.join();
+  });
+
+  char c;
+  int r = read(sock, &c, 1);
+
+  // "abortive" is just a hint for ShutdownSocketSet, so accept both
+  // behaviors
+  if (abortive) {
+    if (r == -1) {
+      EXPECT_EQ(ECONNRESET, errno);
+    } else {
+      EXPECT_EQ(r, 0);
+    }
+  } else {
+    EXPECT_EQ(0, r);
+  }
+
+  close(sock);
+
+  killer.join();
+
+  // NOT closed by server when it exited
+  EXPECT_EQ(1, server.closeClients(false));
+}
+
+TEST(ShutdownSocketSetTest, OrderlyKill) {
+  runKillTest(false);
+}
+
+TEST(ShutdownSocketSetTest, AbortiveKill) {
+  runKillTest(true);
+}
+
+}}  // namespaces
+
+int main(int argc, char *argv[]) {
+  testing::InitGoogleTest(&argc, argv);
+  google::ParseCommandLineFlags(&argc, &argv, true);
+  return RUN_ALL_TESTS();
+}