io/async/Request.h \
io/async/SSLContext.h \
io/async/TimeoutManager.h \
+ io/async/test/AsyncSSLSocketTest.h \
+ io/async/test/BlockingSocket.h \
io/async/test/TimeUtil.h \
io/async/test/UndelayedDestruction.h \
io/async/test/Util.h \
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/async/test/AsyncSSLSocketTest.h>
+
+#include <signal.h>
+#include <pthread.h>
+
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/SocketAddress.h>
+
+#include <folly/io/async/test/BlockingSocket.h>
+
+#include <gtest/gtest.h>
+#include <iostream>
+#include <list>
+#include <set>
+#include <unistd.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/tcp.h>
+#include <folly/io/Cursor.h>
+
+using std::string;
+using std::vector;
+using std::min;
+using std::cerr;
+using std::endl;
+using std::list;
+
+namespace folly {
+uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
+uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
+uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
+
+const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
+const char* testKey = "folly/io/async/test/certs/tests-key.pem";
+const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
+
+TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
+ctx_(new folly::SSLContext),
+ acb_(acb),
+ socket_(new folly::AsyncServerSocket(&evb_)) {
+ // Set up the SSL context
+ ctx_->loadCertificate(testCert);
+ ctx_->loadPrivateKey(testKey);
+ ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ acb_->ctx_ = ctx_;
+ acb_->base_ = &evb_;
+
+ //set up the listening socket
+ socket_->bind(0);
+ socket_->getAddress(&address_);
+ socket_->listen(100);
+ socket_->addAcceptCallback(acb_, &evb_);
+ socket_->startAccepting();
+
+ int ret = pthread_create(&thread_, nullptr, Main, this);
+ assert(ret == 0);
+
+ std::cerr << "Accepting connections on " << address_ << std::endl;
+}
+
+void getfds(int fds[2]) {
+ if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
+ FAIL() << "failed to create socketpair: " << strerror(errno);
+ }
+ for (int idx = 0; idx < 2; ++idx) {
+ int flags = fcntl(fds[idx], F_GETFL, 0);
+ if (flags == -1) {
+ FAIL() << "failed to get flags for socket " << idx << ": "
+ << strerror(errno);
+ }
+ if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
+ FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
+ << strerror(errno);
+ }
+ }
+}
+
+void getctx(
+ std::shared_ptr<folly::SSLContext> clientCtx,
+ std::shared_ptr<folly::SSLContext> serverCtx) {
+ clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ serverCtx->loadCertificate(
+ testCert);
+ serverCtx->loadPrivateKey(
+ testKey);
+}
+
+void sslsocketpair(
+ EventBase* eventBase,
+ AsyncSSLSocket::UniquePtr* clientSock,
+ AsyncSSLSocket::UniquePtr* serverSock) {
+ auto clientCtx = std::make_shared<folly::SSLContext>();
+ auto serverCtx = std::make_shared<folly::SSLContext>();
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, serverCtx);
+ clientSock->reset(new AsyncSSLSocket(
+ clientCtx, eventBase, fds[0], false));
+ serverSock->reset(new AsyncSSLSocket(
+ serverCtx, eventBase, fds[1], true));
+
+ // (*clientSock)->setSendTimeout(100);
+ // (*serverSock)->setSendTimeout(100);
+}
+
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ std::shared_ptr<SSLContext> sslContext(new SSLContext());
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
+ //sslContext->authenticate(true, false);
+
+ // connect
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+ sslContext);
+ socket->open();
+
+ // write()
+ uint8_t buf[128];
+ memset(buf, 'a', sizeof(buf));
+ socket->write(buf, sizeof(buf));
+
+ // read()
+ uint8_t readbuf[128];
+ uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
+
+ // close()
+ socket->close();
+
+ cerr << "ConnectWriteReadClose test completed" << endl;
+}
+
+/**
+ * Negative test for handshakeError().
+ */
+TEST(AsyncSSLSocketTest, HandshakeError) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ HandshakeErrorCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ std::shared_ptr<SSLContext> sslContext(new SSLContext());
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+ sslContext);
+ // read()
+ bool ex = false;
+ try {
+ socket->open();
+
+ uint8_t readbuf[128];
+ uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
+ } catch (AsyncSocketException &e) {
+ ex = true;
+ }
+ EXPECT_TRUE(ex);
+
+ // close()
+ socket->close();
+ cerr << "HandshakeError test completed" << endl;
+}
+
+/**
+ * Negative test for readError().
+ */
+TEST(AsyncSSLSocketTest, ReadError) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ std::shared_ptr<SSLContext> sslContext(new SSLContext());
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+ sslContext);
+ socket->open();
+
+ // write something to trigger ssl handshake
+ uint8_t buf[128];
+ memset(buf, 'a', sizeof(buf));
+ socket->write(buf, sizeof(buf));
+
+ socket->close();
+ cerr << "ReadError test completed" << endl;
+}
+
+/**
+ * Negative test for writeError().
+ */
+TEST(AsyncSSLSocketTest, WriteError) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ WriteErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ std::shared_ptr<SSLContext> sslContext(new SSLContext());
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+ sslContext);
+ socket->open();
+
+ // write something to trigger ssl handshake
+ uint8_t buf[128];
+ memset(buf, 'a', sizeof(buf));
+ socket->write(buf, sizeof(buf));
+
+ socket->close();
+ cerr << "WriteError test completed" << endl;
+}
+
+/**
+ * Test a socket with TCP_NODELAY unset.
+ */
+TEST(AsyncSSLSocketTest, SocketWithDelay) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ std::shared_ptr<SSLContext> sslContext(new SSLContext());
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+ sslContext);
+ socket->open();
+
+ // write()
+ uint8_t buf[128];
+ memset(buf, 'a', sizeof(buf));
+ socket->write(buf, sizeof(buf));
+
+ // read()
+ uint8_t readbuf[128];
+ uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
+
+ // close()
+ socket->close();
+
+ cerr << "SocketWithDelay test completed" << endl;
+}
+
+TEST(AsyncSSLSocketTest, NpnTestOverlap) {
+ EventBase eventBase;
+ std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+ std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, serverCtx);
+
+ clientCtx->setAdvertisedNextProtocols({"blub","baz"});
+ serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+ NpnClient client(std::move(clientSock));
+ NpnServer server(std::move(serverSock));
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.nextProtoLength != 0);
+ EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+ EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+ server.nextProtoLength), 0);
+ string selected((const char*)client.nextProto, client.nextProtoLength);
+ EXPECT_EQ(selected.compare("baz"), 0);
+}
+
+TEST(AsyncSSLSocketTest, NpnTestUnset) {
+ // Identical to above test, except that we want unset NPN before
+ // looping.
+ EventBase eventBase;
+ std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+ std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, serverCtx);
+
+ clientCtx->setAdvertisedNextProtocols({"blub","baz"});
+ serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+ // unsetting NPN for any of [client, server] is enought to make NPN not
+ // work
+ clientCtx->unsetNextProtocols();
+
+ NpnClient client(std::move(clientSock));
+ NpnServer server(std::move(serverSock));
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.nextProtoLength == 0);
+ EXPECT_TRUE(server.nextProtoLength == 0);
+ EXPECT_TRUE(client.nextProto == nullptr);
+ EXPECT_TRUE(server.nextProto == nullptr);
+}
+
+TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
+ EventBase eventBase;
+ std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+ std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, serverCtx);
+
+ clientCtx->setAdvertisedNextProtocols({"blub"});
+ serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+ NpnClient client(std::move(clientSock));
+ NpnServer server(std::move(serverSock));
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.nextProtoLength != 0);
+ EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+ EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+ server.nextProtoLength), 0);
+ string selected((const char*)client.nextProto, client.nextProtoLength);
+ EXPECT_EQ(selected.compare("blub"), 0);
+}
+
+TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
+ // Probability that this test will fail is 2^-64, which could be considered
+ // as negligible.
+ const int kTries = 64;
+
+ std::set<string> selectedProtocols;
+ for (int i = 0; i < kTries; ++i) {
+ EventBase eventBase;
+ std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
+ std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, serverCtx);
+
+ clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
+ serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
+ {1, {"bar"}}});
+
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+ NpnClient client(std::move(clientSock));
+ NpnServer server(std::move(serverSock));
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.nextProtoLength != 0);
+ EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+ EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+ server.nextProtoLength), 0);
+ string selected((const char*)client.nextProto, client.nextProtoLength);
+ selectedProtocols.insert(selected);
+ }
+ EXPECT_EQ(selectedProtocols.size(), 2);
+}
+
+
+#ifndef OPENSSL_NO_TLSEXT
+/**
+ * 1. Client sends TLSEXT_HOSTNAME in client hello.
+ * 2. Server found a match SSL_CTX and use this SSL_CTX to
+ * continue the SSL handshake.
+ * 3. Server sends back TLSEXT_HOSTNAME in server hello.
+ */
+TEST(AsyncSSLSocketTest, SNITestMatch) {
+ EventBase eventBase;
+ std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+ std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
+ // Use the same SSLContext to continue the handshake after
+ // tlsext_hostname match.
+ std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
+ const std::string serverName("xyz.newdev.facebook.com");
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ SNIClient client(std::move(clientSock));
+ SNIServer server(std::move(serverSock),
+ dfServerCtx,
+ hskServerCtx,
+ serverName);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.serverNameMatch);
+ EXPECT_TRUE(server.serverNameMatch);
+}
+
+/**
+ * 1. Client sends TLSEXT_HOSTNAME in client hello.
+ * 2. Server cannot find a matching SSL_CTX and continue to use
+ * the current SSL_CTX to do the handshake.
+ * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
+ */
+TEST(AsyncSSLSocketTest, SNITestNotMatch) {
+ EventBase eventBase;
+ std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+ std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
+ // Use the same SSLContext to continue the handshake after
+ // tlsext_hostname match.
+ std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
+ const std::string clientRequestingServerName("foo.com");
+ const std::string serverExpectedServerName("xyz.newdev.facebook.com");
+
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx,
+ &eventBase,
+ fds[0],
+ clientRequestingServerName));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ SNIClient client(std::move(clientSock));
+ SNIServer server(std::move(serverSock),
+ dfServerCtx,
+ hskServerCtx,
+ serverExpectedServerName);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(!client.serverNameMatch);
+ EXPECT_TRUE(!server.serverNameMatch);
+}
+
+/**
+ * 1. Client does not send TLSEXT_HOSTNAME in client hello.
+ * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
+ */
+TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
+ EventBase eventBase;
+ std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+ std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
+ // Use the same SSLContext to continue the handshake after
+ // tlsext_hostname match.
+ std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
+ const std::string serverExpectedServerName("xyz.newdev.facebook.com");
+
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ SNIClient client(std::move(clientSock));
+ SNIServer server(std::move(serverSock),
+ dfServerCtx,
+ hskServerCtx,
+ serverExpectedServerName);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(!client.serverNameMatch);
+ EXPECT_TRUE(!server.serverNameMatch);
+}
+
+#endif
+/**
+ * Test SSL client socket
+ */
+TEST(AsyncSSLSocketTest, SSLClientTest) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL client
+ EventBase eventBase;
+ std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+ 1));
+
+ client->connect();
+ EventBaseAborter eba(&eventBase, 3000);
+ eventBase.loop();
+
+ EXPECT_EQ(client->getMiss(), 1);
+ EXPECT_EQ(client->getHit(), 0);
+
+ cerr << "SSLClientTest test completed" << endl;
+}
+
+
+/**
+ * Test SSL client socket session re-use
+ */
+TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL client
+ EventBase eventBase;
+ std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+ 10));
+
+ client->connect();
+ EventBaseAborter eba(&eventBase, 3000);
+ eventBase.loop();
+
+ EXPECT_EQ(client->getMiss(), 1);
+ EXPECT_EQ(client->getHit(), 9);
+
+ cerr << "SSLClientTestReuse test completed" << endl;
+}
+
+/**
+ * Test SSL client socket timeout
+ */
+TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
+ // Start listening on a local port
+ EmptyReadCallback readCallback;
+ HandshakeCallback handshakeCallback(&readCallback,
+ HandshakeCallback::EXPECT_ERROR);
+ HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL client
+ EventBase eventBase;
+ std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+ 1, 10));
+ client->connect(true /* write before connect completes */);
+ EventBaseAborter eba(&eventBase, 3000);
+ eventBase.loop();
+
+ usleep(100000);
+ // This is checking that the connectError callback precedes any queued
+ // writeError callbacks. This matches AsyncSocket's behavior
+ EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
+ EXPECT_EQ(client->getErrors(), 1);
+ EXPECT_EQ(client->getMiss(), 0);
+ EXPECT_EQ(client->getHit(), 0);
+
+ cerr << "SSLClientTimeoutTest test completed" << endl;
+}
+
+
+/**
+ * Test SSL server async cache
+ */
+TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLAsyncCacheServer server(&acceptCallback);
+
+ // Set up SSL client
+ EventBase eventBase;
+ std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+ 10, 500));
+
+ client->connect();
+ EventBaseAborter eba(&eventBase, 3000);
+ eventBase.loop();
+
+ EXPECT_EQ(server.getAsyncCallbacks(), 18);
+ EXPECT_EQ(server.getAsyncLookups(), 9);
+ EXPECT_EQ(client->getMiss(), 10);
+ EXPECT_EQ(client->getHit(), 0);
+
+ cerr << "SSLServerAsyncCacheTest test completed" << endl;
+}
+
+
+/**
+ * Test SSL server accept timeout with cache path
+ */
+TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ EmptyReadCallback clientReadCallback;
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
+ TestSSLAsyncCacheServer server(&acceptCallback);
+
+ // Set up SSL client
+ EventBase eventBase;
+ // only do a TCP connect
+ std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
+ sock->connect(nullptr, server.getAddress());
+ clientReadCallback.tcpSocket_ = sock;
+ sock->setReadCB(&clientReadCallback);
+
+ EventBaseAborter eba(&eventBase, 3000);
+ eventBase.loop();
+
+ EXPECT_EQ(readCallback.state, STATE_WAITING);
+
+ cerr << "SSLServerTimeoutTest test completed" << endl;
+}
+
+/**
+ * Test SSL server accept timeout with cache path
+ */
+TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
+ TestSSLAsyncCacheServer server(&acceptCallback);
+
+ // Set up SSL client
+ EventBase eventBase;
+ std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+ 2));
+
+ client->connect();
+ EventBaseAborter eba(&eventBase, 3000);
+ eventBase.loop();
+
+ EXPECT_EQ(server.getAsyncCallbacks(), 1);
+ EXPECT_EQ(server.getAsyncLookups(), 1);
+ EXPECT_EQ(client->getErrors(), 1);
+ EXPECT_EQ(client->getMiss(), 1);
+ EXPECT_EQ(client->getHit(), 0);
+
+ cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
+}
+
+/**
+ * Test SSL server accept timeout with cache path
+ */
+TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback,
+ HandshakeCallback::EXPECT_ERROR);
+ SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLAsyncCacheServer server(&acceptCallback, 500);
+
+ // Set up SSL client
+ EventBase eventBase;
+ std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
+ 2, 100));
+
+ client->connect();
+ EventBaseAborter eba(&eventBase, 3000);
+ eventBase.loop();
+
+ server.getEventBase().runInEventBaseThread([&handshakeCallback]{
+ handshakeCallback.closeSocket();});
+ // give time for the cache lookup to come back and find it closed
+ usleep(500000);
+
+ EXPECT_EQ(server.getAsyncCallbacks(), 1);
+ EXPECT_EQ(server.getAsyncLookups(), 1);
+ EXPECT_EQ(client->getErrors(), 1);
+ EXPECT_EQ(client->getMiss(), 1);
+ EXPECT_EQ(client->getHit(), 0);
+
+ cerr << "SSLServerCacheCloseTest test completed" << endl;
+}
+
+/**
+ * Verify Client Ciphers obtained using SSL MSG Callback.
+ */
+TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto serverCtx = std::make_shared<SSLContext>();
+ serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+ serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
+ serverCtx->loadPrivateKey(testKey);
+ serverCtx->loadCertificate(testCert);
+ serverCtx->loadTrustedCertificates(testCA);
+ serverCtx->loadClientCAList(testCA);
+
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+ clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
+ clientCtx->loadPrivateKey(testKey);
+ clientCtx->loadCertificate(testCert);
+ clientCtx->loadTrustedCertificates(testCA);
+
+ int fds[2];
+ getfds(fds);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClient client(std::move(clientSock), true, true);
+ SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
+
+ eventBase.loop();
+
+ EXPECT_EQ(server.clientCiphers_,
+ "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
+ EXPECT_TRUE(client.handshakeVerify_);
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(!client.handshakeError_);
+ EXPECT_TRUE(server.handshakeVerify_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_TRUE(!server.handshakeError_);
+}
+
+TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
+ EventBase eventBase;
+ auto ctx = std::make_shared<SSLContext>();
+
+ int fds[2];
+ getfds(fds);
+
+ int bufLen = 42;
+ uint8_t majorVersion = 18;
+ uint8_t minorVersion = 25;
+
+ // Create callback buf
+ auto buf = IOBuf::create(bufLen);
+ buf->append(bufLen);
+ folly::io::RWPrivateCursor cursor(buf.get());
+ cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
+ cursor.write<uint16_t>(0);
+ cursor.write<uint8_t>(38);
+ cursor.write<uint8_t>(majorVersion);
+ cursor.write<uint8_t>(minorVersion);
+ cursor.skip(32);
+ cursor.write<uint32_t>(0);
+
+ SSL* ssl = ctx->createSSL();
+ AsyncSSLSocket::UniquePtr sock(
+ new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
+ sock->enableClientHelloParsing();
+
+ // Test client hello parsing in one packet
+ AsyncSSLSocket::clientHelloParsingCallback(
+ 0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
+ buf.reset();
+
+ auto parsedClientHello = sock->getClientHelloInfo();
+ EXPECT_TRUE(parsedClientHello != nullptr);
+ EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
+ EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
+}
+
+TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
+ EventBase eventBase;
+ auto ctx = std::make_shared<SSLContext>();
+
+ int fds[2];
+ getfds(fds);
+
+ int bufLen = 42;
+ uint8_t majorVersion = 18;
+ uint8_t minorVersion = 25;
+
+ // Create callback buf
+ auto buf = IOBuf::create(bufLen);
+ buf->append(bufLen);
+ folly::io::RWPrivateCursor cursor(buf.get());
+ cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
+ cursor.write<uint16_t>(0);
+ cursor.write<uint8_t>(38);
+ cursor.write<uint8_t>(majorVersion);
+ cursor.write<uint8_t>(minorVersion);
+ cursor.skip(32);
+ cursor.write<uint32_t>(0);
+
+ SSL* ssl = ctx->createSSL();
+ AsyncSSLSocket::UniquePtr sock(
+ new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
+ sock->enableClientHelloParsing();
+
+ // Test parsing with two packets with first packet size < 3
+ auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
+ AsyncSSLSocket::clientHelloParsingCallback(
+ 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
+ ssl, sock.get());
+ bufCopy.reset();
+ bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
+ AsyncSSLSocket::clientHelloParsingCallback(
+ 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
+ ssl, sock.get());
+ bufCopy.reset();
+
+ auto parsedClientHello = sock->getClientHelloInfo();
+ EXPECT_TRUE(parsedClientHello != nullptr);
+ EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
+ EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
+}
+
+TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
+ EventBase eventBase;
+ auto ctx = std::make_shared<SSLContext>();
+
+ int fds[2];
+ getfds(fds);
+
+ int bufLen = 42;
+ uint8_t majorVersion = 18;
+ uint8_t minorVersion = 25;
+
+ // Create callback buf
+ auto buf = IOBuf::create(bufLen);
+ buf->append(bufLen);
+ folly::io::RWPrivateCursor cursor(buf.get());
+ cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
+ cursor.write<uint16_t>(0);
+ cursor.write<uint8_t>(38);
+ cursor.write<uint8_t>(majorVersion);
+ cursor.write<uint8_t>(minorVersion);
+ cursor.skip(32);
+ cursor.write<uint32_t>(0);
+
+ SSL* ssl = ctx->createSSL();
+ AsyncSSLSocket::UniquePtr sock(
+ new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
+ sock->enableClientHelloParsing();
+
+ // Test parsing with multiple small packets
+ for (uint64_t i = 0; i < buf->length(); i += 3) {
+ auto bufCopy = folly::IOBuf::copyBuffer(
+ buf->data() + i, std::min((uint64_t)3, buf->length() - i));
+ AsyncSSLSocket::clientHelloParsingCallback(
+ 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
+ ssl, sock.get());
+ bufCopy.reset();
+ }
+
+ auto parsedClientHello = sock->getClientHelloInfo();
+ EXPECT_TRUE(parsedClientHello != nullptr);
+ EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
+ EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
+}
+
+/**
+ * Verify sucessful behavior of SSL certificate validation.
+ */
+TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, dfServerCtx);
+
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+ dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClient client(std::move(clientSock), true, true);
+ clientCtx->loadTrustedCertificates(testCA);
+
+ SSLHandshakeServer server(std::move(serverSock), true, true);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.handshakeVerify_);
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(!client.handshakeError_);
+ EXPECT_TRUE(!server.handshakeVerify_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_TRUE(!server.handshakeError_);
+}
+
+/**
+ * Verify that the client's verification callback is able to fail SSL
+ * connection establishment.
+ */
+TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, dfServerCtx);
+
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+ dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClient client(std::move(clientSock), true, false);
+ clientCtx->loadTrustedCertificates(testCA);
+
+ SSLHandshakeServer server(std::move(serverSock), true, true);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.handshakeVerify_);
+ EXPECT_TRUE(!client.handshakeSuccess_);
+ EXPECT_TRUE(client.handshakeError_);
+ EXPECT_TRUE(!server.handshakeVerify_);
+ EXPECT_TRUE(!server.handshakeSuccess_);
+ EXPECT_TRUE(server.handshakeError_);
+}
+
+/**
+ * Verify that the options in SSLContext can be overridden in
+ * sslConnect/Accept.i.e specifying that no validation should be performed
+ * allows an otherwise-invalid certificate to be accepted and doesn't fire
+ * the validation callback.
+ */
+TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, dfServerCtx);
+
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+ dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
+ clientCtx->loadTrustedCertificates(testCA);
+
+ SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(!client.handshakeVerify_);
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(!client.handshakeError_);
+ EXPECT_TRUE(!server.handshakeVerify_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_TRUE(!server.handshakeError_);
+}
+
+/**
+ * Verify that the options in SSLContext can be overridden in
+ * sslConnect/Accept. Enable verification even if context says otherwise.
+ * Test requireClientCert with client cert
+ */
+TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto serverCtx = std::make_shared<SSLContext>();
+ serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+ serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ serverCtx->loadPrivateKey(testKey);
+ serverCtx->loadCertificate(testCert);
+ serverCtx->loadTrustedCertificates(testCA);
+ serverCtx->loadClientCAList(testCA);
+
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+ clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ clientCtx->loadPrivateKey(testKey);
+ clientCtx->loadCertificate(testCert);
+ clientCtx->loadTrustedCertificates(testCA);
+
+ int fds[2];
+ getfds(fds);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
+ SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.handshakeVerify_);
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_FALSE(client.handshakeError_);
+ EXPECT_TRUE(server.handshakeVerify_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_FALSE(server.handshakeError_);
+}
+
+/**
+ * Verify that the client's verification callback is able to override
+ * the preverification failure and allow a successful connection.
+ */
+TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, dfServerCtx);
+
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+ dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClient client(std::move(clientSock), false, true);
+ SSLHandshakeServer server(std::move(serverSock), true, true);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.handshakeVerify_);
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(!client.handshakeError_);
+ EXPECT_TRUE(!server.handshakeVerify_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_TRUE(!server.handshakeError_);
+}
+
+/**
+ * Verify that specifying that no validation should be performed allows an
+ * otherwise-invalid certificate to be accepted and doesn't fire the validation
+ * callback.
+ */
+TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, dfServerCtx);
+
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+ dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClient client(std::move(clientSock), false, false);
+ SSLHandshakeServer server(std::move(serverSock), false, false);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(!client.handshakeVerify_);
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(!client.handshakeError_);
+ EXPECT_TRUE(!server.handshakeVerify_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_TRUE(!server.handshakeError_);
+}
+
+/**
+ * Test requireClientCert with client cert
+ */
+TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto serverCtx = std::make_shared<SSLContext>();
+ serverCtx->setVerificationOption(
+ SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+ serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ serverCtx->loadPrivateKey(testKey);
+ serverCtx->loadCertificate(testCert);
+ serverCtx->loadTrustedCertificates(testCA);
+ serverCtx->loadClientCAList(testCA);
+
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+ clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ clientCtx->loadPrivateKey(testKey);
+ clientCtx->loadCertificate(testCert);
+ clientCtx->loadTrustedCertificates(testCA);
+
+ int fds[2];
+ getfds(fds);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClient client(std::move(clientSock), true, true);
+ SSLHandshakeServer server(std::move(serverSock), true, true);
+
+ eventBase.loop();
+
+ EXPECT_TRUE(client.handshakeVerify_);
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_FALSE(client.handshakeError_);
+ EXPECT_TRUE(server.handshakeVerify_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_FALSE(server.handshakeError_);
+}
+
+
+/**
+ * Test requireClientCert with no client cert
+ */
+TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto serverCtx = std::make_shared<SSLContext>();
+ serverCtx->setVerificationOption(
+ SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+ serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ serverCtx->loadPrivateKey(testKey);
+ serverCtx->loadCertificate(testCert);
+ serverCtx->loadTrustedCertificates(testCA);
+ serverCtx->loadClientCAList(testCA);
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+ clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ int fds[2];
+ getfds(fds);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClient client(std::move(clientSock), false, false);
+ SSLHandshakeServer server(std::move(serverSock), false, false);
+
+ eventBase.loop();
+
+ EXPECT_FALSE(server.handshakeVerify_);
+ EXPECT_FALSE(server.handshakeSuccess_);
+ EXPECT_TRUE(server.handshakeError_);
+}
+}
+
+///////////////////////////////////////////////////////////////////////////
+// init_unit_test_suite
+///////////////////////////////////////////////////////////////////////////
+namespace {
+struct Initializer {
+ Initializer() {
+ signal(SIGPIPE, SIG_IGN);
+ }
+};
+Initializer initializer;
+} // anonymous
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <signal.h>
+#include <pthread.h>
+
+#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTransport.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/io/async/AsyncTimeout.h>
+#include <folly/SocketAddress.h>
+
+#include <gtest/gtest.h>
+#include <iostream>
+#include <list>
+#include <unistd.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/tcp.h>
+
+namespace folly {
+
+enum StateEnum {
+ STATE_WAITING,
+ STATE_SUCCEEDED,
+ STATE_FAILED
+};
+
+// The destructors of all callback classes assert that the state is
+// STATE_SUCCEEDED, for both possitive and negative tests. The tests
+// are responsible for setting the succeeded state properly before the
+// destructors are called.
+
+class WriteCallbackBase :
+public AsyncTransportWrapper::WriteCallback {
+public:
+ WriteCallbackBase()
+ : state(STATE_WAITING)
+ , bytesWritten(0)
+ , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+ ~WriteCallbackBase() {
+ EXPECT_EQ(state, STATE_SUCCEEDED);
+ }
+
+ void setSocket(
+ const std::shared_ptr<AsyncSSLSocket> &socket) {
+ socket_ = socket;
+ }
+
+ void writeSuccess() noexcept override {
+ std::cerr << "writeSuccess" << std::endl;
+ state = STATE_SUCCEEDED;
+ }
+
+ void writeErr(
+ size_t bytesWritten,
+ const AsyncSocketException& ex) noexcept override {
+ std::cerr << "writeError: bytesWritten " << bytesWritten
+ << ", exception " << ex.what() << std::endl;
+
+ state = STATE_FAILED;
+ this->bytesWritten = bytesWritten;
+ exception = ex;
+ socket_->close();
+ socket_->detachEventBase();
+ }
+
+ std::shared_ptr<AsyncSSLSocket> socket_;
+ StateEnum state;
+ size_t bytesWritten;
+ AsyncSocketException exception;
+};
+
+class ReadCallbackBase :
+public AsyncTransportWrapper::ReadCallback {
+public:
+ explicit ReadCallbackBase(WriteCallbackBase *wcb)
+ : wcb_(wcb)
+ , state(STATE_WAITING) {}
+
+ ~ReadCallbackBase() {
+ EXPECT_EQ(state, STATE_SUCCEEDED);
+ }
+
+ void setSocket(
+ const std::shared_ptr<AsyncSSLSocket> &socket) {
+ socket_ = socket;
+ }
+
+ void setState(StateEnum s) {
+ state = s;
+ if (wcb_) {
+ wcb_->state = s;
+ }
+ }
+
+ void readErr(
+ const AsyncSocketException& ex) noexcept override {
+ std::cerr << "readError " << ex.what() << std::endl;
+ state = STATE_FAILED;
+ socket_->close();
+ socket_->detachEventBase();
+ }
+
+ void readEOF() noexcept override {
+ std::cerr << "readEOF" << std::endl;
+
+ socket_->close();
+ socket_->detachEventBase();
+ }
+
+ std::shared_ptr<AsyncSSLSocket> socket_;
+ WriteCallbackBase *wcb_;
+ StateEnum state;
+};
+
+class ReadCallback : public ReadCallbackBase {
+public:
+ explicit ReadCallback(WriteCallbackBase *wcb)
+ : ReadCallbackBase(wcb)
+ , buffers() {}
+
+ ~ReadCallback() {
+ for (std::vector<Buffer>::iterator it = buffers.begin();
+ it != buffers.end();
+ ++it) {
+ it->free();
+ }
+ currentBuffer.free();
+ }
+
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ if (!currentBuffer.buffer) {
+ currentBuffer.allocate(4096);
+ }
+ *bufReturn = currentBuffer.buffer;
+ *lenReturn = currentBuffer.length;
+ }
+
+ void readDataAvailable(size_t len) noexcept override {
+ std::cerr << "readDataAvailable, len " << len << std::endl;
+
+ currentBuffer.length = len;
+
+ wcb_->setSocket(socket_);
+
+ // Write back the same data.
+ socket_->write(wcb_, currentBuffer.buffer, len);
+
+ buffers.push_back(currentBuffer);
+ currentBuffer.reset();
+ state = STATE_SUCCEEDED;
+ }
+
+ class Buffer {
+ public:
+ Buffer() : buffer(nullptr), length(0) {}
+ Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
+
+ void reset() {
+ buffer = nullptr;
+ length = 0;
+ }
+ void allocate(size_t length) {
+ assert(buffer == nullptr);
+ this->buffer = static_cast<char*>(malloc(length));
+ this->length = length;
+ }
+ void free() {
+ ::free(buffer);
+ reset();
+ }
+
+ char* buffer;
+ size_t length;
+ };
+
+ std::vector<Buffer> buffers;
+ Buffer currentBuffer;
+};
+
+class ReadErrorCallback : public ReadCallbackBase {
+public:
+ explicit ReadErrorCallback(WriteCallbackBase *wcb)
+ : ReadCallbackBase(wcb) {}
+
+ // Return nullptr buffer to trigger readError()
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *bufReturn = nullptr;
+ *lenReturn = 0;
+ }
+
+ void readDataAvailable(size_t len) noexcept override {
+ // This should never to called.
+ FAIL();
+ }
+
+ void readErr(
+ const AsyncSocketException& ex) noexcept override {
+ ReadCallbackBase::readErr(ex);
+ std::cerr << "ReadErrorCallback::readError" << std::endl;
+ setState(STATE_SUCCEEDED);
+ }
+};
+
+class WriteErrorCallback : public ReadCallback {
+public:
+ explicit WriteErrorCallback(WriteCallbackBase *wcb)
+ : ReadCallback(wcb) {}
+
+ void readDataAvailable(size_t len) noexcept override {
+ std::cerr << "readDataAvailable, len " << len << std::endl;
+
+ currentBuffer.length = len;
+
+ // close the socket before writing to trigger writeError().
+ ::close(socket_->getFd());
+
+ wcb_->setSocket(socket_);
+
+ // Write back the same data.
+ socket_->write(wcb_, currentBuffer.buffer, len);
+
+ if (wcb_->state == STATE_FAILED) {
+ setState(STATE_SUCCEEDED);
+ } else {
+ state = STATE_FAILED;
+ }
+
+ buffers.push_back(currentBuffer);
+ currentBuffer.reset();
+ }
+
+ void readErr(const AsyncSocketException& ex) noexcept override {
+ std::cerr << "readError " << ex.what() << std::endl;
+ // do nothing since this is expected
+ }
+};
+
+class EmptyReadCallback : public ReadCallback {
+public:
+ explicit EmptyReadCallback()
+ : ReadCallback(nullptr) {}
+
+ void readErr(const AsyncSocketException& ex) noexcept override {
+ std::cerr << "readError " << ex.what() << std::endl;
+ state = STATE_FAILED;
+ tcpSocket_->close();
+ tcpSocket_->detachEventBase();
+ }
+
+ void readEOF() noexcept override {
+ std::cerr << "readEOF" << std::endl;
+
+ tcpSocket_->close();
+ tcpSocket_->detachEventBase();
+ state = STATE_SUCCEEDED;
+ }
+
+ std::shared_ptr<AsyncSocket> tcpSocket_;
+};
+
+class HandshakeCallback :
+public AsyncSSLSocket::HandshakeCB {
+public:
+ enum ExpectType {
+ EXPECT_SUCCESS,
+ EXPECT_ERROR
+ };
+
+ explicit HandshakeCallback(ReadCallbackBase *rcb,
+ ExpectType expect = EXPECT_SUCCESS):
+ state(STATE_WAITING),
+ rcb_(rcb),
+ expect_(expect) {}
+
+ void setSocket(
+ const std::shared_ptr<AsyncSSLSocket> &socket) {
+ socket_ = socket;
+ }
+
+ void setState(StateEnum s) {
+ state = s;
+ rcb_->setState(s);
+ }
+
+ // Functions inherited from AsyncSSLSocketHandshakeCallback
+ void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
+ EXPECT_EQ(sock, socket_.get());
+ std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
+ rcb_->setSocket(socket_);
+ sock->setReadCB(rcb_);
+ state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
+ }
+ void handshakeErr(
+ AsyncSSLSocket *sock,
+ const AsyncSocketException& ex) noexcept override {
+ std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
+ state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
+ if (expect_ == EXPECT_ERROR) {
+ // rcb will never be invoked
+ rcb_->setState(STATE_SUCCEEDED);
+ }
+ }
+
+ ~HandshakeCallback() {
+ EXPECT_EQ(state, STATE_SUCCEEDED);
+ }
+
+ void closeSocket() {
+ socket_->close();
+ state = STATE_SUCCEEDED;
+ }
+
+ StateEnum state;
+ std::shared_ptr<AsyncSSLSocket> socket_;
+ ReadCallbackBase *rcb_;
+ ExpectType expect_;
+};
+
+class SSLServerAcceptCallbackBase:
+public folly::AsyncServerSocket::AcceptCallback {
+public:
+ explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
+ state(STATE_WAITING), hcb_(hcb) {}
+
+ ~SSLServerAcceptCallbackBase() {
+ EXPECT_EQ(state, STATE_SUCCEEDED);
+ }
+
+ void acceptError(const std::exception& ex) noexcept override {
+ std::cerr << "SSLServerAcceptCallbackBase::acceptError "
+ << ex.what() << std::endl;
+ state = STATE_FAILED;
+ }
+
+ void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
+ noexcept override{
+ printf("Connection accepted\n");
+ std::shared_ptr<AsyncSSLSocket> sslSock;
+ try {
+ // Create a AsyncSSLSocket object with the fd. The socket should be
+ // added to the event base and in the state of accepting SSL connection.
+ sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
+ } catch (const std::exception &e) {
+ LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
+ "object with socket " << e.what() << fd;
+ ::close(fd);
+ acceptError(e);
+ return;
+ }
+
+ connAccepted(sslSock);
+ }
+
+ virtual void connAccepted(
+ const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
+
+ StateEnum state;
+ HandshakeCallback *hcb_;
+ std::shared_ptr<folly::SSLContext> ctx_;
+ folly::EventBase* base_;
+};
+
+class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
+public:
+ uint32_t timeout_;
+
+ explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
+ uint32_t timeout = 0):
+ SSLServerAcceptCallbackBase(hcb),
+ timeout_(timeout) {}
+
+ virtual ~SSLServerAcceptCallback() {
+ if (timeout_ > 0) {
+ // if we set a timeout, we expect failure
+ EXPECT_EQ(hcb_->state, STATE_FAILED);
+ hcb_->setState(STATE_SUCCEEDED);
+ }
+ }
+
+ // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+ void connAccepted(
+ const std::shared_ptr<folly::AsyncSSLSocket> &s)
+ noexcept override {
+ auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+ std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
+
+ hcb_->setSocket(sock);
+ sock->sslAccept(hcb_, timeout_);
+ EXPECT_EQ(sock->getSSLState(),
+ AsyncSSLSocket::STATE_ACCEPTING);
+
+ state = STATE_SUCCEEDED;
+ }
+};
+
+class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
+public:
+ explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
+ SSLServerAcceptCallback(hcb) {}
+
+ // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+ void connAccepted(
+ const std::shared_ptr<folly::AsyncSSLSocket> &s)
+ noexcept override {
+
+ auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+
+ std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
+ << std::endl;
+ int fd = sock->getFd();
+
+#ifndef TCP_NOPUSH
+ {
+ // The accepted connection should already have TCP_NODELAY set
+ int value;
+ socklen_t valueLength = sizeof(value);
+ int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
+ EXPECT_EQ(rc, 0);
+ EXPECT_EQ(value, 1);
+ }
+#endif
+
+ // Unset the TCP_NODELAY option.
+ int value = 0;
+ socklen_t valueLength = sizeof(value);
+ int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
+ EXPECT_EQ(rc, 0);
+
+ rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
+ EXPECT_EQ(rc, 0);
+ EXPECT_EQ(value, 0);
+
+ SSLServerAcceptCallback::connAccepted(sock);
+ }
+};
+
+class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
+public:
+ explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
+ uint32_t timeout = 0):
+ SSLServerAcceptCallback(hcb, timeout) {}
+
+ // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+ void connAccepted(
+ const std::shared_ptr<folly::AsyncSSLSocket> &s)
+ noexcept override {
+ auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+
+ std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
+
+ hcb_->setSocket(sock);
+ sock->sslAccept(hcb_, timeout_);
+ ASSERT_TRUE((sock->getSSLState() ==
+ AsyncSSLSocket::STATE_ACCEPTING) ||
+ (sock->getSSLState() ==
+ AsyncSSLSocket::STATE_CACHE_LOOKUP));
+
+ state = STATE_SUCCEEDED;
+ }
+};
+
+
+class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
+public:
+ explicit HandshakeErrorCallback(HandshakeCallback *hcb):
+ SSLServerAcceptCallbackBase(hcb) {}
+
+ // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+ void connAccepted(
+ const std::shared_ptr<folly::AsyncSSLSocket> &s)
+ noexcept override {
+ auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+
+ std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
+
+ // The first call to sslAccept() should succeed.
+ hcb_->setSocket(sock);
+ sock->sslAccept(hcb_);
+ EXPECT_EQ(sock->getSSLState(),
+ AsyncSSLSocket::STATE_ACCEPTING);
+
+ // The second call to sslAccept() should fail.
+ HandshakeCallback callback2(hcb_->rcb_);
+ callback2.setSocket(sock);
+ sock->sslAccept(&callback2);
+ EXPECT_EQ(sock->getSSLState(),
+ AsyncSSLSocket::STATE_ERROR);
+
+ // Both callbacks should be in the error state.
+ EXPECT_EQ(hcb_->state, STATE_FAILED);
+ EXPECT_EQ(callback2.state, STATE_FAILED);
+
+ sock->detachEventBase();
+
+ state = STATE_SUCCEEDED;
+ hcb_->setState(STATE_SUCCEEDED);
+ callback2.setState(STATE_SUCCEEDED);
+ }
+};
+
+class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
+public:
+ explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
+ SSLServerAcceptCallbackBase(hcb) {}
+
+ // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+ void connAccepted(
+ const std::shared_ptr<folly::AsyncSSLSocket> &s)
+ noexcept override {
+ std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
+
+ auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
+
+ hcb_->setSocket(sock);
+ sock->getEventBase()->tryRunAfterDelay([=] {
+ std::cerr << "Delayed SSL accept, client will have close by now"
+ << std::endl;
+ // SSL accept will fail
+ EXPECT_EQ(
+ sock->getSSLState(),
+ AsyncSSLSocket::STATE_UNINIT);
+ hcb_->socket_->sslAccept(hcb_);
+ // This registers for an event
+ EXPECT_EQ(
+ sock->getSSLState(),
+ AsyncSSLSocket::STATE_ACCEPTING);
+
+ state = STATE_SUCCEEDED;
+ }, 100);
+ }
+};
+
+
+class TestSSLServer {
+ protected:
+ EventBase evb_;
+ std::shared_ptr<folly::SSLContext> ctx_;
+ SSLServerAcceptCallbackBase *acb_;
+ folly::AsyncServerSocket *socket_;
+ folly::SocketAddress address_;
+ pthread_t thread_;
+
+ static void *Main(void *ctx) {
+ TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
+ self->evb_.loop();
+ std::cerr << "Server thread exited event loop" << std::endl;
+ return nullptr;
+ }
+
+ public:
+ // Create a TestSSLServer.
+ // This immediately starts listening on the given port.
+ explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
+
+ // Kill the thread.
+ ~TestSSLServer() {
+ evb_.runInEventBaseThread([&](){
+ socket_->stopAccepting();
+ });
+ std::cerr << "Waiting for server thread to exit" << std::endl;
+ pthread_join(thread_, nullptr);
+ }
+
+ EventBase &getEventBase() { return evb_; }
+
+ const folly::SocketAddress& getAddress() const {
+ return address_;
+ }
+};
+
+class TestSSLAsyncCacheServer : public TestSSLServer {
+ public:
+ explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
+ int lookupDelay = 100) :
+ TestSSLServer(acb) {
+ SSL_CTX *sslCtx = ctx_->getSSLCtx();
+ SSL_CTX_sess_set_get_cb(sslCtx,
+ TestSSLAsyncCacheServer::getSessionCallback);
+ SSL_CTX_set_session_cache_mode(
+ sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
+ asyncCallbacks_ = 0;
+ asyncLookups_ = 0;
+ lookupDelay_ = lookupDelay;
+ }
+
+ uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
+ uint32_t getAsyncLookups() const { return asyncLookups_; }
+
+ private:
+ static uint32_t asyncCallbacks_;
+ static uint32_t asyncLookups_;
+ static uint32_t lookupDelay_;
+
+ static SSL_SESSION *getSessionCallback(SSL *ssl,
+ unsigned char *sess_id,
+ int id_len,
+ int *copyflag) {
+ *copyflag = 0;
+ asyncCallbacks_++;
+#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
+ if (!SSL_want_sess_cache_lookup(ssl)) {
+ // libssl.so mismatch
+ std::cerr << "no async support" << std::endl;
+ return nullptr;
+ }
+
+ AsyncSSLSocket *sslSocket =
+ AsyncSSLSocket::getFromSSL(ssl);
+ assert(sslSocket != nullptr);
+ // Going to simulate an async cache by just running delaying the miss 100ms
+ if (asyncCallbacks_ % 2 == 0) {
+ // This socket is already blocked on lookup, return miss
+ std::cerr << "returning miss" << std::endl;
+ } else {
+ // fresh meat - block it
+ std::cerr << "async lookup" << std::endl;
+ sslSocket->getEventBase()->tryRunAfterDelay(
+ std::bind(&AsyncSSLSocket::restartSSLAccept,
+ sslSocket), lookupDelay_);
+ *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
+ asyncLookups_++;
+ }
+#endif
+ return nullptr;
+ }
+};
+
+void getfds(int fds[2]);
+
+void getctx(
+ std::shared_ptr<folly::SSLContext> clientCtx,
+ std::shared_ptr<folly::SSLContext> serverCtx);
+
+void sslsocketpair(
+ EventBase* eventBase,
+ AsyncSSLSocket::UniquePtr* clientSock,
+ AsyncSSLSocket::UniquePtr* serverSock);
+
+class BlockingWriteClient :
+ private AsyncSSLSocket::HandshakeCB,
+ private AsyncTransportWrapper::WriteCallback {
+ public:
+ explicit BlockingWriteClient(
+ AsyncSSLSocket::UniquePtr socket)
+ : socket_(std::move(socket)),
+ bufLen_(2500),
+ iovCount_(2000) {
+ // Fill buf_
+ buf_.reset(new uint8_t[bufLen_]);
+ for (uint32_t n = 0; n < sizeof(buf_); ++n) {
+ buf_[n] = n % 0xff;
+ }
+
+ // Initialize iov_
+ iov_.reset(new struct iovec[iovCount_]);
+ for (uint32_t n = 0; n < iovCount_; ++n) {
+ iov_[n].iov_base = buf_.get() + n;
+ if (n & 0x1) {
+ iov_[n].iov_len = n % bufLen_;
+ } else {
+ iov_[n].iov_len = bufLen_ - (n % bufLen_);
+ }
+ }
+
+ socket_->sslConn(this, 100);
+ }
+
+ struct iovec* getIovec() const {
+ return iov_.get();
+ }
+ uint32_t getIovecCount() const {
+ return iovCount_;
+ }
+
+ private:
+ void handshakeSuc(AsyncSSLSocket*) noexcept override {
+ socket_->writev(this, iov_.get(), iovCount_);
+ }
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "client handshake error: " << ex.what();
+ }
+ void writeSuccess() noexcept override {
+ socket_->close();
+ }
+ void writeErr(
+ size_t bytesWritten,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
+ << ex.what();
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+ uint32_t bufLen_;
+ uint32_t iovCount_;
+ std::unique_ptr<uint8_t[]> buf_;
+ std::unique_ptr<struct iovec[]> iov_;
+};
+
+class BlockingWriteServer :
+ private AsyncSSLSocket::HandshakeCB,
+ private AsyncTransportWrapper::ReadCallback {
+ public:
+ explicit BlockingWriteServer(
+ AsyncSSLSocket::UniquePtr socket)
+ : socket_(std::move(socket)),
+ bufSize_(2500 * 2000),
+ bytesRead_(0) {
+ buf_.reset(new uint8_t[bufSize_]);
+ socket_->sslAccept(this, 100);
+ }
+
+ void checkBuffer(struct iovec* iov, uint32_t count) const {
+ uint32_t idx = 0;
+ for (uint32_t n = 0; n < count; ++n) {
+ size_t bytesLeft = bytesRead_ - idx;
+ int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
+ std::min(iov[n].iov_len, bytesLeft));
+ if (rc != 0) {
+ FAIL() << "buffer mismatch at iovec " << n << "/" << count
+ << ": rc=" << rc;
+
+ }
+ if (iov[n].iov_len > bytesLeft) {
+ FAIL() << "server did not read enough data: "
+ << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
+ << " in iovec " << n << "/" << count;
+ }
+
+ idx += iov[n].iov_len;
+ }
+ if (idx != bytesRead_) {
+ ADD_FAILURE() << "server read extra data: " << bytesRead_
+ << " bytes read; expected " << idx;
+ }
+ }
+
+ private:
+ void handshakeSuc(AsyncSSLSocket*) noexcept override {
+ // Wait 10ms before reading, so the client's writes will initially block.
+ socket_->getEventBase()->tryRunAfterDelay(
+ [this] { socket_->setReadCB(this); }, 10);
+ }
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "server handshake error: " << ex.what();
+ }
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *bufReturn = buf_.get() + bytesRead_;
+ *lenReturn = bufSize_ - bytesRead_;
+ }
+ void readDataAvailable(size_t len) noexcept override {
+ bytesRead_ += len;
+ socket_->setReadCB(nullptr);
+ socket_->getEventBase()->tryRunAfterDelay(
+ [this] { socket_->setReadCB(this); }, 2);
+ }
+ void readEOF() noexcept override {
+ socket_->close();
+ }
+ void readErr(
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "server read error: " << ex.what();
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+ uint32_t bufSize_;
+ uint32_t bytesRead_;
+ std::unique_ptr<uint8_t[]> buf_;
+};
+
+class NpnClient :
+ private AsyncSSLSocket::HandshakeCB,
+ private AsyncTransportWrapper::WriteCallback {
+ public:
+ explicit NpnClient(
+ AsyncSSLSocket::UniquePtr socket)
+ : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
+ socket_->sslConn(this);
+ }
+
+ const unsigned char* nextProto;
+ unsigned nextProtoLength;
+ private:
+ void handshakeSuc(AsyncSSLSocket*) noexcept override {
+ socket_->getSelectedNextProtocol(&nextProto,
+ &nextProtoLength);
+ }
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "client handshake error: " << ex.what();
+ }
+ void writeSuccess() noexcept override {
+ socket_->close();
+ }
+ void writeErr(
+ size_t bytesWritten,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
+ << ex.what();
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+};
+
+class NpnServer :
+ private AsyncSSLSocket::HandshakeCB,
+ private AsyncTransportWrapper::ReadCallback {
+ public:
+ explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
+ : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
+ socket_->sslAccept(this);
+ }
+
+ const unsigned char* nextProto;
+ unsigned nextProtoLength;
+ private:
+ void handshakeSuc(AsyncSSLSocket*) noexcept override {
+ socket_->getSelectedNextProtocol(&nextProto,
+ &nextProtoLength);
+ }
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "server handshake error: " << ex.what();
+ }
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *lenReturn = 0;
+ }
+ void readDataAvailable(size_t len) noexcept override {
+ }
+ void readEOF() noexcept override {
+ socket_->close();
+ }
+ void readErr(
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "server read error: " << ex.what();
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+};
+
+#ifndef OPENSSL_NO_TLSEXT
+class SNIClient :
+ private AsyncSSLSocket::HandshakeCB,
+ private AsyncTransportWrapper::WriteCallback {
+ public:
+ explicit SNIClient(
+ AsyncSSLSocket::UniquePtr socket)
+ : serverNameMatch(false), socket_(std::move(socket)) {
+ socket_->sslConn(this);
+ }
+
+ bool serverNameMatch;
+
+ private:
+ void handshakeSuc(AsyncSSLSocket*) noexcept override {
+ serverNameMatch = socket_->isServerNameMatch();
+ }
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "client handshake error: " << ex.what();
+ }
+ void writeSuccess() noexcept override {
+ socket_->close();
+ }
+ void writeErr(
+ size_t bytesWritten,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
+ << ex.what();
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+};
+
+class SNIServer :
+ private AsyncSSLSocket::HandshakeCB,
+ private AsyncTransportWrapper::ReadCallback {
+ public:
+ explicit SNIServer(
+ AsyncSSLSocket::UniquePtr socket,
+ const std::shared_ptr<folly::SSLContext>& ctx,
+ const std::shared_ptr<folly::SSLContext>& sniCtx,
+ const std::string& expectedServerName)
+ : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
+ expectedServerName_(expectedServerName) {
+ ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
+ std::placeholders::_1));
+ socket_->sslAccept(this);
+ }
+
+ bool serverNameMatch;
+
+ private:
+ void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "server handshake error: " << ex.what();
+ }
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *lenReturn = 0;
+ }
+ void readDataAvailable(size_t len) noexcept override {
+ }
+ void readEOF() noexcept override {
+ socket_->close();
+ }
+ void readErr(
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "server read error: " << ex.what();
+ }
+
+ folly::SSLContext::ServerNameCallbackResult
+ serverNameCallback(SSL *ssl) {
+ const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+ if (sniCtx_ &&
+ sn &&
+ !strcasecmp(expectedServerName_.c_str(), sn)) {
+ AsyncSSLSocket *sslSocket =
+ AsyncSSLSocket::getFromSSL(ssl);
+ sslSocket->switchServerSSLContext(sniCtx_);
+ serverNameMatch = true;
+ return folly::SSLContext::SERVER_NAME_FOUND;
+ } else {
+ serverNameMatch = false;
+ return folly::SSLContext::SERVER_NAME_NOT_FOUND;
+ }
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+ std::shared_ptr<folly::SSLContext> sniCtx_;
+ std::string expectedServerName_;
+};
+#endif
+
+class SSLClient : public AsyncSocket::ConnectCallback,
+ public AsyncTransportWrapper::WriteCallback,
+ public AsyncTransportWrapper::ReadCallback
+{
+ private:
+ EventBase *eventBase_;
+ std::shared_ptr<AsyncSSLSocket> sslSocket_;
+ SSL_SESSION *session_;
+ std::shared_ptr<folly::SSLContext> ctx_;
+ uint32_t requests_;
+ folly::SocketAddress address_;
+ uint32_t timeout_;
+ char buf_[128];
+ char readbuf_[128];
+ uint32_t bytesRead_;
+ uint32_t hit_;
+ uint32_t miss_;
+ uint32_t errors_;
+ uint32_t writeAfterConnectErrors_;
+
+ public:
+ SSLClient(EventBase *eventBase,
+ const folly::SocketAddress& address,
+ uint32_t requests, uint32_t timeout = 0)
+ : eventBase_(eventBase),
+ session_(nullptr),
+ requests_(requests),
+ address_(address),
+ timeout_(timeout),
+ bytesRead_(0),
+ hit_(0),
+ miss_(0),
+ errors_(0),
+ writeAfterConnectErrors_(0) {
+ ctx_.reset(new folly::SSLContext());
+ ctx_->setOptions(SSL_OP_NO_TICKET);
+ ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ memset(buf_, 'a', sizeof(buf_));
+ }
+
+ ~SSLClient() {
+ if (session_) {
+ SSL_SESSION_free(session_);
+ }
+ if (errors_ == 0) {
+ EXPECT_EQ(bytesRead_, sizeof(buf_));
+ }
+ }
+
+ uint32_t getHit() const { return hit_; }
+
+ uint32_t getMiss() const { return miss_; }
+
+ uint32_t getErrors() const { return errors_; }
+
+ uint32_t getWriteAfterConnectErrors() const {
+ return writeAfterConnectErrors_;
+ }
+
+ void connect(bool writeNow = false) {
+ sslSocket_ = AsyncSSLSocket::newSocket(
+ ctx_, eventBase_);
+ if (session_ != nullptr) {
+ sslSocket_->setSSLSession(session_);
+ }
+ requests_--;
+ sslSocket_->connect(this, address_, timeout_);
+ if (sslSocket_ && writeNow) {
+ // write some junk, used in an error test
+ sslSocket_->write(this, buf_, sizeof(buf_));
+ }
+ }
+
+ void connectSuccess() noexcept override {
+ std::cerr << "client SSL socket connected" << std::endl;
+ if (sslSocket_->getSSLSessionReused()) {
+ hit_++;
+ } else {
+ miss_++;
+ if (session_ != nullptr) {
+ SSL_SESSION_free(session_);
+ }
+ session_ = sslSocket_->getSSLSession();
+ }
+
+ // write()
+ sslSocket_->write(this, buf_, sizeof(buf_));
+ sslSocket_->setReadCB(this);
+ memset(readbuf_, 'b', sizeof(readbuf_));
+ bytesRead_ = 0;
+ }
+
+ void connectErr(
+ const AsyncSocketException& ex) noexcept override {
+ std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
+ errors_++;
+ sslSocket_.reset();
+ }
+
+ void writeSuccess() noexcept override {
+ std::cerr << "client write success" << std::endl;
+ }
+
+ void writeErr(
+ size_t bytesWritten,
+ const AsyncSocketException& ex)
+ noexcept override {
+ std::cerr << "client writeError: " << ex.what() << std::endl;
+ if (!sslSocket_) {
+ writeAfterConnectErrors_++;
+ }
+ }
+
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *bufReturn = readbuf_ + bytesRead_;
+ *lenReturn = sizeof(readbuf_) - bytesRead_;
+ }
+
+ void readEOF() noexcept override {
+ std::cerr << "client readEOF" << std::endl;
+ }
+
+ void readErr(
+ const AsyncSocketException& ex) noexcept override {
+ std::cerr << "client readError: " << ex.what() << std::endl;
+ }
+
+ void readDataAvailable(size_t len) noexcept override {
+ std::cerr << "client read data: " << len << std::endl;
+ bytesRead_ += len;
+ if (len == sizeof(buf_)) {
+ EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
+ sslSocket_->closeNow();
+ sslSocket_.reset();
+ if (requests_ != 0) {
+ connect();
+ }
+ }
+ }
+
+};
+
+class SSLHandshakeBase :
+ public AsyncSSLSocket::HandshakeCB,
+ private AsyncTransportWrapper::WriteCallback {
+ public:
+ explicit SSLHandshakeBase(
+ AsyncSSLSocket::UniquePtr socket,
+ bool preverifyResult,
+ bool verifyResult) :
+ handshakeVerify_(false),
+ handshakeSuccess_(false),
+ handshakeError_(false),
+ socket_(std::move(socket)),
+ preverifyResult_(preverifyResult),
+ verifyResult_(verifyResult) {
+ }
+
+ bool handshakeVerify_;
+ bool handshakeSuccess_;
+ bool handshakeError_;
+
+ protected:
+ AsyncSSLSocket::UniquePtr socket_;
+ bool preverifyResult_;
+ bool verifyResult_;
+
+ // HandshakeCallback
+ bool handshakeVer(
+ AsyncSSLSocket* sock,
+ bool preverifyOk,
+ X509_STORE_CTX* ctx) noexcept override {
+ handshakeVerify_ = true;
+
+ EXPECT_EQ(preverifyResult_, preverifyOk);
+ return verifyResult_;
+ }
+
+ void handshakeSuc(AsyncSSLSocket*) noexcept override {
+ handshakeSuccess_ = true;
+ }
+
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ handshakeError_ = true;
+ }
+
+ // WriteCallback
+ void writeSuccess() noexcept override {
+ socket_->close();
+ }
+
+ void writeErr(
+ size_t bytesWritten,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
+ << ex.what();
+ }
+};
+
+class SSLHandshakeClient : public SSLHandshakeBase {
+ public:
+ SSLHandshakeClient(
+ AsyncSSLSocket::UniquePtr socket,
+ bool preverifyResult,
+ bool verifyResult) :
+ SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+ socket_->sslConn(this, 0);
+ }
+};
+
+class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
+ public:
+ SSLHandshakeClientNoVerify(
+ AsyncSSLSocket::UniquePtr socket,
+ bool preverifyResult,
+ bool verifyResult) :
+ SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+ socket_->sslConn(this, 0,
+ folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+ }
+};
+
+class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
+ public:
+ SSLHandshakeClientDoVerify(
+ AsyncSSLSocket::UniquePtr socket,
+ bool preverifyResult,
+ bool verifyResult) :
+ SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+ socket_->sslConn(this, 0,
+ folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
+ }
+};
+
+class SSLHandshakeServer : public SSLHandshakeBase {
+ public:
+ SSLHandshakeServer(
+ AsyncSSLSocket::UniquePtr socket,
+ bool preverifyResult,
+ bool verifyResult)
+ : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+ socket_->sslAccept(this, 0);
+ }
+};
+
+class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
+ public:
+ SSLHandshakeServerParseClientHello(
+ AsyncSSLSocket::UniquePtr socket,
+ bool preverifyResult,
+ bool verifyResult)
+ : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+ socket_->enableClientHelloParsing();
+ socket_->sslAccept(this, 0);
+ }
+
+ std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
+
+ protected:
+ void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
+ handshakeSuccess_ = true;
+ sock->getSSLSharedCiphers(sharedCiphers_);
+ sock->getSSLServerCiphers(serverCiphers_);
+ sock->getSSLClientCiphers(clientCiphers_);
+ chosenCipher_ = sock->getNegotiatedCipherName();
+ }
+};
+
+
+class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
+ public:
+ SSLHandshakeServerNoVerify(
+ AsyncSSLSocket::UniquePtr socket,
+ bool preverifyResult,
+ bool verifyResult)
+ : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+ socket_->sslAccept(this, 0,
+ folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+ }
+};
+
+class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
+ public:
+ SSLHandshakeServerDoVerify(
+ AsyncSSLSocket::UniquePtr socket,
+ bool preverifyResult,
+ bool verifyResult)
+ : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
+ socket_->sslAccept(this, 0,
+ folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+ }
+};
+
+class EventBaseAborter : public AsyncTimeout {
+ public:
+ EventBaseAborter(EventBase* eventBase,
+ uint32_t timeoutMS)
+ : AsyncTimeout(
+ eventBase, AsyncTimeout::InternalEnum::INTERNAL)
+ , eventBase_(eventBase) {
+ scheduleTimeout(timeoutMS);
+ }
+
+ void timeoutExpired() noexcept override {
+ FAIL() << "test timed out";
+ eventBase_->terminateLoopSoon();
+ }
+
+ private:
+ EventBase* eventBase_;
+};
+
+}
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/async/test/AsyncSSLSocketTest.h>
+
+#include <gtest/gtest.h>
+#include <pthread.h>
+
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/EventBase.h>
+
+using std::string;
+using std::vector;
+using std::min;
+using std::cerr;
+using std::endl;
+using std::list;
+
+namespace folly {
+
+class AttachDetachClient : public AsyncSocket::ConnectCallback,
+ public AsyncTransportWrapper::WriteCallback,
+ public AsyncTransportWrapper::ReadCallback {
+ private:
+ EventBase *eventBase_;
+ std::shared_ptr<AsyncSSLSocket> sslSocket_;
+ std::shared_ptr<SSLContext> ctx_;
+ folly::SocketAddress address_;
+ char buf_[128];
+ char readbuf_[128];
+ uint32_t bytesRead_;
+ public:
+ AttachDetachClient(EventBase *eventBase, const folly::SocketAddress& address)
+ : eventBase_(eventBase), address_(address), bytesRead_(0) {
+ ctx_.reset(new SSLContext());
+ ctx_->setOptions(SSL_OP_NO_TICKET);
+ ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ }
+
+ void connect() {
+ sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_);
+ sslSocket_->connect(this, address_);
+ }
+
+ void connectSuccess() noexcept override {
+ cerr << "client SSL socket connected" << endl;
+
+ for (int i = 0; i < 1000; ++i) {
+ sslSocket_->detachSSLContext();
+ sslSocket_->attachSSLContext(ctx_);
+ }
+
+ EXPECT_EQ(ctx_->getSSLCtx()->references, 2);
+
+ sslSocket_->write(this, buf_, sizeof(buf_));
+ sslSocket_->setReadCB(this);
+ memset(readbuf_, 'b', sizeof(readbuf_));
+ bytesRead_ = 0;
+ }
+
+ void connectErr(const AsyncSocketException& ex) noexcept override
+ {
+ cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
+ sslSocket_.reset();
+ }
+
+ void writeSuccess() noexcept override {
+ cerr << "client write success" << endl;
+ }
+
+ void writeErr(size_t bytesWritten, const AsyncSocketException& ex)
+ noexcept override {
+ cerr << "client writeError: " << ex.what() << endl;
+ }
+
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *bufReturn = readbuf_ + bytesRead_;
+ *lenReturn = sizeof(readbuf_) - bytesRead_;
+ }
+ void readEOF() noexcept override {
+ cerr << "client readEOF" << endl;
+ }
+
+ void readErr(const AsyncSocketException& ex) noexcept override {
+ cerr << "client readError: " << ex.what() << endl;
+ }
+
+ void readDataAvailable(size_t len) noexcept override {
+ cerr << "client read data: " << len << endl;
+ bytesRead_ += len;
+ if (len == sizeof(buf_)) {
+ EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
+ sslSocket_->closeNow();
+ }
+ }
+};
+
+/**
+ * Test passing contexts between threads
+ */
+TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ EventBase eventBase;
+ EventBaseAborter eba(&eventBase, 3000);
+ std::shared_ptr<AttachDetachClient> client(
+ new AttachDetachClient(&eventBase, server.getAddress()));
+
+ client->connect();
+ eventBase.loop();
+}
+
+}
+///////////////////////////////////////////////////////////////////////////
+// init_unit_test_suite
+///////////////////////////////////////////////////////////////////////////
+
+namespace {
+using folly::SSLContext;
+struct Initializer {
+ Initializer() {
+ signal(SIGPIPE, SIG_IGN);
+ SSLContext::setSSLLockTypes({
+ {CRYPTO_LOCK_EVP_PKEY, SSLContext::LOCK_NONE},
+ {CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_SPINLOCK},
+ {CRYPTO_LOCK_SSL_CTX, SSLContext::LOCK_NONE}});
+ }
+};
+Initializer initializer;
+} // anonymous
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/Foreach.h>
+#include <folly/io/Cursor.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/EventBase.h>
+
+#include <gtest/gtest.h>
+#include <gmock/gmock.h>
+#include <string>
+#include <vector>
+
+using std::string;
+using namespace testing;
+
+namespace folly {
+
+class MockAsyncSSLSocket : public AsyncSSLSocket{
+ public:
+ static std::shared_ptr<MockAsyncSSLSocket> newSocket(
+ const std::shared_ptr<SSLContext>& ctx,
+ EventBase* evb) {
+ auto sock = std::shared_ptr<MockAsyncSSLSocket>(
+ new MockAsyncSSLSocket(ctx, evb),
+ Destructor());
+ sock->ssl_ = SSL_new(ctx->getSSLCtx());
+ SSL_set_fd(sock->ssl_, -1);
+ return sock;
+ }
+
+ // Fake constructor sets the state to established without call to connect
+ // or accept
+ MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx,
+ EventBase* evb)
+ : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
+ state_ = AsyncSocket::StateEnum::ESTABLISHED;
+ sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
+ }
+
+ // mock the calls to SSL_write to see the buffer length and contents
+ MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n));
+
+ // mock the calls to getRawBytesWritten()
+ MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
+
+ // public wrapper for protected interface
+ ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags,
+ uint32_t* countWritten, uint32_t* partialWritten) {
+ return performWrite(vec, count, flags, countWritten, partialWritten);
+ }
+
+ void checkEor(size_t appEor, size_t rawEor) {
+ EXPECT_EQ(appEor, appEorByteNo_);
+ EXPECT_EQ(rawEor, minEorRawByteNo_);
+ }
+
+ void setAppBytesWritten(size_t n) {
+ appBytesWritten_ = n;
+ }
+};
+
+class AsyncSSLSocketWriteTest : public testing::Test {
+ public:
+ AsyncSSLSocketWriteTest() :
+ sslContext_(new SSLContext()),
+ sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
+ for (int i = 0; i < 500; i++) {
+ memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
+ }
+ }
+
+ // Make an iovec containing chunks of the reference text with requested sizes
+ // for each chunk
+ iovec *makeVec(std::vector<uint32_t> sizes) {
+ iovec *vec = new iovec[sizes.size()];
+ int i = 0;
+ int pos = 0;
+ for (auto size: sizes) {
+ vec[i].iov_base = (void *)(source_ + pos);
+ vec[i++].iov_len = size;
+ pos += size;
+ }
+ return vec;
+ }
+
+ // Verify that the given buf/pos matches the reference text
+ void verifyVec(const void *buf, int n, int pos) {
+ ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
+ }
+
+ // Update a vec on partial write
+ void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) {
+ vec[countWritten].iov_base =
+ ((char *)vec[countWritten].iov_base) + partialWritten;
+ vec[countWritten].iov_len -= partialWritten;
+ }
+
+ EventBase eventBase_;
+ std::shared_ptr<SSLContext> sslContext_;
+ std::shared_ptr<MockAsyncSSLSocket> sock_;
+ char source_[26 * 500];
+};
+
+
+// The entire vec fits in one packet
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
+ int n = 3;
+ iovec *vec = makeVec({3, 3, 3});
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
+ .WillOnce(Invoke([this] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, 0);
+ return 9; }));
+ uint32_t countWritten = 0;
+ uint32_t partialWritten = 0;
+ sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+ &partialWritten);
+ EXPECT_EQ(countWritten, n);
+ EXPECT_EQ(partialWritten, 0);
+}
+
+// First packet is full, second two go in one packet
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
+ int n = 3;
+ iovec *vec = makeVec({1500, 3, 3});
+ int pos = 0;
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+ uint32_t countWritten = 0;
+ uint32_t partialWritten = 0;
+ sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+ &partialWritten);
+ EXPECT_EQ(countWritten, n);
+ EXPECT_EQ(partialWritten, 0);
+}
+
+// Two exactly full packets (coalesce ends midway through second chunk)
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
+ int n = 3;
+ iovec *vec = makeVec({1000, 1000, 1000});
+ int pos = 0;
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+ .Times(2)
+ .WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+ uint32_t countWritten = 0;
+ uint32_t partialWritten = 0;
+ sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+ &partialWritten);
+ EXPECT_EQ(countWritten, n);
+ EXPECT_EQ(partialWritten, 0);
+}
+
+// Partial write success midway through a coalesced vec
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
+ int n = 5;
+ iovec *vec = makeVec({300, 300, 300, 300, 300});
+ int pos = 0;
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, pos);
+ pos += 1000;
+ return 1000; /* 500 bytes "pending" */ }));
+ uint32_t countWritten = 0;
+ uint32_t partialWritten = 0;
+ sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+ &partialWritten);
+ EXPECT_EQ(countWritten, 3);
+ EXPECT_EQ(partialWritten, 100);
+ consumeVec(vec, countWritten, partialWritten);
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, pos);
+ pos += n;
+ return 500; }));
+ sock_->testPerformWrite(vec + countWritten, n - countWritten,
+ WriteFlags::NONE,
+ &countWritten, &partialWritten);
+ EXPECT_EQ(countWritten, 2);
+ EXPECT_EQ(partialWritten, 0);
+}
+
+// coalesce ends exactly on a buffer boundary
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
+ int n = 3;
+ iovec *vec = makeVec({1000, 500, 500});
+ int pos = 0;
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+ uint32_t countWritten = 0;
+ uint32_t partialWritten = 0;
+ sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+ &partialWritten);
+ EXPECT_EQ(countWritten, 3);
+ EXPECT_EQ(partialWritten, 0);
+}
+
+// partial write midway through first chunk
+TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
+ int n = 2;
+ iovec *vec = makeVec({1000, 500});
+ int pos = 0;
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, pos);
+ pos += 700;
+ return 700; }));
+ uint32_t countWritten = 0;
+ uint32_t partialWritten = 0;
+ sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
+ &partialWritten);
+ EXPECT_EQ(countWritten, 0);
+ EXPECT_EQ(partialWritten, 700);
+ consumeVec(vec, countWritten, partialWritten);
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+ sock_->testPerformWrite(vec + countWritten, n - countWritten,
+ WriteFlags::NONE,
+ &countWritten, &partialWritten);
+ EXPECT_EQ(countWritten, 2);
+ EXPECT_EQ(partialWritten, 0);
+}
+
+// Repeat coalescing2 with WriteFlags::EOR
+TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
+ int n = 3;
+ iovec *vec = makeVec({1500, 3, 3});
+ int pos = 0;
+ const size_t initAppBytesWritten = 500;
+ const size_t appEor = initAppBytesWritten + 1506;
+
+ sock_->setAppBytesWritten(initAppBytesWritten);
+ EXPECT_FALSE(sock_->isEorTrackingEnabled());
+ sock_->setEorTracking(true);
+ EXPECT_TRUE(sock_->isEorTrackingEnabled());
+
+ EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
+ // rawBytesWritten after writting initAppBytesWritten + 1500
+ // + some random SSL overhead
+ .WillOnce(Return(3600))
+ // rawBytesWritten after writting last 6 bytes
+ // + some random SSL overhead
+ .WillOnce(Return(3728));
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+ .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
+ // the first 1500 does not have the EOR byte
+ sock_->checkEor(0, 0);
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
+ .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
+ sock_->checkEor(appEor, 3600 + n);
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+
+ uint32_t countWritten = 0;
+ uint32_t partialWritten = 0;
+ sock_->testPerformWrite(vec, n , WriteFlags::EOR,
+ &countWritten, &partialWritten);
+ EXPECT_EQ(countWritten, n);
+ EXPECT_EQ(partialWritten, 0);
+ sock_->checkEor(0, 0);
+}
+
+// coalescing with left over at the last chunk
+// WriteFlags::EOR turned on
+TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
+ int n = 3;
+ iovec *vec = makeVec({600, 600, 600});
+ int pos = 0;
+ const size_t initAppBytesWritten = 500;
+ const size_t appEor = initAppBytesWritten + 1800;
+
+ sock_->setAppBytesWritten(initAppBytesWritten);
+ sock_->setEorTracking(true);
+
+ EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
+ // rawBytesWritten after writting initAppBytesWritten + 1500 bytes
+ // + some random SSL overhead
+ .WillOnce(Return(3600))
+ // rawBytesWritten after writting last 300 bytes
+ // + some random SSL overhead
+ .WillOnce(Return(4100));
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
+ .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
+ // the first 1500 does not have the EOR byte
+ sock_->checkEor(0, 0);
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
+ .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
+ sock_->checkEor(appEor, 3600 + n);
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+
+ uint32_t countWritten = 0;
+ uint32_t partialWritten = 0;
+ sock_->testPerformWrite(vec, n, WriteFlags::EOR,
+ &countWritten, &partialWritten);
+ EXPECT_EQ(countWritten, n);
+ EXPECT_EQ(partialWritten, 0);
+ sock_->checkEor(0, 0);
+}
+
+// WriteFlags::EOR set
+// One buf in iovec
+// Partial write at 1000-th byte
+TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
+ int n = 1;
+ iovec *vec = makeVec({1600});
+ int pos = 0;
+ const size_t initAppBytesWritten = 500;
+ const size_t appEor = initAppBytesWritten + 1600;
+
+ sock_->setAppBytesWritten(initAppBytesWritten);
+ sock_->setEorTracking(true);
+
+ EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
+ // rawBytesWritten after the initAppBytesWritten
+ // + some random SSL overhead
+ .WillOnce(Return(2000))
+ // rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead)
+ // + some random SSL overhead
+ .WillOnce(Return(3100));
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ sock_->checkEor(appEor, 2000 + n);
+ verifyVec(buf, n, pos);
+ pos += 1000;
+ return 1000; }));
+
+ uint32_t countWritten = 0;
+ uint32_t partialWritten = 0;
+ sock_->testPerformWrite(vec, n, WriteFlags::EOR,
+ &countWritten, &partialWritten);
+ EXPECT_EQ(countWritten, 0);
+ EXPECT_EQ(partialWritten, 1000);
+ sock_->checkEor(appEor, 2000 + 1600);
+ consumeVec(vec, countWritten, partialWritten);
+
+ EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
+ .WillOnce(Return(3100))
+ .WillOnce(Return(3800));
+ EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
+ .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
+ sock_->checkEor(appEor, 3100 + n);
+ verifyVec(buf, n, pos);
+ pos += n;
+ return n; }));
+ sock_->testPerformWrite(vec + countWritten, n - countWritten,
+ WriteFlags::EOR,
+ &countWritten, &partialWritten);
+ EXPECT_EQ(countWritten, n);
+ EXPECT_EQ(partialWritten, 0);
+ sock_->checkEor(0, 0);
+}
+
+}
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/SocketAddress.h>
+
+#include <folly/io/IOBuf.h>
+#include <folly/io/async/test/BlockingSocket.h>
+#include <folly/io/async/test/Util.h>
+
+#include <gtest/gtest.h>
+#include <boost/scoped_array.hpp>
+#include <iostream>
+#include <unistd.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/tcp.h>
+#include <thread>
+
+using namespace boost;
+
+using std::string;
+using std::vector;
+using std::min;
+using std::cerr;
+using std::endl;
+using std::unique_ptr;
+using std::chrono::milliseconds;
+using boost::scoped_array;
+
+using namespace folly;
+
+enum StateEnum {
+ STATE_WAITING,
+ STATE_SUCCEEDED,
+ STATE_FAILED
+};
+
+typedef std::function<void()> VoidCallback;
+
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+ ConnCallback()
+ : state(STATE_WAITING)
+ , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+ void connectSuccess() noexcept override {
+ state = STATE_SUCCEEDED;
+ if (successCallback) {
+ successCallback();
+ }
+ }
+
+ void connectErr(const AsyncSocketException& ex) noexcept override {
+ state = STATE_FAILED;
+ exception = ex;
+ if (errorCallback) {
+ errorCallback();
+ }
+ }
+
+ StateEnum state;
+ AsyncSocketException exception;
+ VoidCallback successCallback;
+ VoidCallback errorCallback;
+};
+
+class WriteCallback : public AsyncTransportWrapper::WriteCallback {
+ public:
+ WriteCallback()
+ : state(STATE_WAITING)
+ , bytesWritten(0)
+ , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+ void writeSuccess() noexcept override {
+ state = STATE_SUCCEEDED;
+ if (successCallback) {
+ successCallback();
+ }
+ }
+
+ void writeErr(size_t bytesWritten,
+ const AsyncSocketException& ex) noexcept override {
+ state = STATE_FAILED;
+ this->bytesWritten = bytesWritten;
+ exception = ex;
+ if (errorCallback) {
+ errorCallback();
+ }
+ }
+
+ StateEnum state;
+ size_t bytesWritten;
+ AsyncSocketException exception;
+ VoidCallback successCallback;
+ VoidCallback errorCallback;
+};
+
+class ReadCallback : public AsyncTransportWrapper::ReadCallback {
+ public:
+ ReadCallback()
+ : state(STATE_WAITING)
+ , exception(AsyncSocketException::UNKNOWN, "none")
+ , buffers() {}
+
+ ~ReadCallback() {
+ for (vector<Buffer>::iterator it = buffers.begin();
+ it != buffers.end();
+ ++it) {
+ it->free();
+ }
+ currentBuffer.free();
+ }
+
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ if (!currentBuffer.buffer) {
+ currentBuffer.allocate(4096);
+ }
+ *bufReturn = currentBuffer.buffer;
+ *lenReturn = currentBuffer.length;
+ }
+
+ void readDataAvailable(size_t len) noexcept override {
+ currentBuffer.length = len;
+ buffers.push_back(currentBuffer);
+ currentBuffer.reset();
+ if (dataAvailableCallback) {
+ dataAvailableCallback();
+ }
+ }
+
+ void readEOF() noexcept override {
+ state = STATE_SUCCEEDED;
+ }
+
+ void readErr(const AsyncSocketException& ex) noexcept override {
+ state = STATE_FAILED;
+ exception = ex;
+ }
+
+ void verifyData(const char* expected, size_t expectedLen) const {
+ size_t offset = 0;
+ for (size_t idx = 0; idx < buffers.size(); ++idx) {
+ const auto& buf = buffers[idx];
+ size_t cmpLen = std::min(buf.length, expectedLen - offset);
+ CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
+ CHECK_EQ(cmpLen, buf.length);
+ offset += cmpLen;
+ }
+ CHECK_EQ(offset, expectedLen);
+ }
+
+ class Buffer {
+ public:
+ Buffer() : buffer(nullptr), length(0) {}
+ Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
+
+ void reset() {
+ buffer = nullptr;
+ length = 0;
+ }
+ void allocate(size_t length) {
+ assert(buffer == nullptr);
+ this->buffer = static_cast<char*>(malloc(length));
+ this->length = length;
+ }
+ void free() {
+ ::free(buffer);
+ reset();
+ }
+
+ char* buffer;
+ size_t length;
+ };
+
+ StateEnum state;
+ AsyncSocketException exception;
+ vector<Buffer> buffers;
+ Buffer currentBuffer;
+ VoidCallback dataAvailableCallback;
+};
+
+class ReadVerifier {
+};
+
+class TestServer {
+ public:
+ // Create a TestServer.
+ // This immediately starts listening on an ephemeral port.
+ TestServer()
+ : fd_(-1) {
+ fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (fd_ < 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "failed to create test server socket", errno);
+ }
+ if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "failed to put test server socket in "
+ "non-blocking mode", errno);
+ }
+ if (listen(fd_, 10) != 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "failed to listen on test server socket",
+ errno);
+ }
+
+ address_.setFromLocalAddress(fd_);
+ // The local address will contain 0.0.0.0.
+ // Change it to 127.0.0.1, so it can be used to connect to the server
+ address_.setFromIpPort("127.0.0.1", address_.getPort());
+ }
+
+ // Get the address for connecting to the server
+ const folly::SocketAddress& getAddress() const {
+ return address_;
+ }
+
+ int acceptFD(int timeout=50) {
+ struct pollfd pfd;
+ pfd.fd = fd_;
+ pfd.events = POLLIN;
+ int ret = poll(&pfd, 1, timeout);
+ if (ret == 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "test server accept() timed out");
+ } else if (ret < 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "test server accept() poll failed", errno);
+ }
+
+ int acceptedFd = ::accept(fd_, nullptr, nullptr);
+ if (acceptedFd < 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "test server accept() failed", errno);
+ }
+
+ return acceptedFd;
+ }
+
+ std::shared_ptr<BlockingSocket> accept(int timeout=50) {
+ int fd = acceptFD(timeout);
+ return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
+ }
+
+ std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
+ int fd = acceptFD(timeout);
+ return AsyncSocket::newSocket(evb, fd);
+ }
+
+ /**
+ * Accept a connection, read data from it, and verify that it matches the
+ * data in the specified buffer.
+ */
+ void verifyConnection(const char* buf, size_t len) {
+ // accept a connection
+ std::shared_ptr<BlockingSocket> acceptedSocket = accept();
+ // read the data and compare it to the specified buffer
+ scoped_array<uint8_t> readbuf(new uint8_t[len]);
+ acceptedSocket->readAll(readbuf.get(), len);
+ CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
+ // make sure we get EOF next
+ uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
+ CHECK_EQ(bytesRead, 0);
+ }
+
+ private:
+ int fd_;
+ folly::SocketAddress address_;
+};
+
+class DelayedWrite: public AsyncTimeout {
+ public:
+ DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
+ unique_ptr<IOBuf>&& bufs, AsyncTransportWrapper::WriteCallback* wcb,
+ bool cork, bool lastWrite = false):
+ AsyncTimeout(socket->getEventBase()),
+ socket_(socket),
+ bufs_(std::move(bufs)),
+ wcb_(wcb),
+ cork_(cork),
+ lastWrite_(lastWrite) {}
+
+ private:
+ void timeoutExpired() noexcept override {
+ WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
+ socket_->writeChain(wcb_, std::move(bufs_), flags);
+ if (lastWrite_) {
+ socket_->shutdownWrite();
+ }
+ }
+
+ std::shared_ptr<AsyncSocket> socket_;
+ unique_ptr<IOBuf> bufs_;
+ AsyncTransportWrapper::WriteCallback* wcb_;
+ bool cork_;
+ bool lastWrite_;
+};
+
+///////////////////////////////////////////////////////////////////////////
+// connect() tests
+///////////////////////////////////////////////////////////////////////////
+
+/**
+ * Test connecting to a server
+ */
+TEST(AsyncSocketTest, Connect) {
+ // Start listening on a local port
+ TestServer server;
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback cb;
+ socket->connect(&cb, server.getAddress(), 30);
+
+ evb.loop();
+
+ CHECK_EQ(cb.state, STATE_SUCCEEDED);
+}
+
+/**
+ * Test connecting to a server that isn't listening
+ */
+TEST(AsyncSocketTest, ConnectRefused) {
+ EventBase evb;
+
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+ // Hopefully nothing is actually listening on this address
+ folly::SocketAddress addr("127.0.0.1", 65535);
+ ConnCallback cb;
+ socket->connect(&cb, addr, 30);
+
+ evb.loop();
+
+ CHECK_EQ(cb.state, STATE_FAILED);
+ CHECK_EQ(cb.exception.getType(), AsyncSocketException::NOT_OPEN);
+}
+
+/**
+ * Test connection timeout
+ */
+TEST(AsyncSocketTest, ConnectTimeout) {
+ EventBase evb;
+
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+ // Try connecting to server that won't respond.
+ //
+ // This depends somewhat on the network where this test is run.
+ // Hopefully this IP will be routable but unresponsive.
+ // (Alternatively, we could try listening on a local raw socket, but that
+ // normally requires root privileges.)
+ folly::SocketAddress addr("8.8.8.8", 65535);
+ ConnCallback cb;
+ socket->connect(&cb, addr, 1); // also set a ridiculously small timeout
+
+ evb.loop();
+
+ CHECK_EQ(cb.state, STATE_FAILED);
+ CHECK_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);
+
+ // Verify that we can still get the peer address after a timeout.
+ // Use case is if the client was created from a client pool, and we want
+ // to log which peer failed.
+ folly::SocketAddress peer;
+ socket->getPeerAddress(&peer);
+ CHECK_EQ(peer, addr);
+}
+
+/**
+ * Test writing immediately after connecting, without waiting for connect
+ * to finish.
+ */
+TEST(AsyncSocketTest, ConnectAndWrite) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // write()
+ char buf[128];
+ memset(buf, 'a', sizeof(buf));
+ WriteCallback wcb;
+ socket->write(&wcb, buf, sizeof(buf));
+
+ // Loop. We don't bother accepting on the server socket yet.
+ // The kernel should be able to buffer the write request so it can succeed.
+ evb.loop();
+
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+ // Make sure the server got a connection and received the data
+ socket->close();
+ server.verifyConnection(buf, sizeof(buf));
+}
+
+/**
+ * Test connecting using a nullptr connect callback.
+ */
+TEST(AsyncSocketTest, ConnectNullCallback) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->connect(nullptr, server.getAddress(), 30);
+
+ // write some data, just so we have some way of verifing
+ // that the socket works correctly after connecting
+ char buf[128];
+ memset(buf, 'a', sizeof(buf));
+ WriteCallback wcb;
+ socket->write(&wcb, buf, sizeof(buf));
+
+ evb.loop();
+
+ CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+ // Make sure the server got a connection and received the data
+ socket->close();
+ server.verifyConnection(buf, sizeof(buf));
+}
+
+/**
+ * Test calling both write() and close() immediately after connecting, without
+ * waiting for connect to finish.
+ *
+ * This exercises the STATE_CONNECTING_CLOSING code.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndClose) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // write()
+ char buf[128];
+ memset(buf, 'a', sizeof(buf));
+ WriteCallback wcb;
+ socket->write(&wcb, buf, sizeof(buf));
+
+ // close()
+ socket->close();
+
+ // Loop. We don't bother accepting on the server socket yet.
+ // The kernel should be able to buffer the write request so it can succeed.
+ evb.loop();
+
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+ // Make sure the server got a connection and received the data
+ server.verifyConnection(buf, sizeof(buf));
+}
+
+/**
+ * Test calling close() immediately after connect()
+ */
+TEST(AsyncSocketTest, ConnectAndClose) {
+ TestServer server;
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Hopefully the connect didn't succeed immediately.
+ // If it did, we can't exercise the close-while-connecting code path.
+ if (ccb.state == STATE_SUCCEEDED) {
+ LOG(INFO) << "connect() succeeded immediately; aborting test "
+ "of close-during-connect behavior";
+ return;
+ }
+
+ socket->close();
+
+ // Loop, although there shouldn't be anything to do.
+ evb.loop();
+
+ // Make sure the connection was aborted
+ CHECK_EQ(ccb.state, STATE_FAILED);
+}
+
+/**
+ * Test calling closeNow() immediately after connect()
+ *
+ * This should be identical to the normal close behavior.
+ */
+TEST(AsyncSocketTest, ConnectAndCloseNow) {
+ TestServer server;
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Hopefully the connect didn't succeed immediately.
+ // If it did, we can't exercise the close-while-connecting code path.
+ if (ccb.state == STATE_SUCCEEDED) {
+ LOG(INFO) << "connect() succeeded immediately; aborting test "
+ "of closeNow()-during-connect behavior";
+ return;
+ }
+
+ socket->closeNow();
+
+ // Loop, although there shouldn't be anything to do.
+ evb.loop();
+
+ // Make sure the connection was aborted
+ CHECK_EQ(ccb.state, STATE_FAILED);
+}
+
+/**
+ * Test calling both write() and closeNow() immediately after connecting,
+ * without waiting for connect to finish.
+ *
+ * This should abort the pending write.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Hopefully the connect didn't succeed immediately.
+ // If it did, we can't exercise the close-while-connecting code path.
+ if (ccb.state == STATE_SUCCEEDED) {
+ LOG(INFO) << "connect() succeeded immediately; aborting test "
+ "of write-during-connect behavior";
+ return;
+ }
+
+ // write()
+ char buf[128];
+ memset(buf, 'a', sizeof(buf));
+ WriteCallback wcb;
+ socket->write(&wcb, buf, sizeof(buf));
+
+ // close()
+ socket->closeNow();
+
+ // Loop, although there shouldn't be anything to do.
+ evb.loop();
+
+ CHECK_EQ(ccb.state, STATE_FAILED);
+ CHECK_EQ(wcb.state, STATE_FAILED);
+}
+
+/**
+ * Test installing a read callback immediately, before connect() finishes.
+ */
+TEST(AsyncSocketTest, ConnectAndRead) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+
+ // Even though we haven't looped yet, we should be able to accept
+ // the connection and send data to it.
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+ uint8_t buf[128];
+ memset(buf, 'a', sizeof(buf));
+ acceptedSocket->write(buf, sizeof(buf));
+ acceptedSocket->flush();
+ acceptedSocket->close();
+
+ // Loop, although there shouldn't be anything to do.
+ evb.loop();
+
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.buffers.size(), 1);
+ CHECK_EQ(rcb.buffers[0].length, sizeof(buf));
+ CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
+}
+
+/**
+ * Test installing a read callback and then closing immediately before the
+ * connect attempt finishes.
+ */
+TEST(AsyncSocketTest, ConnectReadAndClose) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Hopefully the connect didn't succeed immediately.
+ // If it did, we can't exercise the close-while-connecting code path.
+ if (ccb.state == STATE_SUCCEEDED) {
+ LOG(INFO) << "connect() succeeded immediately; aborting test "
+ "of read-during-connect behavior";
+ return;
+ }
+
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+
+ // close()
+ socket->close();
+
+ // Loop, although there shouldn't be anything to do.
+ evb.loop();
+
+ CHECK_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
+ CHECK_EQ(rcb.buffers.size(), 0);
+ CHECK_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
+}
+
+/**
+ * Test both writing and installing a read callback immediately,
+ * before connect() finishes.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndRead) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // write()
+ char buf1[128];
+ memset(buf1, 'a', sizeof(buf1));
+ WriteCallback wcb;
+ socket->write(&wcb, buf1, sizeof(buf1));
+
+ // set a read callback
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+
+ // Even though we haven't looped yet, we should be able to accept
+ // the connection and send data to it.
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+ uint8_t buf2[128];
+ memset(buf2, 'b', sizeof(buf2));
+ acceptedSocket->write(buf2, sizeof(buf2));
+ acceptedSocket->flush();
+
+ // shut down the write half of acceptedSocket, so that the AsyncSocket
+ // will stop reading and we can break out of the event loop.
+ shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
+
+ // Loop
+ evb.loop();
+
+ // Make sure the connect succeeded
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+ // Make sure the AsyncSocket read the data written by the accepted socket
+ CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.buffers.size(), 1);
+ CHECK_EQ(rcb.buffers[0].length, sizeof(buf2));
+ CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);
+
+ // Close the AsyncSocket so we'll see EOF on acceptedSocket
+ socket->close();
+
+ // Make sure the accepted socket saw the data written by the AsyncSocket
+ uint8_t readbuf[sizeof(buf1)];
+ acceptedSocket->readAll(readbuf, sizeof(readbuf));
+ CHECK_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
+ uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
+ CHECK_EQ(bytesRead, 0);
+}
+
+/**
+ * Test writing to the socket then shutting down writes before the connect
+ * attempt finishes.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Hopefully the connect didn't succeed immediately.
+ // If it did, we can't exercise the write-while-connecting code path.
+ if (ccb.state == STATE_SUCCEEDED) {
+ LOG(INFO) << "connect() succeeded immediately; skipping test";
+ return;
+ }
+
+ // Ask to write some data
+ char wbuf[128];
+ memset(wbuf, 'a', sizeof(wbuf));
+ WriteCallback wcb;
+ socket->write(&wcb, wbuf, sizeof(wbuf));
+ socket->shutdownWrite();
+
+ // Shutdown writes
+ socket->shutdownWrite();
+
+ // Even though we haven't looped yet, we should be able to accept
+ // the connection.
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+ // Since the connection is still in progress, there should be no data to
+ // read yet. Verify that the accepted socket is not readable.
+ struct pollfd fds[1];
+ fds[0].fd = acceptedSocket->getSocketFD();
+ fds[0].events = POLLIN;
+ fds[0].revents = 0;
+ int rc = poll(fds, 1, 0);
+ CHECK_EQ(rc, 0);
+
+ // Write data to the accepted socket
+ uint8_t acceptedWbuf[192];
+ memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
+ acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
+ acceptedSocket->flush();
+
+ // Loop
+ evb.loop();
+
+ // The loop should have completed the connection, written the queued data,
+ // and shutdown writes on the socket.
+ //
+ // Check that the connection was completed successfully and that the write
+ // callback succeeded.
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+ // Check that we can read the data that was written to the socket, and that
+ // we see an EOF, since its socket was half-shutdown.
+ uint8_t readbuf[sizeof(wbuf)];
+ acceptedSocket->readAll(readbuf, sizeof(readbuf));
+ CHECK_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
+ uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
+ CHECK_EQ(bytesRead, 0);
+
+ // Close the accepted socket. This will cause it to see EOF
+ // and uninstall the read callback when we loop next.
+ acceptedSocket->close();
+
+ // Install a read callback, then loop again.
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+ evb.loop();
+
+ // This loop should have read the data and seen the EOF
+ CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.buffers.size(), 1);
+ CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
+ CHECK_EQ(memcmp(rcb.buffers[0].buffer,
+ acceptedWbuf, sizeof(acceptedWbuf)), 0);
+}
+
+/**
+ * Test reading, writing, and shutting down writes before the connect attempt
+ * finishes.
+ */
+TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Hopefully the connect didn't succeed immediately.
+ // If it did, we can't exercise the write-while-connecting code path.
+ if (ccb.state == STATE_SUCCEEDED) {
+ LOG(INFO) << "connect() succeeded immediately; skipping test";
+ return;
+ }
+
+ // Install a read callback
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+
+ // Ask to write some data
+ char wbuf[128];
+ memset(wbuf, 'a', sizeof(wbuf));
+ WriteCallback wcb;
+ socket->write(&wcb, wbuf, sizeof(wbuf));
+
+ // Shutdown writes
+ socket->shutdownWrite();
+
+ // Even though we haven't looped yet, we should be able to accept
+ // the connection.
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+ // Since the connection is still in progress, there should be no data to
+ // read yet. Verify that the accepted socket is not readable.
+ struct pollfd fds[1];
+ fds[0].fd = acceptedSocket->getSocketFD();
+ fds[0].events = POLLIN;
+ fds[0].revents = 0;
+ int rc = poll(fds, 1, 0);
+ CHECK_EQ(rc, 0);
+
+ // Write data to the accepted socket
+ uint8_t acceptedWbuf[192];
+ memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
+ acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
+ acceptedSocket->flush();
+ // Shutdown writes to the accepted socket. This will cause it to see EOF
+ // and uninstall the read callback.
+ ::shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
+
+ // Loop
+ evb.loop();
+
+ // The loop should have completed the connection, written the queued data,
+ // shutdown writes on the socket, read the data we wrote to it, and see the
+ // EOF.
+ //
+ // Check that the connection was completed successfully and that the read
+ // and write callbacks were invoked as expected.
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.buffers.size(), 1);
+ CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
+ CHECK_EQ(memcmp(rcb.buffers[0].buffer,
+ acceptedWbuf, sizeof(acceptedWbuf)), 0);
+ CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+ // Check that we can read the data that was written to the socket, and that
+ // we see an EOF, since its socket was half-shutdown.
+ uint8_t readbuf[sizeof(wbuf)];
+ acceptedSocket->readAll(readbuf, sizeof(readbuf));
+ CHECK_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
+ uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
+ CHECK_EQ(bytesRead, 0);
+
+ // Fully close both sockets
+ acceptedSocket->close();
+ socket->close();
+}
+
+/**
+ * Test reading, writing, and calling shutdownWriteNow() before the
+ * connect attempt finishes.
+ */
+TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Hopefully the connect didn't succeed immediately.
+ // If it did, we can't exercise the write-while-connecting code path.
+ if (ccb.state == STATE_SUCCEEDED) {
+ LOG(INFO) << "connect() succeeded immediately; skipping test";
+ return;
+ }
+
+ // Install a read callback
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+
+ // Ask to write some data
+ char wbuf[128];
+ memset(wbuf, 'a', sizeof(wbuf));
+ WriteCallback wcb;
+ socket->write(&wcb, wbuf, sizeof(wbuf));
+
+ // Shutdown writes immediately.
+ // This should immediately discard the data that we just tried to write.
+ socket->shutdownWriteNow();
+
+ // Verify that writeError() was invoked on the write callback.
+ CHECK_EQ(wcb.state, STATE_FAILED);
+ CHECK_EQ(wcb.bytesWritten, 0);
+
+ // Even though we haven't looped yet, we should be able to accept
+ // the connection.
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+ // Since the connection is still in progress, there should be no data to
+ // read yet. Verify that the accepted socket is not readable.
+ struct pollfd fds[1];
+ fds[0].fd = acceptedSocket->getSocketFD();
+ fds[0].events = POLLIN;
+ fds[0].revents = 0;
+ int rc = poll(fds, 1, 0);
+ CHECK_EQ(rc, 0);
+
+ // Write data to the accepted socket
+ uint8_t acceptedWbuf[192];
+ memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
+ acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
+ acceptedSocket->flush();
+ // Shutdown writes to the accepted socket. This will cause it to see EOF
+ // and uninstall the read callback.
+ ::shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
+
+ // Loop
+ evb.loop();
+
+ // The loop should have completed the connection, written the queued data,
+ // shutdown writes on the socket, read the data we wrote to it, and see the
+ // EOF.
+ //
+ // Check that the connection was completed successfully and that the read
+ // callback was invoked as expected.
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.buffers.size(), 1);
+ CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
+ CHECK_EQ(memcmp(rcb.buffers[0].buffer,
+ acceptedWbuf, sizeof(acceptedWbuf)), 0);
+
+ // Since we used shutdownWriteNow(), it should have discarded all pending
+ // write data. Verify we see an immediate EOF when reading from the accepted
+ // socket.
+ uint8_t readbuf[sizeof(wbuf)];
+ uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
+ CHECK_EQ(bytesRead, 0);
+
+ // Fully close both sockets
+ acceptedSocket->close();
+ socket->close();
+}
+
+// Helper function for use in testConnectOptWrite()
+// Temporarily disable the read callback
+void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
+ // Uninstall the read callback
+ socket->setReadCB(nullptr);
+ // Schedule the read callback to be reinstalled after 1ms
+ socket->getEventBase()->runInLoop(
+ std::bind(&AsyncSocket::setReadCB, socket, rcb));
+}
+
+/**
+ * Test connect+write, then have the connect callback perform another write.
+ *
+ * This tests interaction of the optimistic writing after connect with
+ * additional write attempts that occur in the connect callback.
+ */
+void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
+ TestServer server;
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+ // connect()
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Hopefully the connect didn't succeed immediately.
+ // If it did, we can't exercise the optimistic write code path.
+ if (ccb.state == STATE_SUCCEEDED) {
+ LOG(INFO) << "connect() succeeded immediately; aborting test "
+ "of optimistic write behavior";
+ return;
+ }
+
+ // Tell the connect callback to perform a write when the connect succeeds
+ WriteCallback wcb2;
+ scoped_array<char> buf2(new char[size2]);
+ memset(buf2.get(), 'b', size2);
+ if (size2 > 0) {
+ ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
+ // Tell the second write callback to close the connection when it is done
+ wcb2.successCallback = [&] { socket->closeNow(); };
+ }
+
+ // Schedule one write() immediately, before the connect finishes
+ scoped_array<char> buf1(new char[size1]);
+ memset(buf1.get(), 'a', size1);
+ WriteCallback wcb1;
+ if (size1 > 0) {
+ socket->write(&wcb1, buf1.get(), size1);
+ }
+
+ if (close) {
+ // immediately perform a close, before connect() completes
+ socket->close();
+ }
+
+ // Start reading from the other endpoint after 10ms.
+ // If we're using large buffers, we have to read so that the writes don't
+ // block forever.
+ std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
+ ReadCallback rcb;
+ rcb.dataAvailableCallback = std::bind(tmpDisableReads,
+ acceptedSocket.get(), &rcb);
+ socket->getEventBase()->tryRunAfterDelay(
+ std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb),
+ 10);
+
+ // Loop. We don't bother accepting on the server socket yet.
+ // The kernel should be able to buffer the write request so it can succeed.
+ evb.loop();
+
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+ if (size1 > 0) {
+ CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
+ }
+ if (size2 > 0) {
+ CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
+ }
+
+ socket->close();
+
+ // Make sure the read callback received all of the data
+ size_t bytesRead = 0;
+ for (vector<ReadCallback::Buffer>::const_iterator it = rcb.buffers.begin();
+ it != rcb.buffers.end();
+ ++it) {
+ size_t start = bytesRead;
+ bytesRead += it->length;
+ size_t end = bytesRead;
+ if (start < size1) {
+ size_t cmpLen = min(size1, end) - start;
+ CHECK_EQ(memcmp(it->buffer, buf1.get() + start, cmpLen), 0);
+ }
+ if (end > size1 && end <= size1 + size2) {
+ size_t itOffset;
+ size_t buf2Offset;
+ size_t cmpLen;
+ if (start >= size1) {
+ itOffset = 0;
+ buf2Offset = start - size1;
+ cmpLen = end - start;
+ } else {
+ itOffset = size1 - start;
+ buf2Offset = 0;
+ cmpLen = end - size1;
+ }
+ CHECK_EQ(memcmp(it->buffer + itOffset, buf2.get() + buf2Offset,
+ cmpLen),
+ 0);
+ }
+ }
+ CHECK_EQ(bytesRead, size1 + size2);
+}
+
+TEST(AsyncSocketTest, ConnectCallbackWrite) {
+ // Test using small writes that should both succeed immediately
+ testConnectOptWrite(100, 200);
+
+ // Test using a large buffer in the connect callback, that should block
+ const size_t largeSize = 8*1024*1024;
+ testConnectOptWrite(100, largeSize);
+
+ // Test using a large initial write
+ testConnectOptWrite(largeSize, 100);
+
+ // Test using two large buffers
+ testConnectOptWrite(largeSize, largeSize);
+
+ // Test a small write in the connect callback,
+ // but no immediate write before connect completes
+ testConnectOptWrite(0, 64);
+
+ // Test a large write in the connect callback,
+ // but no immediate write before connect completes
+ testConnectOptWrite(0, largeSize);
+
+ // Test connect, a small write, then immediately call close() before connect
+ // completes
+ testConnectOptWrite(211, 0, true);
+
+ // Test connect, a large immediate write (that will block), then immediately
+ // call close() before connect completes
+ testConnectOptWrite(largeSize, 0, true);
+}
+
+///////////////////////////////////////////////////////////////////////////
+// write() related tests
+///////////////////////////////////////////////////////////////////////////
+
+/**
+ * Test writing using a nullptr callback
+ */
+TEST(AsyncSocketTest, WriteNullCallback) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket =
+ AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+ evb.loop(); // loop until the socket is connected
+
+ // write() with a nullptr callback
+ char buf[128];
+ memset(buf, 'a', sizeof(buf));
+ socket->write(nullptr, buf, sizeof(buf));
+
+ evb.loop(); // loop until the data is sent
+
+ // Make sure the server got a connection and received the data
+ socket->close();
+ server.verifyConnection(buf, sizeof(buf));
+}
+
+/**
+ * Test writing with a send timeout
+ */
+TEST(AsyncSocketTest, WriteTimeout) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket =
+ AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+ evb.loop(); // loop until the socket is connected
+
+ // write() a large chunk of data, with no-one on the other end reading
+ size_t writeLength = 8*1024*1024;
+ uint32_t timeout = 200;
+ socket->setSendTimeout(timeout);
+ scoped_array<char> buf(new char[writeLength]);
+ memset(buf.get(), 'a', writeLength);
+ WriteCallback wcb;
+ socket->write(&wcb, buf.get(), writeLength);
+
+ TimePoint start;
+ evb.loop();
+ TimePoint end;
+
+ // Make sure the write attempt timed out as requested
+ CHECK_EQ(wcb.state, STATE_FAILED);
+ CHECK_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);
+
+ // Check that the write timed out within a reasonable period of time.
+ // We don't check for exactly the specified timeout, since AsyncSocket only
+ // times out when it hasn't made progress for that period of time.
+ //
+ // On linux, the first write sends a few hundred kb of data, then blocks for
+ // writability, and then unblocks again after 40ms and is able to write
+ // another smaller of data before blocking permanently. Therefore it doesn't
+ // time out until 40ms + timeout.
+ //
+ // I haven't fully verified the cause of this, but I believe it probably
+ // occurs because the receiving end delays sending an ack for up to 40ms.
+ // (This is the default value for TCP_DELACK_MIN.) Once the sender receives
+ // the ack, it can send some more data. However, after that point the
+ // receiver's kernel buffer is full. This 40ms delay happens even with
+ // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints. However, the
+ // kernel may be automatically disabling TCP_QUICKACK after receiving some
+ // data.
+ //
+ // For now, we simply check that the timeout occurred within 160ms of
+ // the requested value.
+ T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
+}
+
+/**
+ * Test writing to a socket that the remote endpoint has closed
+ */
+TEST(AsyncSocketTest, WritePipeError) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket =
+ AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+ socket->setSendTimeout(1000);
+ evb.loop(); // loop until the socket is connected
+
+ // accept and immediately close the socket
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+ acceptedSocket.reset();
+
+ // write() a large chunk of data
+ size_t writeLength = 8*1024*1024;
+ scoped_array<char> buf(new char[writeLength]);
+ memset(buf.get(), 'a', writeLength);
+ WriteCallback wcb;
+ socket->write(&wcb, buf.get(), writeLength);
+
+ evb.loop();
+
+ // Make sure the write failed.
+ // It would be nice if AsyncSocketException could convey the errno value,
+ // so that we could check for EPIPE
+ CHECK_EQ(wcb.state, STATE_FAILED);
+ CHECK_EQ(wcb.exception.getType(),
+ AsyncSocketException::INTERNAL_ERROR);
+}
+
+/**
+ * Test writing a mix of simple buffers and IOBufs
+ */
+TEST(AsyncSocketTest, WriteIOBuf) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Accept the connection
+ std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
+ ReadCallback rcb;
+ acceptedSocket->setReadCB(&rcb);
+
+ // Write a simple buffer to the socket
+ size_t simpleBufLength = 5;
+ char simpleBuf[simpleBufLength];
+ memset(simpleBuf, 'a', simpleBufLength);
+ WriteCallback wcb;
+ socket->write(&wcb, simpleBuf, simpleBufLength);
+
+ // Write a single-element IOBuf chain
+ size_t buf1Length = 7;
+ unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
+ memset(buf1->writableData(), 'b', buf1Length);
+ buf1->append(buf1Length);
+ unique_ptr<IOBuf> buf1Copy(buf1->clone());
+ WriteCallback wcb2;
+ socket->writeChain(&wcb2, std::move(buf1));
+
+ // Write a multiple-element IOBuf chain
+ size_t buf2Length = 11;
+ unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
+ memset(buf2->writableData(), 'c', buf2Length);
+ buf2->append(buf2Length);
+ size_t buf3Length = 13;
+ unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
+ memset(buf3->writableData(), 'd', buf3Length);
+ buf3->append(buf3Length);
+ buf2->appendChain(std::move(buf3));
+ unique_ptr<IOBuf> buf2Copy(buf2->clone());
+ buf2Copy->coalesce();
+ WriteCallback wcb3;
+ socket->writeChain(&wcb3, std::move(buf2));
+ socket->shutdownWrite();
+
+ // Let the reads and writes run to completion
+ evb.loop();
+
+ CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb3.state, STATE_SUCCEEDED);
+
+ // Make sure the reader got the right data in the right order
+ CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.buffers.size(), 1);
+ CHECK_EQ(rcb.buffers[0].length,
+ simpleBufLength + buf1Length + buf2Length + buf3Length);
+ CHECK_EQ(
+ memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
+ CHECK_EQ(
+ memcmp(rcb.buffers[0].buffer + simpleBufLength,
+ buf1Copy->data(), buf1Copy->length()), 0);
+ CHECK_EQ(
+ memcmp(rcb.buffers[0].buffer + simpleBufLength + buf1Length,
+ buf2Copy->data(), buf2Copy->length()), 0);
+
+ acceptedSocket->close();
+ socket->close();
+}
+
+TEST(AsyncSocketTest, WriteIOBufCorked) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Accept the connection
+ std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
+ ReadCallback rcb;
+ acceptedSocket->setReadCB(&rcb);
+
+ // Do three writes, 100ms apart, with the "cork" flag set
+ // on the second write. The reader should see the first write
+ // arrive by itself, followed by the second and third writes
+ // arriving together.
+ size_t buf1Length = 5;
+ unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
+ memset(buf1->writableData(), 'a', buf1Length);
+ buf1->append(buf1Length);
+ size_t buf2Length = 7;
+ unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
+ memset(buf2->writableData(), 'b', buf2Length);
+ buf2->append(buf2Length);
+ size_t buf3Length = 11;
+ unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
+ memset(buf3->writableData(), 'c', buf3Length);
+ buf3->append(buf3Length);
+ WriteCallback wcb1;
+ socket->writeChain(&wcb1, std::move(buf1));
+ WriteCallback wcb2;
+ DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
+ write2.scheduleTimeout(100);
+ WriteCallback wcb3;
+ DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
+ write3.scheduleTimeout(200);
+
+ evb.loop();
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
+ if (wcb3.state != STATE_SUCCEEDED) {
+ throw(wcb3.exception);
+ }
+ CHECK_EQ(wcb3.state, STATE_SUCCEEDED);
+
+ // Make sure the reader got the data with the right grouping
+ CHECK_EQ(rcb.state, STATE_SUCCEEDED);
+ CHECK_EQ(rcb.buffers.size(), 2);
+ CHECK_EQ(rcb.buffers[0].length, buf1Length);
+ CHECK_EQ(rcb.buffers[1].length, buf2Length + buf3Length);
+
+ acceptedSocket->close();
+ socket->close();
+}
+
+/**
+ * Test performing a zero-length write
+ */
+TEST(AsyncSocketTest, ZeroLengthWrite) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket =
+ AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+ evb.loop(); // loop until the socket is connected
+
+ auto acceptedSocket = server.acceptAsync(&evb);
+ ReadCallback rcb;
+ acceptedSocket->setReadCB(&rcb);
+
+ size_t len1 = 1024*1024;
+ size_t len2 = 1024*1024;
+ std::unique_ptr<char[]> buf(new char[len1 + len2]);
+ memset(buf.get(), 'a', len1);
+ memset(buf.get(), 'b', len2);
+
+ WriteCallback wcb1;
+ WriteCallback wcb2;
+ WriteCallback wcb3;
+ WriteCallback wcb4;
+ socket->write(&wcb1, buf.get(), 0);
+ socket->write(&wcb2, buf.get(), len1);
+ socket->write(&wcb3, buf.get() + len1, 0);
+ socket->write(&wcb4, buf.get() + len1, len2);
+ socket->close();
+
+ evb.loop(); // loop until the data is sent
+
+ CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb3.state, STATE_SUCCEEDED);
+ CHECK_EQ(wcb4.state, STATE_SUCCEEDED);
+ rcb.verifyData(buf.get(), len1 + len2);
+}
+
+TEST(AsyncSocketTest, ZeroLengthWritev) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket =
+ AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+ evb.loop(); // loop until the socket is connected
+
+ auto acceptedSocket = server.acceptAsync(&evb);
+ ReadCallback rcb;
+ acceptedSocket->setReadCB(&rcb);
+
+ size_t len1 = 1024*1024;
+ size_t len2 = 1024*1024;
+ std::unique_ptr<char[]> buf(new char[len1 + len2]);
+ memset(buf.get(), 'a', len1);
+ memset(buf.get(), 'b', len2);
+
+ WriteCallback wcb;
+ size_t iovCount = 4;
+ struct iovec iov[iovCount];
+ iov[0].iov_base = buf.get();
+ iov[0].iov_len = len1;
+ iov[1].iov_base = buf.get() + len1;
+ iov[1].iov_len = 0;
+ iov[2].iov_base = buf.get() + len1;
+ iov[2].iov_len = len2;
+ iov[3].iov_base = buf.get() + len1 + len2;
+ iov[3].iov_len = 0;
+
+ socket->writev(&wcb, iov, iovCount);
+ socket->close();
+ evb.loop(); // loop until the data is sent
+
+ CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+ rcb.verifyData(buf.get(), len1 + len2);
+}
+
+///////////////////////////////////////////////////////////////////////////
+// close() related tests
+///////////////////////////////////////////////////////////////////////////
+
+/**
+ * Test calling close() with pending writes when the socket is already closing.
+ */
+TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // accept the socket on the server side
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+ // Loop to ensure the connect has completed
+ evb.loop();
+
+ // Make sure we are connected
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+ // Schedule pending writes, until several write attempts have blocked
+ char buf[128];
+ memset(buf, 'a', sizeof(buf));
+ typedef vector< std::shared_ptr<WriteCallback> > WriteCallbackVector;
+ WriteCallbackVector writeCallbacks;
+
+ writeCallbacks.reserve(5);
+ while (writeCallbacks.size() < 5) {
+ std::shared_ptr<WriteCallback> wcb(new WriteCallback);
+
+ socket->write(wcb.get(), buf, sizeof(buf));
+ if (wcb->state == STATE_SUCCEEDED) {
+ // Succeeded immediately. Keep performing more writes
+ continue;
+ }
+
+ // This write is blocked.
+ // Have the write callback call close() when writeError() is invoked
+ wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
+ writeCallbacks.push_back(wcb);
+ }
+
+ // Call closeNow() to immediately fail the pending writes
+ socket->closeNow();
+
+ // Make sure writeError() was invoked on all of the pending write callbacks
+ for (WriteCallbackVector::const_iterator it = writeCallbacks.begin();
+ it != writeCallbacks.end();
+ ++it) {
+ CHECK_EQ((*it)->state, STATE_FAILED);
+ }
+}
+
+
+// TODO:
+// - Test connect() and have the connect callback set the read callback
+// - Test connect() and have the connect callback unset the read callback
+// - Test reading/writing/closing/destroying the socket in the connect callback
+// - Test reading/writing/closing/destroying the socket in the read callback
+// - Test reading/writing/closing/destroying the socket in the write callback
+// - Test one-way shutdown behavior
+// - Test changing the EventBase
+//
+// - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
+// in connectSuccess(), readDataAvailable(), writeSuccess()
+
+
+///////////////////////////////////////////////////////////////////////////
+// AsyncServerSocket tests
+///////////////////////////////////////////////////////////////////////////
+
+/**
+ * Helper AcceptCallback class for the test code
+ * It records the callbacks that were invoked, and also supports calling
+ * generic std::function objects in each callback.
+ */
+class TestAcceptCallback : public AsyncServerSocket::AcceptCallback {
+ public:
+ enum EventType {
+ TYPE_START,
+ TYPE_ACCEPT,
+ TYPE_ERROR,
+ TYPE_STOP
+ };
+ struct EventInfo {
+ EventInfo(int fd, const folly::SocketAddress& addr)
+ : type(TYPE_ACCEPT),
+ fd(fd),
+ address(addr),
+ errorMsg() {}
+ explicit EventInfo(const std::string& msg)
+ : type(TYPE_ERROR),
+ fd(-1),
+ address(),
+ errorMsg(msg) {}
+ explicit EventInfo(EventType et)
+ : type(et),
+ fd(-1),
+ address(),
+ errorMsg() {}
+
+ EventType type;
+ int fd; // valid for TYPE_ACCEPT
+ folly::SocketAddress address; // valid for TYPE_ACCEPT
+ string errorMsg; // valid for TYPE_ERROR
+ };
+ typedef std::deque<EventInfo> EventList;
+
+ TestAcceptCallback()
+ : connectionAcceptedFn_(),
+ acceptErrorFn_(),
+ acceptStoppedFn_(),
+ events_() {}
+
+ std::deque<EventInfo>* getEvents() {
+ return &events_;
+ }
+
+ void setConnectionAcceptedFn(
+ const std::function<void(int, const folly::SocketAddress&)>& fn) {
+ connectionAcceptedFn_ = fn;
+ }
+ void setAcceptErrorFn(const std::function<void(const std::exception&)>& fn) {
+ acceptErrorFn_ = fn;
+ }
+ void setAcceptStartedFn(const std::function<void()>& fn) {
+ acceptStartedFn_ = fn;
+ }
+ void setAcceptStoppedFn(const std::function<void()>& fn) {
+ acceptStoppedFn_ = fn;
+ }
+
+ void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
+ noexcept {
+ events_.push_back(EventInfo(fd, clientAddr));
+
+ if (connectionAcceptedFn_) {
+ connectionAcceptedFn_(fd, clientAddr);
+ }
+ }
+ void acceptError(const std::exception& ex) noexcept {
+ events_.push_back(EventInfo(ex.what()));
+
+ if (acceptErrorFn_) {
+ acceptErrorFn_(ex);
+ }
+ }
+ void acceptStarted() noexcept {
+ events_.push_back(EventInfo(TYPE_START));
+
+ if (acceptStartedFn_) {
+ acceptStartedFn_();
+ }
+ }
+ void acceptStopped() noexcept {
+ events_.push_back(EventInfo(TYPE_STOP));
+
+ if (acceptStoppedFn_) {
+ acceptStoppedFn_();
+ }
+ }
+
+ private:
+ std::function<void(int, const folly::SocketAddress&)> connectionAcceptedFn_;
+ std::function<void(const std::exception&)> acceptErrorFn_;
+ std::function<void()> acceptStartedFn_;
+ std::function<void()> acceptStoppedFn_;
+
+ std::deque<EventInfo> events_;
+};
+
+/**
+ * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
+ */
+TEST(AsyncSocketTest, ServerAcceptOptions) {
+ EventBase eventBase;
+
+ // Create a server socket
+ std::shared_ptr<AsyncServerSocket> serverSocket(
+ AsyncServerSocket::newSocket(&eventBase));
+ serverSocket->bind(0);
+ serverSocket->listen(16);
+ folly::SocketAddress serverAddress;
+ serverSocket->getAddress(&serverAddress);
+
+ // Add a callback to accept one connection then stop the loop
+ TestAcceptCallback acceptCallback;
+ acceptCallback.setConnectionAcceptedFn(
+ [&](int fd, const folly::SocketAddress& addr) {
+ serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+ });
+ acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
+ serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+ });
+ serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+ serverSocket->startAccepting();
+
+ // Connect to the server socket
+ std::shared_ptr<AsyncSocket> socket(
+ AsyncSocket::newSocket(&eventBase, serverAddress));
+
+ eventBase.loop();
+
+ // Verify that the server accepted a connection
+ CHECK_EQ(acceptCallback.getEvents()->size(), 3);
+ CHECK_EQ(acceptCallback.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(acceptCallback.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(acceptCallback.getEvents()->at(2).type,
+ TestAcceptCallback::TYPE_STOP);
+ int fd = acceptCallback.getEvents()->at(1).fd;
+
+ // The accepted connection should already be in non-blocking mode
+ int flags = fcntl(fd, F_GETFL, 0);
+ CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
+
+#ifndef TCP_NOPUSH
+ // The accepted connection should already have TCP_NODELAY set
+ int value;
+ socklen_t valueLength = sizeof(value);
+ int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
+ CHECK_EQ(rc, 0);
+ CHECK_EQ(value, 1);
+#endif
+}
+
+/**
+ * Test AsyncServerSocket::removeAcceptCallback()
+ */
+TEST(AsyncSocketTest, RemoveAcceptCallback) {
+ // Create a new AsyncServerSocket
+ EventBase eventBase;
+ std::shared_ptr<AsyncServerSocket> serverSocket(
+ AsyncServerSocket::newSocket(&eventBase));
+ serverSocket->bind(0);
+ serverSocket->listen(16);
+ folly::SocketAddress serverAddress;
+ serverSocket->getAddress(&serverAddress);
+
+ // Add several accept callbacks
+ TestAcceptCallback cb1;
+ TestAcceptCallback cb2;
+ TestAcceptCallback cb3;
+ TestAcceptCallback cb4;
+ TestAcceptCallback cb5;
+ TestAcceptCallback cb6;
+ TestAcceptCallback cb7;
+
+ // Test having callbacks remove other callbacks before them on the list,
+ // after them on the list, or removing themselves.
+ //
+ // Have callback 2 remove callback 3 and callback 5 the first time it is
+ // called.
+ int cb2Count = 0;
+ cb1.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+ std::shared_ptr<AsyncSocket> sock2(
+ AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2: -cb3 -cb5
+ });
+ cb3.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+ });
+ cb4.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+ std::shared_ptr<AsyncSocket> sock3(
+ AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
+ });
+ cb5.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+ std::shared_ptr<AsyncSocket> sock5(
+ AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
+
+ });
+ cb2.setConnectionAcceptedFn(
+ [&](int fd, const folly::SocketAddress& addr) {
+ if (cb2Count == 0) {
+ serverSocket->removeAcceptCallback(&cb3, nullptr);
+ serverSocket->removeAcceptCallback(&cb5, nullptr);
+ }
+ ++cb2Count;
+ });
+ // Have callback 6 remove callback 4 the first time it is called,
+ // and destroy the server socket the second time it is called
+ int cb6Count = 0;
+ cb6.setConnectionAcceptedFn(
+ [&](int fd, const folly::SocketAddress& addr) {
+ if (cb6Count == 0) {
+ serverSocket->removeAcceptCallback(&cb4, nullptr);
+ std::shared_ptr<AsyncSocket> sock6(
+ AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
+ std::shared_ptr<AsyncSocket> sock7(
+ AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
+ std::shared_ptr<AsyncSocket> sock8(
+ AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop
+
+ } else {
+ serverSocket.reset();
+ }
+ ++cb6Count;
+ });
+ // Have callback 7 remove itself
+ cb7.setConnectionAcceptedFn(
+ [&](int fd, const folly::SocketAddress& addr) {
+ serverSocket->removeAcceptCallback(&cb7, nullptr);
+ });
+
+ serverSocket->addAcceptCallback(&cb1, nullptr);
+ serverSocket->addAcceptCallback(&cb2, nullptr);
+ serverSocket->addAcceptCallback(&cb3, nullptr);
+ serverSocket->addAcceptCallback(&cb4, nullptr);
+ serverSocket->addAcceptCallback(&cb5, nullptr);
+ serverSocket->addAcceptCallback(&cb6, nullptr);
+ serverSocket->addAcceptCallback(&cb7, nullptr);
+ serverSocket->startAccepting();
+
+ // Make several connections to the socket
+ std::shared_ptr<AsyncSocket> sock1(
+ AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
+ std::shared_ptr<AsyncSocket> sock4(
+ AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4
+
+ // Loop until we are stopped
+ eventBase.loop();
+
+ // Check to make sure that the expected callbacks were invoked.
+ //
+ // NOTE: This code depends on the AsyncServerSocket operating calling all of
+ // the AcceptCallbacks in round-robin fashion, in the order that they were
+ // added. The code is implemented this way right now, but the API doesn't
+ // explicitly require it be done this way. If we change the code not to be
+ // exactly round robin in the future, we can simplify the test checks here.
+ // (We'll also need to update the termination code, since we expect cb6 to
+ // get called twice to terminate the loop.)
+ CHECK_EQ(cb1.getEvents()->size(), 4);
+ CHECK_EQ(cb1.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(cb1.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(cb1.getEvents()->at(2).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(cb1.getEvents()->at(3).type,
+ TestAcceptCallback::TYPE_STOP);
+
+ CHECK_EQ(cb2.getEvents()->size(), 4);
+ CHECK_EQ(cb2.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(cb2.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(cb2.getEvents()->at(2).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(cb2.getEvents()->at(3).type,
+ TestAcceptCallback::TYPE_STOP);
+
+ CHECK_EQ(cb3.getEvents()->size(), 2);
+ CHECK_EQ(cb3.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(cb3.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_STOP);
+
+ CHECK_EQ(cb4.getEvents()->size(), 3);
+ CHECK_EQ(cb4.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(cb4.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(cb4.getEvents()->at(2).type,
+ TestAcceptCallback::TYPE_STOP);
+
+ CHECK_EQ(cb5.getEvents()->size(), 2);
+ CHECK_EQ(cb5.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(cb5.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_STOP);
+
+ CHECK_EQ(cb6.getEvents()->size(), 4);
+ CHECK_EQ(cb6.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(cb6.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(cb6.getEvents()->at(2).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(cb6.getEvents()->at(3).type,
+ TestAcceptCallback::TYPE_STOP);
+
+ CHECK_EQ(cb7.getEvents()->size(), 3);
+ CHECK_EQ(cb7.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(cb7.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(cb7.getEvents()->at(2).type,
+ TestAcceptCallback::TYPE_STOP);
+}
+
+/**
+ * Test AsyncServerSocket::removeAcceptCallback()
+ */
+TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
+ // Create a new AsyncServerSocket
+ EventBase eventBase;
+ std::shared_ptr<AsyncServerSocket> serverSocket(
+ AsyncServerSocket::newSocket(&eventBase));
+ serverSocket->bind(0);
+ serverSocket->listen(16);
+ folly::SocketAddress serverAddress;
+ serverSocket->getAddress(&serverAddress);
+
+ // Add several accept callbacks
+ TestAcceptCallback cb1;
+ auto thread_id = pthread_self();
+ cb1.setAcceptStartedFn([&](){
+ CHECK_NE(thread_id, pthread_self());
+ thread_id = pthread_self();
+ });
+ cb1.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
+ CHECK_EQ(thread_id, pthread_self());
+ serverSocket->removeAcceptCallback(&cb1, nullptr);
+ });
+ cb1.setAcceptStoppedFn([&](){
+ CHECK_EQ(thread_id, pthread_self());
+ });
+
+ // Test having callbacks remove other callbacks before them on the list,
+ serverSocket->addAcceptCallback(&cb1, nullptr);
+ serverSocket->startAccepting();
+
+ // Make several connections to the socket
+ std::shared_ptr<AsyncSocket> sock1(
+ AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
+
+ // Loop in another thread
+ auto other = std::thread([&](){
+ eventBase.loop();
+ });
+ other.join();
+
+ // Check to make sure that the expected callbacks were invoked.
+ //
+ // NOTE: This code depends on the AsyncServerSocket operating calling all of
+ // the AcceptCallbacks in round-robin fashion, in the order that they were
+ // added. The code is implemented this way right now, but the API doesn't
+ // explicitly require it be done this way. If we change the code not to be
+ // exactly round robin in the future, we can simplify the test checks here.
+ // (We'll also need to update the termination code, since we expect cb6 to
+ // get called twice to terminate the loop.)
+ CHECK_EQ(cb1.getEvents()->size(), 3);
+ CHECK_EQ(cb1.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(cb1.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(cb1.getEvents()->at(2).type,
+ TestAcceptCallback::TYPE_STOP);
+
+}
+
+void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
+ // Add a callback to accept one connection then stop accepting
+ TestAcceptCallback acceptCallback;
+ acceptCallback.setConnectionAcceptedFn(
+ [&](int fd, const folly::SocketAddress& addr) {
+ serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+ });
+ acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
+ serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+ });
+ serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+ serverSocket->startAccepting();
+
+ // Connect to the server socket
+ EventBase* eventBase = serverSocket->getEventBase();
+ folly::SocketAddress serverAddress;
+ serverSocket->getAddress(&serverAddress);
+ AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
+
+ // Loop to process all events
+ eventBase->loop();
+
+ // Verify that the server accepted a connection
+ CHECK_EQ(acceptCallback.getEvents()->size(), 3);
+ CHECK_EQ(acceptCallback.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(acceptCallback.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(acceptCallback.getEvents()->at(2).type,
+ TestAcceptCallback::TYPE_STOP);
+}
+
+/* Verify that we don't leak sockets if we are destroyed()
+ * and there are still writes pending
+ *
+ * If destroy() only calls close() instead of closeNow(),
+ * it would shutdown(writes) on the socket, but it would
+ * never be close()'d, and the socket would leak
+ */
+TEST(AsyncSocketTest, DestroyCloseTest) {
+ TestServer server;
+
+ // connect()
+ EventBase clientEB;
+ EventBase serverEB;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Accept the connection
+ std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
+ ReadCallback rcb;
+ acceptedSocket->setReadCB(&rcb);
+
+ // Write a large buffer to the socket that is larger than kernel buffer
+ size_t simpleBufLength = 5000000;
+ char* simpleBuf = new char[simpleBufLength];
+ memset(simpleBuf, 'a', simpleBufLength);
+ WriteCallback wcb;
+
+ // Let the reads and writes run to completion
+ int fd = acceptedSocket->getFd();
+
+ acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
+ socket.reset();
+ acceptedSocket.reset();
+
+ // Test that server socket was closed
+ ssize_t sz = read(fd, simpleBuf, simpleBufLength);
+ CHECK_EQ(sz, -1);
+ CHECK_EQ(errno, 9);
+ delete[] simpleBuf;
+}
+
+/**
+ * Test AsyncServerSocket::useExistingSocket()
+ */
+TEST(AsyncSocketTest, ServerExistingSocket) {
+ EventBase eventBase;
+
+ // Test creating a socket, and letting AsyncServerSocket bind and listen
+ {
+ // Manually create a socket
+ int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ ASSERT_GE(fd, 0);
+
+ // Create a server socket
+ AsyncServerSocket::UniquePtr serverSocket(
+ new AsyncServerSocket(&eventBase));
+ serverSocket->useExistingSocket(fd);
+ folly::SocketAddress address;
+ serverSocket->getAddress(&address);
+ address.setPort(0);
+ serverSocket->bind(address);
+ serverSocket->listen(16);
+
+ // Make sure the socket works
+ serverSocketSanityTest(serverSocket.get());
+ }
+
+ // Test creating a socket and binding manually,
+ // then letting AsyncServerSocket listen
+ {
+ // Manually create a socket
+ int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ ASSERT_GE(fd, 0);
+ // bind
+ struct sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_port = 0;
+ addr.sin_addr.s_addr = INADDR_ANY;
+ CHECK_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)), 0);
+ // Look up the address that we bound to
+ folly::SocketAddress boundAddress;
+ boundAddress.setFromLocalAddress(fd);
+
+ // Create a server socket
+ AsyncServerSocket::UniquePtr serverSocket(
+ new AsyncServerSocket(&eventBase));
+ serverSocket->useExistingSocket(fd);
+ serverSocket->listen(16);
+
+ // Make sure AsyncServerSocket reports the same address that we bound to
+ folly::SocketAddress serverSocketAddress;
+ serverSocket->getAddress(&serverSocketAddress);
+ CHECK_EQ(boundAddress, serverSocketAddress);
+
+ // Make sure the socket works
+ serverSocketSanityTest(serverSocket.get());
+ }
+
+ // Test creating a socket, binding and listening manually,
+ // then giving it to AsyncServerSocket
+ {
+ // Manually create a socket
+ int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ ASSERT_GE(fd, 0);
+ // bind
+ struct sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_port = 0;
+ addr.sin_addr.s_addr = INADDR_ANY;
+ CHECK_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)), 0);
+ // Look up the address that we bound to
+ folly::SocketAddress boundAddress;
+ boundAddress.setFromLocalAddress(fd);
+ // listen
+ CHECK_EQ(listen(fd, 16), 0);
+
+ // Create a server socket
+ AsyncServerSocket::UniquePtr serverSocket(
+ new AsyncServerSocket(&eventBase));
+ serverSocket->useExistingSocket(fd);
+
+ // Make sure AsyncServerSocket reports the same address that we bound to
+ folly::SocketAddress serverSocketAddress;
+ serverSocket->getAddress(&serverSocketAddress);
+ CHECK_EQ(boundAddress, serverSocketAddress);
+
+ // Make sure the socket works
+ serverSocketSanityTest(serverSocket.get());
+ }
+}
+
+TEST(AsyncSocketTest, UnixDomainSocketTest) {
+ EventBase eventBase;
+
+ // Create a server socket
+ std::shared_ptr<AsyncServerSocket> serverSocket(
+ AsyncServerSocket::newSocket(&eventBase));
+ string path(1, 0);
+ path.append("/anonymous");
+ folly::SocketAddress serverAddress;
+ serverAddress.setFromPath(path);
+ serverSocket->bind(serverAddress);
+ serverSocket->listen(16);
+
+ // Add a callback to accept one connection then stop the loop
+ TestAcceptCallback acceptCallback;
+ acceptCallback.setConnectionAcceptedFn(
+ [&](int fd, const folly::SocketAddress& addr) {
+ serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+ });
+ acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
+ serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+ });
+ serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+ serverSocket->startAccepting();
+
+ // Connect to the server socket
+ std::shared_ptr<AsyncSocket> socket(
+ AsyncSocket::newSocket(&eventBase, serverAddress));
+
+ eventBase.loop();
+
+ // Verify that the server accepted a connection
+ CHECK_EQ(acceptCallback.getEvents()->size(), 3);
+ CHECK_EQ(acceptCallback.getEvents()->at(0).type,
+ TestAcceptCallback::TYPE_START);
+ CHECK_EQ(acceptCallback.getEvents()->at(1).type,
+ TestAcceptCallback::TYPE_ACCEPT);
+ CHECK_EQ(acceptCallback.getEvents()->at(2).type,
+ TestAcceptCallback::TYPE_STOP);
+ int fd = acceptCallback.getEvents()->at(1).fd;
+
+ // The accepted connection should already be in non-blocking mode
+ int flags = fcntl(fd, F_GETFL, 0);
+ CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
+}
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/Optional.h>
+#include <folly/io/async/SSLContext.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+
+class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
+ public folly::AsyncTransportWrapper::ReadCallback,
+ public folly::AsyncTransportWrapper::WriteCallback
+{
+ public:
+ explicit BlockingSocket(int fd)
+ : sock_(new folly::AsyncSocket(&eventBase_, fd)) {
+ }
+
+ BlockingSocket(folly::SocketAddress address,
+ std::shared_ptr<folly::SSLContext> sslContext)
+ : sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) :
+ new folly::AsyncSocket(&eventBase_)),
+ address_(address) {}
+
+ void open() {
+ sock_->connect(this, address_);
+ eventBase_.loop();
+ if (err_.hasValue()) {
+ throw err_.value();
+ }
+ }
+ void close() {
+ sock_->close();
+ }
+
+ int32_t write(uint8_t const* buf, size_t len) {
+ sock_->write(this, buf, len);
+ eventBase_.loop();
+ if (err_.hasValue()) {
+ throw err_.value();
+ }
+ return len;
+ }
+
+ void flush() {}
+
+ int32_t readAll(uint8_t *buf, size_t len) {
+ return readHelper(buf, len, true);
+ }
+
+ int32_t read(uint8_t *buf, size_t len) {
+ return readHelper(buf, len, false);
+ }
+
+ int getSocketFD() const {
+ return sock_->getFd();
+ }
+
+ private:
+ folly::EventBase eventBase_;
+ folly::AsyncSocket::UniquePtr sock_;
+ folly::Optional<folly::AsyncSocketException> err_;
+ uint8_t *readBuf_{nullptr};
+ size_t readLen_{0};
+ folly::SocketAddress address_;
+
+ void connectSuccess() noexcept override {}
+ void connectErr(const folly::AsyncSocketException& ex) noexcept override {
+ err_ = ex;
+ }
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *bufReturn = readBuf_;
+ *lenReturn = readLen_;
+ }
+ void readDataAvailable(size_t len) noexcept override {
+ readBuf_ += len;
+ readLen_ -= len;
+ if (readLen_ == 0) {
+ sock_->setReadCB(nullptr);
+ }
+ }
+ void readEOF() noexcept override {
+ }
+ void readErr(const folly::AsyncSocketException& ex) noexcept override {
+ err_ = ex;
+ }
+ void writeSuccess() noexcept override {}
+ void writeErr(size_t bytesWritten,
+ const folly::AsyncSocketException& ex) noexcept override {
+ err_ = ex;
+ }
+
+ int32_t readHelper(uint8_t *buf, size_t len, bool all) {
+ readBuf_ = buf;
+ readLen_ = len;
+ sock_->setReadCB(this);
+ while (!err_ && sock_->good() && readLen_ > 0) {
+ eventBase_.loop();
+ if (!all) {
+ break;
+ }
+ }
+ sock_->setReadCB(nullptr);
+ if (err_.hasValue()) {
+ throw err_.value();
+ }
+ if (all && readLen_ > 0) {
+ throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN,
+ "eof");
+ }
+ return len - readLen_;
+ }
+};
--- /dev/null
+-----BEGIN CERTIFICATE-----
+MIIDXTCCAkWgAwIBAgIJAKMZICGWUzawMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
+BAYTAlVTMQ8wDQYDVQQKDAZUaHJpZnQxJTAjBgNVBAMMHFRocmlmdCBDZXJ0aWZp
+Y2F0ZSBBdXRob3JpdHkwHhcNMTQwNTE2MjAyODUyWhcNNDExMDAxMjAyODUyWjBF
+MQswCQYDVQQGEwJVUzEPMA0GA1UECgwGVGhyaWZ0MSUwIwYDVQQDDBxUaHJpZnQg
+Q2VydGlmaWNhdGUgQXV0aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
+CgKCAQEA1Bx2vUvXZ8PrvEBxwdH5qM1F2Xo7UkeC1jzQ+OLUBEcCiEduyStitSvB
+NOAzAGdjt7NmHTP/7OJngp2vzQGjSQzm20XacyTieFUuPBuikUc0Ge3Tf+uQXtiU
+zZPh+xn6arHH+zBWtmUCt3cBrpgRqdnWUsbl8eqo5HsczY781FxQbDoT9VP6A+9R
+KGTsEhxxKbWJ1C7OngwLKc7Zv4DtTC1JFlFyKd8ryDtxP4s/GgsXJkoK0Hkpputr
+cMxMm6OGt77mFvzR2qRY1CpEK/9rjBB6Gqd8GakXsvoOsqL/37k2wVhN/JoS/Pde
+12Mp6TZ2rA8NW8vRujfWU0u55gnQnwIDAQABo1AwTjAdBgNVHQ4EFgQUQ00NGVmY
+NZ6LJg8UQUOVLZX1Gh8wHwYDVR0jBBgwFoAUQ00NGVmYNZ6LJg8UQUOVLZX1Gh8w
+DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQUFAAOCAQEAdlxt5+z9uXCBr1Wt6r49
+4MmOYw9lOnEOG1JPMRo108TLpmwXEWReCAtjQuR7BitRJW0kJtlO1M6t3qoIh6GA
+sBkgsjQM1xNY3YEpx71MLt1V+JD+2WtSBKMyysj1TiOmIH66kkvXO3ptXzhjhZyX
+G6B+kxLtxrqkn9SJULyN55X8T+dkW28UIBZVLavoREDU+UPrYU9JgZeIVObtGSWi
+DvS4RIJZNjgG3vTrT00rfUGEfTlI54Vbcmv0cYvswP/nMsLtDStCdgI7c/ipyJve
+dfuI4CedjE240AxK5OFxFg/k/IfnB4a5oojbdIR9hKrTU57TPaUVD50Na9WA1aqX
+5Q==
+-----END CERTIFICATE-----
--- /dev/null
+-----BEGIN CERTIFICATE-----
+MIIDKzCCAhOgAwIBAgIBCjANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJVUzEP
+MA0GA1UECgwGVGhyaWZ0MSUwIwYDVQQDDBxUaHJpZnQgQ2VydGlmaWNhdGUgQXV0
+aG9yaXR5MB4XDTE0MDUxNjIwMjg1MloXDTQxMTAwMTIwMjg1MlowRjELMAkGA1UE
+BhMCVVMxDTALBgNVBAgTBE9oaW8xETAPBgNVBAcTCEhpbGxpYXJkMRUwEwYDVQQD
+EwxBc294IENvbXBhbnkwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCz
+ZGrJ5XQHAuMYHlBgn32OOc9l0n3RjXccio2ceeWctXkSxDP3vFyZ4kBILF1lsY1g
+o8UTjMkqSDYcytCLK0qavrv9BZRLB9FcpqJ9o4V9feaI/HsHa8DYHyEs8qyNTTNG
+YQ3i4j+AA9iDSpezIYy/tyAOAjrSquUW1jI4tzKTBh8hk8MAMvR2/NPHPkrp4gI+
+EMH6u4vWdr4F9bbriLFWoU04T9mWOMk7G+h8BS9sgINg2+v5cWvl3BC4kLk5L1yJ
+FEyuofSSCEEe6dDf7uVh+RPKa4hEkIYo31AEOPFrN56d+pCj/5l67HTWXoQx3rjy
+dNXMvgU75urm6TQe8dB5AgMBAAGjJTAjMCEGA1UdEQQaMBiHBH8AAAGHEAAAAAAA
+AAAAAAAAAAAAAAEwDQYJKoZIhvcNAQEFBQADggEBAD26XYInaEvlWZJYgtl3yQyC
+3NRQc3LG7XxWg4aFdXCxYLPRAL2HLoarKYH8GPFso57t5xnhA8WfP7iJxmgsKdCS
+0pNIicOWsMmXvYLib0j9tMCFR+a8rn3f4n+clwnqas4w/vWBJUoMgyxtkP8NNNZO
+kIl02JKRhuyiFyPLilVp5tu0e+lmyUER+ak53WjLq2yoytYAlHkzkOpc4MZ/TNt5
+UTEtx/WVlZvlrPi3dsi7QikkjQgo1wCnm7owtuAHlPDMAB8wKk4+vvIOjsGM33T/
+8ffq/4X1HeYM0w0fM+SVlX1rwkXA1RW/jn48VWFHpWbE10+m196OdiToGfm2OJI=
+-----END CERTIFICATE-----
--- /dev/null
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEAs2RqyeV0BwLjGB5QYJ99jjnPZdJ90Y13HIqNnHnlnLV5EsQz
+97xcmeJASCxdZbGNYKPFE4zJKkg2HMrQiytKmr67/QWUSwfRXKaifaOFfX3miPx7
+B2vA2B8hLPKsjU0zRmEN4uI/gAPYg0qXsyGMv7cgDgI60qrlFtYyOLcykwYfIZPD
+ADL0dvzTxz5K6eICPhDB+ruL1na+BfW264ixVqFNOE/ZljjJOxvofAUvbICDYNvr
++XFr5dwQuJC5OS9ciRRMrqH0kghBHunQ3+7lYfkTymuIRJCGKN9QBDjxazeenfqQ
+o/+Zeux01l6EMd648nTVzL4FO+bq5uk0HvHQeQIDAQABAoIBAQCSPcBYindF5/Kd
+jMjVm+9M7I/IYAo1tG9vkvvSngSy9bWXuN7sjF+pCyqAK7qP1mh8acWVJGYx0+BZ
+JHVRnp8Y+3hg0hWL/PmN4EICzjVakjJHZhwddpglF2uCKurD3jV4oFIjrXE6uOfe
+UAbO/wCwoWa+RM8TQkGzljYmyiGufCcXlgEKMNA7TIvbJ9TVx3VTCOQy6EjZ13jd
+M6X7byV/ZOFpZ2H0QV46LvZraw04riXQ/59gVmzizYdI+BwnxxapsCmalTJoV/Y0
+LMI2ylat4PTMVTxPF+ti7Nt+rUkkEx6kuiAgfc+bzE4BSD5X4wy3fdLVLccoxXYw
+4N3fOuQhAoGBAOLrMhiSCrzXGjDWTbPrwzxXDO0qm+wURELi3N5SXIkKUdG2/In6
+wNdpXdvqblOm7SASgPf9KCwUSADrNw6R6nbfrrir5EHg66YydI/OW42QzJKcBUFh
+5Q5na3fvoL/zRhsmh0gEymBg+OIfNel2LY69bl8aAko2y0R1kj7zb8X1AoGBAMph
+9hlnkIBSw60+pKHaOqo2t/kihNNMFyfOgJvh8960eFeMDhMIXgxPUR8yaPX0bBMb
+bCdEJJ2pmq7zUBPvxVJLedwkGMhywElA8yYVh+S6x4Cg+lYo4spIjrHQ/WTvJkHB
+GrDskxdq80lbXjwRd0dPJZkxhKJec1o0n8S03Mn1AoGAGarK5taWGlgmYUHMVj6j
+vc6G6si4DFMaiYpJu2gLiYC+Un9lP2I6r+L+N+LjidjG16rgJazf/2Rn5Jq2hpJg
+uAODKuZekkkTvp/UaXPJDVFEooy9V3DwTNnL4SwcvbmRw35vLOlFzvMJE+K94WN5
+sbyhoGY7vhNGmL7HxREaIoUCgYEAwpteVWFz3yE2ziF9l7FMVh7V23go9zGk1n9I
+xhyJL26khbLEWeLi5L1kiTYlHdUSE3F8F2n8N6s+ddq79t/KA29WV6xSNHW7lvUg
+mk975CMC8hpZfn5ETjVlGXGYJ/Wa+QGiE9z5ODx8gt6cB/DXnLdrtRqbqrJeA7C0
+rScpY/0CgYBCC1QeuAiwWHOqQn3BwsZo9JQBTyT0QvDqLH/F+h9QbXep+4HvyAxG
+nTMNDtGyfyKGDaDUn5hyeU7Oxvzq0K9P+eZD3MjQeaMEg/++GPGUPmDUTqyb2UT8
+5s0NIUobxfKnTD6IpgOIq7ffvVY6cKBMyuLmu/gSvscsbONHjKti3Q==
+-----END RSA PRIVATE KEY-----
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/ShutdownSocketSet.h>
+
+#include <atomic>
+#include <chrono>
+#include <thread>
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <sys/socket.h>
+
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+
+using folly::ShutdownSocketSet;
+
+namespace folly { namespace test {
+
+ShutdownSocketSet shutdownSocketSet;
+
+class Server {
+ public:
+ Server();
+
+ void stop(bool abortive);
+ void join();
+ int port() const { return port_; }
+ int closeClients(bool abortive);
+
+ private:
+ int acceptSocket_;
+ int port_;
+ enum StopMode {
+ NO_STOP,
+ ORDERLY,
+ ABORTIVE
+ };
+ std::atomic<StopMode> stop_;
+ std::thread serverThread_;
+ std::vector<int> fds_;
+};
+
+Server::Server()
+ : acceptSocket_(-1),
+ port_(0),
+ stop_(NO_STOP) {
+ acceptSocket_ = socket(PF_INET, SOCK_STREAM, 0);
+ CHECK_ERR(acceptSocket_);
+ shutdownSocketSet.add(acceptSocket_);
+
+ sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_port = 0;
+ addr.sin_addr.s_addr = INADDR_ANY;
+ CHECK_ERR(bind(acceptSocket_,
+ reinterpret_cast<const sockaddr*>(&addr),
+ sizeof(addr)));
+
+ CHECK_ERR(listen(acceptSocket_, 10));
+
+ socklen_t addrLen = sizeof(addr);
+ CHECK_ERR(getsockname(acceptSocket_,
+ reinterpret_cast<sockaddr*>(&addr),
+ &addrLen));
+
+ port_ = ntohs(addr.sin_port);
+
+ serverThread_ = std::thread([this] {
+ while (stop_ == NO_STOP) {
+ sockaddr_in peer;
+ socklen_t peerLen = sizeof(peer);
+ int fd = accept(acceptSocket_,
+ reinterpret_cast<sockaddr*>(&peer),
+ &peerLen);
+ if (fd == -1) {
+ if (errno == EINTR) {
+ continue;
+ }
+ if (errno == EINVAL || errno == ENOTSOCK) { // socket broken
+ break;
+ }
+ }
+ CHECK_ERR(fd);
+ shutdownSocketSet.add(fd);
+ fds_.push_back(fd);
+ }
+
+ if (stop_ != NO_STOP) {
+ closeClients(stop_ == ABORTIVE);
+ }
+
+ shutdownSocketSet.close(acceptSocket_);
+ acceptSocket_ = -1;
+ port_ = 0;
+ });
+}
+
+int Server::closeClients(bool abortive) {
+ for (int fd : fds_) {
+ if (abortive) {
+ struct linger l = {1, 0};
+ CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)));
+ }
+ shutdownSocketSet.close(fd);
+ }
+ int n = fds_.size();
+ fds_.clear();
+ return n;
+}
+
+void Server::stop(bool abortive) {
+ stop_ = abortive ? ABORTIVE : ORDERLY;
+ shutdown(acceptSocket_, SHUT_RDWR);
+}
+
+void Server::join() {
+ serverThread_.join();
+}
+
+int createConnectedSocket(int port) {
+ int sock = socket(PF_INET, SOCK_STREAM, 0);
+ CHECK_ERR(sock);
+ sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_port = htons(port);
+ addr.sin_addr.s_addr = htonl((127 << 24) | 1); // XXX
+ CHECK_ERR(connect(sock,
+ reinterpret_cast<const sockaddr*>(&addr),
+ sizeof(addr)));
+ return sock;
+}
+
+void runCloseTest(bool abortive) {
+ Server server;
+
+ int sock = createConnectedSocket(server.port());
+
+ std::thread stopper([&server, abortive] {
+ std::this_thread::sleep_for(std::chrono::milliseconds(200));
+ server.stop(abortive);
+ server.join();
+ });
+
+ char c;
+ int r = read(sock, &c, 1);
+ if (abortive) {
+ int e = errno;
+ EXPECT_EQ(-1, r);
+ EXPECT_EQ(ECONNRESET, e);
+ } else {
+ EXPECT_EQ(0, r);
+ }
+
+ close(sock);
+
+ stopper.join();
+
+ EXPECT_EQ(0, server.closeClients(false)); // closed by server when it exited
+}
+
+TEST(ShutdownSocketSetTest, OrderlyClose) {
+ runCloseTest(false);
+}
+
+TEST(ShutdownSocketSetTest, AbortiveClose) {
+ runCloseTest(true);
+}
+
+void runKillTest(bool abortive) {
+ Server server;
+
+ int sock = createConnectedSocket(server.port());
+
+ std::thread killer([&server, abortive] {
+ std::this_thread::sleep_for(std::chrono::milliseconds(200));
+ shutdownSocketSet.shutdownAll(abortive);
+ server.join();
+ });
+
+ char c;
+ int r = read(sock, &c, 1);
+
+ // "abortive" is just a hint for ShutdownSocketSet, so accept both
+ // behaviors
+ if (abortive) {
+ if (r == -1) {
+ EXPECT_EQ(ECONNRESET, errno);
+ } else {
+ EXPECT_EQ(r, 0);
+ }
+ } else {
+ EXPECT_EQ(0, r);
+ }
+
+ close(sock);
+
+ killer.join();
+
+ // NOT closed by server when it exited
+ EXPECT_EQ(1, server.closeClients(false));
+}
+
+TEST(ShutdownSocketSetTest, OrderlyKill) {
+ runKillTest(false);
+}
+
+TEST(ShutdownSocketSetTest, AbortiveKill) {
+ runKillTest(true);
+}
+
+}} // namespaces
+
+int main(int argc, char *argv[]) {
+ testing::InitGoogleTest(&argc, argv);
+ google::ParseCommandLineFlags(&argc, &argv, true);
+ return RUN_ALL_TESTS();
+}