From: Alan Frindell Date: Thu, 2 Apr 2015 17:20:49 +0000 (-0700) Subject: Move AsyncSocket tests from thrift to folly X-Git-Tag: v0.33.0~4 X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=fe6985dd5fcf628083209274f0af8570556e96da;p=folly.git Move AsyncSocket tests from thrift to folly Summary: These tests belong with the code that they test. The old tests had a couple dependencies on TSocket/TSSLSocket, so I wrote a BlockingSocket wrapper for AsyncSocket/AsyncSSLSocket Test Plan: Ran the tests Reviewed By: alandau@fb.com Subscribers: doug, net-systems@, alandau, bmatheny, mshneer, folly-diffs@, yfeldblum, chalfant FB internal diff: D1959955 Signature: t1:1959955:1427917833:73d334846cf248f8bb215f3eb5b596df7f7cee4f --- diff --git a/folly/Makefile.am b/folly/Makefile.am index a5aec762..2dcf6771 100644 --- a/folly/Makefile.am +++ b/folly/Makefile.am @@ -162,6 +162,8 @@ nobase_follyinclude_HEADERS = \ io/async/Request.h \ io/async/SSLContext.h \ io/async/TimeoutManager.h \ + io/async/test/AsyncSSLSocketTest.h \ + io/async/test/BlockingSocket.h \ io/async/test/TimeUtil.h \ io/async/test/UndelayedDestruction.h \ io/async/test/Util.h \ diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp new file mode 100644 index 00000000..bc899791 --- /dev/null +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -0,0 +1,1221 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using std::string; +using std::vector; +using std::min; +using std::cerr; +using std::endl; +using std::list; + +namespace folly { +uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0; +uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0; +uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0; + +const char* testCert = "folly/io/async/test/certs/tests-cert.pem"; +const char* testKey = "folly/io/async/test/certs/tests-key.pem"; +const char* testCA = "folly/io/async/test/certs/ca-cert.pem"; + +TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) : +ctx_(new folly::SSLContext), + acb_(acb), + socket_(new folly::AsyncServerSocket(&evb_)) { + // Set up the SSL context + ctx_->loadCertificate(testCert); + ctx_->loadPrivateKey(testKey); + ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + acb_->ctx_ = ctx_; + acb_->base_ = &evb_; + + //set up the listening socket + socket_->bind(0); + socket_->getAddress(&address_); + socket_->listen(100); + socket_->addAcceptCallback(acb_, &evb_); + socket_->startAccepting(); + + int ret = pthread_create(&thread_, nullptr, Main, this); + assert(ret == 0); + + std::cerr << "Accepting connections on " << address_ << std::endl; +} + +void getfds(int fds[2]) { + if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) { + FAIL() << "failed to create socketpair: " << strerror(errno); + } + for (int idx = 0; idx < 2; ++idx) { + int flags = fcntl(fds[idx], F_GETFL, 0); + if (flags == -1) { + FAIL() << "failed to get flags for socket " << idx << ": " + << strerror(errno); + } + if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) { + FAIL() << "failed to put socket " << idx << " in non-blocking mode: " + << strerror(errno); + } + } +} + +void getctx( + std::shared_ptr clientCtx, + std::shared_ptr serverCtx) { + clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + serverCtx->loadCertificate( + testCert); + serverCtx->loadPrivateKey( + testKey); +} + +void sslsocketpair( + EventBase* eventBase, + AsyncSSLSocket::UniquePtr* clientSock, + AsyncSSLSocket::UniquePtr* serverSock) { + auto clientCtx = std::make_shared(); + auto serverCtx = std::make_shared(); + int fds[2]; + getfds(fds); + getctx(clientCtx, serverCtx); + clientSock->reset(new AsyncSSLSocket( + clientCtx, eventBase, fds[0], false)); + serverSock->reset(new AsyncSSLSocket( + serverCtx, eventBase, fds[1], true)); + + // (*clientSock)->setSendTimeout(100); + // (*serverSock)->setSendTimeout(100); +} + + +/** + * Test connecting to, writing to, reading from, and closing the + * connection to the SSL server. + */ +TEST(AsyncSSLSocketTest, ConnectWriteReadClose) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL context. + std::shared_ptr sslContext(new SSLContext()); + sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem"); + //sslContext->authenticate(true, false); + + // connect + auto socket = std::make_shared(server.getAddress(), + sslContext); + socket->open(); + + // write() + uint8_t buf[128]; + memset(buf, 'a', sizeof(buf)); + socket->write(buf, sizeof(buf)); + + // read() + uint8_t readbuf[128]; + uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf)); + EXPECT_EQ(bytesRead, 128); + EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0); + + // close() + socket->close(); + + cerr << "ConnectWriteReadClose test completed" << endl; +} + +/** + * Negative test for handshakeError(). + */ +TEST(AsyncSSLSocketTest, HandshakeError) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + HandshakeErrorCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL context. + std::shared_ptr sslContext(new SSLContext()); + sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + // connect + auto socket = std::make_shared(server.getAddress(), + sslContext); + // read() + bool ex = false; + try { + socket->open(); + + uint8_t readbuf[128]; + uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf)); + } catch (AsyncSocketException &e) { + ex = true; + } + EXPECT_TRUE(ex); + + // close() + socket->close(); + cerr << "HandshakeError test completed" << endl; +} + +/** + * Negative test for readError(). + */ +TEST(AsyncSSLSocketTest, ReadError) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadErrorCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL context. + std::shared_ptr sslContext(new SSLContext()); + sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + // connect + auto socket = std::make_shared(server.getAddress(), + sslContext); + socket->open(); + + // write something to trigger ssl handshake + uint8_t buf[128]; + memset(buf, 'a', sizeof(buf)); + socket->write(buf, sizeof(buf)); + + socket->close(); + cerr << "ReadError test completed" << endl; +} + +/** + * Negative test for writeError(). + */ +TEST(AsyncSSLSocketTest, WriteError) { + // Start listening on a local port + WriteCallbackBase writeCallback; + WriteErrorCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL context. + std::shared_ptr sslContext(new SSLContext()); + sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + // connect + auto socket = std::make_shared(server.getAddress(), + sslContext); + socket->open(); + + // write something to trigger ssl handshake + uint8_t buf[128]; + memset(buf, 'a', sizeof(buf)); + socket->write(buf, sizeof(buf)); + + socket->close(); + cerr << "WriteError test completed" << endl; +} + +/** + * Test a socket with TCP_NODELAY unset. + */ +TEST(AsyncSSLSocketTest, SocketWithDelay) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL context. + std::shared_ptr sslContext(new SSLContext()); + sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + // connect + auto socket = std::make_shared(server.getAddress(), + sslContext); + socket->open(); + + // write() + uint8_t buf[128]; + memset(buf, 'a', sizeof(buf)); + socket->write(buf, sizeof(buf)); + + // read() + uint8_t readbuf[128]; + uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf)); + EXPECT_EQ(bytesRead, 128); + EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0); + + // close() + socket->close(); + + cerr << "SocketWithDelay test completed" << endl; +} + +TEST(AsyncSSLSocketTest, NpnTestOverlap) { + EventBase eventBase; + std::shared_ptr clientCtx(new SSLContext); + std::shared_ptr serverCtx(new SSLContext);; + int fds[2]; + getfds(fds); + getctx(clientCtx, serverCtx); + + clientCtx->setAdvertisedNextProtocols({"blub","baz"}); + serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + NpnClient client(std::move(clientSock)); + NpnServer server(std::move(serverSock)); + + eventBase.loop(); + + EXPECT_TRUE(client.nextProtoLength != 0); + EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); + EXPECT_EQ(memcmp(client.nextProto, server.nextProto, + server.nextProtoLength), 0); + string selected((const char*)client.nextProto, client.nextProtoLength); + EXPECT_EQ(selected.compare("baz"), 0); +} + +TEST(AsyncSSLSocketTest, NpnTestUnset) { + // Identical to above test, except that we want unset NPN before + // looping. + EventBase eventBase; + std::shared_ptr clientCtx(new SSLContext); + std::shared_ptr serverCtx(new SSLContext);; + int fds[2]; + getfds(fds); + getctx(clientCtx, serverCtx); + + clientCtx->setAdvertisedNextProtocols({"blub","baz"}); + serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + + // unsetting NPN for any of [client, server] is enought to make NPN not + // work + clientCtx->unsetNextProtocols(); + + NpnClient client(std::move(clientSock)); + NpnServer server(std::move(serverSock)); + + eventBase.loop(); + + EXPECT_TRUE(client.nextProtoLength == 0); + EXPECT_TRUE(server.nextProtoLength == 0); + EXPECT_TRUE(client.nextProto == nullptr); + EXPECT_TRUE(server.nextProto == nullptr); +} + +TEST(AsyncSSLSocketTest, NpnTestNoOverlap) { + EventBase eventBase; + std::shared_ptr clientCtx(new SSLContext); + std::shared_ptr serverCtx(new SSLContext);; + int fds[2]; + getfds(fds); + getctx(clientCtx, serverCtx); + + clientCtx->setAdvertisedNextProtocols({"blub"}); + serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + NpnClient client(std::move(clientSock)); + NpnServer server(std::move(serverSock)); + + eventBase.loop(); + + EXPECT_TRUE(client.nextProtoLength != 0); + EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); + EXPECT_EQ(memcmp(client.nextProto, server.nextProto, + server.nextProtoLength), 0); + string selected((const char*)client.nextProto, client.nextProtoLength); + EXPECT_EQ(selected.compare("blub"), 0); +} + +TEST(AsyncSSLSocketTest, RandomizedNpnTest) { + // Probability that this test will fail is 2^-64, which could be considered + // as negligible. + const int kTries = 64; + + std::set selectedProtocols; + for (int i = 0; i < kTries; ++i) { + EventBase eventBase; + std::shared_ptr clientCtx = std::make_shared(); + std::shared_ptr serverCtx = std::make_shared(); + int fds[2]; + getfds(fds); + getctx(clientCtx, serverCtx); + + clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}); + serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, + {1, {"bar"}}}); + + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + NpnClient client(std::move(clientSock)); + NpnServer server(std::move(serverSock)); + + eventBase.loop(); + + EXPECT_TRUE(client.nextProtoLength != 0); + EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); + EXPECT_EQ(memcmp(client.nextProto, server.nextProto, + server.nextProtoLength), 0); + string selected((const char*)client.nextProto, client.nextProtoLength); + selectedProtocols.insert(selected); + } + EXPECT_EQ(selectedProtocols.size(), 2); +} + + +#ifndef OPENSSL_NO_TLSEXT +/** + * 1. Client sends TLSEXT_HOSTNAME in client hello. + * 2. Server found a match SSL_CTX and use this SSL_CTX to + * continue the SSL handshake. + * 3. Server sends back TLSEXT_HOSTNAME in server hello. + */ +TEST(AsyncSSLSocketTest, SNITestMatch) { + EventBase eventBase; + std::shared_ptr clientCtx(new SSLContext); + std::shared_ptr dfServerCtx(new SSLContext); + // Use the same SSLContext to continue the handshake after + // tlsext_hostname match. + std::shared_ptr hskServerCtx(dfServerCtx); + const std::string serverName("xyz.newdev.facebook.com"); + int fds[2]; + getfds(fds); + getctx(clientCtx, dfServerCtx); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + SNIClient client(std::move(clientSock)); + SNIServer server(std::move(serverSock), + dfServerCtx, + hskServerCtx, + serverName); + + eventBase.loop(); + + EXPECT_TRUE(client.serverNameMatch); + EXPECT_TRUE(server.serverNameMatch); +} + +/** + * 1. Client sends TLSEXT_HOSTNAME in client hello. + * 2. Server cannot find a matching SSL_CTX and continue to use + * the current SSL_CTX to do the handshake. + * 3. Server does not send back TLSEXT_HOSTNAME in server hello. + */ +TEST(AsyncSSLSocketTest, SNITestNotMatch) { + EventBase eventBase; + std::shared_ptr clientCtx(new SSLContext); + std::shared_ptr dfServerCtx(new SSLContext); + // Use the same SSLContext to continue the handshake after + // tlsext_hostname match. + std::shared_ptr hskServerCtx(dfServerCtx); + const std::string clientRequestingServerName("foo.com"); + const std::string serverExpectedServerName("xyz.newdev.facebook.com"); + + int fds[2]; + getfds(fds); + getctx(clientCtx, dfServerCtx); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, + &eventBase, + fds[0], + clientRequestingServerName)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + SNIClient client(std::move(clientSock)); + SNIServer server(std::move(serverSock), + dfServerCtx, + hskServerCtx, + serverExpectedServerName); + + eventBase.loop(); + + EXPECT_TRUE(!client.serverNameMatch); + EXPECT_TRUE(!server.serverNameMatch); +} + +/** + * 1. Client does not send TLSEXT_HOSTNAME in client hello. + * 2. Server does not send back TLSEXT_HOSTNAME in server hello. + */ +TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) { + EventBase eventBase; + std::shared_ptr clientCtx(new SSLContext); + std::shared_ptr dfServerCtx(new SSLContext); + // Use the same SSLContext to continue the handshake after + // tlsext_hostname match. + std::shared_ptr hskServerCtx(dfServerCtx); + const std::string serverExpectedServerName("xyz.newdev.facebook.com"); + + int fds[2]; + getfds(fds); + getctx(clientCtx, dfServerCtx); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + SNIClient client(std::move(clientSock)); + SNIServer server(std::move(serverSock), + dfServerCtx, + hskServerCtx, + serverExpectedServerName); + + eventBase.loop(); + + EXPECT_TRUE(!client.serverNameMatch); + EXPECT_TRUE(!server.serverNameMatch); +} + +#endif +/** + * Test SSL client socket + */ +TEST(AsyncSSLSocketTest, SSLClientTest) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL client + EventBase eventBase; + std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), + 1)); + + client->connect(); + EventBaseAborter eba(&eventBase, 3000); + eventBase.loop(); + + EXPECT_EQ(client->getMiss(), 1); + EXPECT_EQ(client->getHit(), 0); + + cerr << "SSLClientTest test completed" << endl; +} + + +/** + * Test SSL client socket session re-use + */ +TEST(AsyncSSLSocketTest, SSLClientTestReuse) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL client + EventBase eventBase; + std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), + 10)); + + client->connect(); + EventBaseAborter eba(&eventBase, 3000); + eventBase.loop(); + + EXPECT_EQ(client->getMiss(), 1); + EXPECT_EQ(client->getHit(), 9); + + cerr << "SSLClientTestReuse test completed" << endl; +} + +/** + * Test SSL client socket timeout + */ +TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) { + // Start listening on a local port + EmptyReadCallback readCallback; + HandshakeCallback handshakeCallback(&readCallback, + HandshakeCallback::EXPECT_ERROR); + HandshakeTimeoutCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL client + EventBase eventBase; + std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), + 1, 10)); + client->connect(true /* write before connect completes */); + EventBaseAborter eba(&eventBase, 3000); + eventBase.loop(); + + usleep(100000); + // This is checking that the connectError callback precedes any queued + // writeError callbacks. This matches AsyncSocket's behavior + EXPECT_EQ(client->getWriteAfterConnectErrors(), 1); + EXPECT_EQ(client->getErrors(), 1); + EXPECT_EQ(client->getMiss(), 0); + EXPECT_EQ(client->getHit(), 0); + + cerr << "SSLClientTimeoutTest test completed" << endl; +} + + +/** + * Test SSL server async cache + */ +TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback); + TestSSLAsyncCacheServer server(&acceptCallback); + + // Set up SSL client + EventBase eventBase; + std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), + 10, 500)); + + client->connect(); + EventBaseAborter eba(&eventBase, 3000); + eventBase.loop(); + + EXPECT_EQ(server.getAsyncCallbacks(), 18); + EXPECT_EQ(server.getAsyncLookups(), 9); + EXPECT_EQ(client->getMiss(), 10); + EXPECT_EQ(client->getHit(), 0); + + cerr << "SSLServerAsyncCacheTest test completed" << endl; +} + + +/** + * Test SSL server accept timeout with cache path + */ +TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + EmptyReadCallback clientReadCallback; + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50); + TestSSLAsyncCacheServer server(&acceptCallback); + + // Set up SSL client + EventBase eventBase; + // only do a TCP connect + std::shared_ptr sock = AsyncSocket::newSocket(&eventBase); + sock->connect(nullptr, server.getAddress()); + clientReadCallback.tcpSocket_ = sock; + sock->setReadCB(&clientReadCallback); + + EventBaseAborter eba(&eventBase, 3000); + eventBase.loop(); + + EXPECT_EQ(readCallback.state, STATE_WAITING); + + cerr << "SSLServerTimeoutTest test completed" << endl; +} + +/** + * Test SSL server accept timeout with cache path + */ +TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50); + TestSSLAsyncCacheServer server(&acceptCallback); + + // Set up SSL client + EventBase eventBase; + std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), + 2)); + + client->connect(); + EventBaseAborter eba(&eventBase, 3000); + eventBase.loop(); + + EXPECT_EQ(server.getAsyncCallbacks(), 1); + EXPECT_EQ(server.getAsyncLookups(), 1); + EXPECT_EQ(client->getErrors(), 1); + EXPECT_EQ(client->getMiss(), 1); + EXPECT_EQ(client->getHit(), 0); + + cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl; +} + +/** + * Test SSL server accept timeout with cache path + */ +TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback, + HandshakeCallback::EXPECT_ERROR); + SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback); + TestSSLAsyncCacheServer server(&acceptCallback, 500); + + // Set up SSL client + EventBase eventBase; + std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), + 2, 100)); + + client->connect(); + EventBaseAborter eba(&eventBase, 3000); + eventBase.loop(); + + server.getEventBase().runInEventBaseThread([&handshakeCallback]{ + handshakeCallback.closeSocket();}); + // give time for the cache lookup to come back and find it closed + usleep(500000); + + EXPECT_EQ(server.getAsyncCallbacks(), 1); + EXPECT_EQ(server.getAsyncLookups(), 1); + EXPECT_EQ(client->getErrors(), 1); + EXPECT_EQ(client->getMiss(), 1); + EXPECT_EQ(client->getHit(), 0); + + cerr << "SSLServerCacheCloseTest test completed" << endl; +} + +/** + * Verify Client Ciphers obtained using SSL MSG Callback. + */ +TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto serverCtx = std::make_shared(); + serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH"); + serverCtx->loadPrivateKey(testKey); + serverCtx->loadCertificate(testCert); + serverCtx->loadTrustedCertificates(testCA); + serverCtx->loadClientCAList(testCA); + + clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5"); + clientCtx->loadPrivateKey(testKey); + clientCtx->loadCertificate(testCert); + clientCtx->loadTrustedCertificates(testCA); + + int fds[2]; + getfds(fds); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + + SSLHandshakeClient client(std::move(clientSock), true, true); + SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true); + + eventBase.loop(); + + EXPECT_EQ(server.clientCiphers_, + "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff"); + EXPECT_TRUE(client.handshakeVerify_); + EXPECT_TRUE(client.handshakeSuccess_); + EXPECT_TRUE(!client.handshakeError_); + EXPECT_TRUE(server.handshakeVerify_); + EXPECT_TRUE(server.handshakeSuccess_); + EXPECT_TRUE(!server.handshakeError_); +} + +TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) { + EventBase eventBase; + auto ctx = std::make_shared(); + + int fds[2]; + getfds(fds); + + int bufLen = 42; + uint8_t majorVersion = 18; + uint8_t minorVersion = 25; + + // Create callback buf + auto buf = IOBuf::create(bufLen); + buf->append(bufLen); + folly::io::RWPrivateCursor cursor(buf.get()); + cursor.write(SSL3_MT_CLIENT_HELLO); + cursor.write(0); + cursor.write(38); + cursor.write(majorVersion); + cursor.write(minorVersion); + cursor.skip(32); + cursor.write(0); + + SSL* ssl = ctx->createSSL(); + AsyncSSLSocket::UniquePtr sock( + new AsyncSSLSocket(ctx, &eventBase, fds[0], true)); + sock->enableClientHelloParsing(); + + // Test client hello parsing in one packet + AsyncSSLSocket::clientHelloParsingCallback( + 0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get()); + buf.reset(); + + auto parsedClientHello = sock->getClientHelloInfo(); + EXPECT_TRUE(parsedClientHello != nullptr); + EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion); + EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion); +} + +TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) { + EventBase eventBase; + auto ctx = std::make_shared(); + + int fds[2]; + getfds(fds); + + int bufLen = 42; + uint8_t majorVersion = 18; + uint8_t minorVersion = 25; + + // Create callback buf + auto buf = IOBuf::create(bufLen); + buf->append(bufLen); + folly::io::RWPrivateCursor cursor(buf.get()); + cursor.write(SSL3_MT_CLIENT_HELLO); + cursor.write(0); + cursor.write(38); + cursor.write(majorVersion); + cursor.write(minorVersion); + cursor.skip(32); + cursor.write(0); + + SSL* ssl = ctx->createSSL(); + AsyncSSLSocket::UniquePtr sock( + new AsyncSSLSocket(ctx, &eventBase, fds[0], true)); + sock->enableClientHelloParsing(); + + // Test parsing with two packets with first packet size < 3 + auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2); + AsyncSSLSocket::clientHelloParsingCallback( + 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(), + ssl, sock.get()); + bufCopy.reset(); + bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2); + AsyncSSLSocket::clientHelloParsingCallback( + 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(), + ssl, sock.get()); + bufCopy.reset(); + + auto parsedClientHello = sock->getClientHelloInfo(); + EXPECT_TRUE(parsedClientHello != nullptr); + EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion); + EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion); +} + +TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) { + EventBase eventBase; + auto ctx = std::make_shared(); + + int fds[2]; + getfds(fds); + + int bufLen = 42; + uint8_t majorVersion = 18; + uint8_t minorVersion = 25; + + // Create callback buf + auto buf = IOBuf::create(bufLen); + buf->append(bufLen); + folly::io::RWPrivateCursor cursor(buf.get()); + cursor.write(SSL3_MT_CLIENT_HELLO); + cursor.write(0); + cursor.write(38); + cursor.write(majorVersion); + cursor.write(minorVersion); + cursor.skip(32); + cursor.write(0); + + SSL* ssl = ctx->createSSL(); + AsyncSSLSocket::UniquePtr sock( + new AsyncSSLSocket(ctx, &eventBase, fds[0], true)); + sock->enableClientHelloParsing(); + + // Test parsing with multiple small packets + for (uint64_t i = 0; i < buf->length(); i += 3) { + auto bufCopy = folly::IOBuf::copyBuffer( + buf->data() + i, std::min((uint64_t)3, buf->length() - i)); + AsyncSSLSocket::clientHelloParsingCallback( + 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(), + ssl, sock.get()); + bufCopy.reset(); + } + + auto parsedClientHello = sock->getClientHelloInfo(); + EXPECT_TRUE(parsedClientHello != nullptr); + EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion); + EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion); +} + +/** + * Verify sucessful behavior of SSL certificate validation. + */ +TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto dfServerCtx = std::make_shared(); + + int fds[2]; + getfds(fds); + getctx(clientCtx, dfServerCtx); + + clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + + SSLHandshakeClient client(std::move(clientSock), true, true); + clientCtx->loadTrustedCertificates(testCA); + + SSLHandshakeServer server(std::move(serverSock), true, true); + + eventBase.loop(); + + EXPECT_TRUE(client.handshakeVerify_); + EXPECT_TRUE(client.handshakeSuccess_); + EXPECT_TRUE(!client.handshakeError_); + EXPECT_TRUE(!server.handshakeVerify_); + EXPECT_TRUE(server.handshakeSuccess_); + EXPECT_TRUE(!server.handshakeError_); +} + +/** + * Verify that the client's verification callback is able to fail SSL + * connection establishment. + */ +TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto dfServerCtx = std::make_shared(); + + int fds[2]; + getfds(fds); + getctx(clientCtx, dfServerCtx); + + clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + + SSLHandshakeClient client(std::move(clientSock), true, false); + clientCtx->loadTrustedCertificates(testCA); + + SSLHandshakeServer server(std::move(serverSock), true, true); + + eventBase.loop(); + + EXPECT_TRUE(client.handshakeVerify_); + EXPECT_TRUE(!client.handshakeSuccess_); + EXPECT_TRUE(client.handshakeError_); + EXPECT_TRUE(!server.handshakeVerify_); + EXPECT_TRUE(!server.handshakeSuccess_); + EXPECT_TRUE(server.handshakeError_); +} + +/** + * Verify that the options in SSLContext can be overridden in + * sslConnect/Accept.i.e specifying that no validation should be performed + * allows an otherwise-invalid certificate to be accepted and doesn't fire + * the validation callback. + */ +TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto dfServerCtx = std::make_shared(); + + int fds[2]; + getfds(fds); + getctx(clientCtx, dfServerCtx); + + clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + + SSLHandshakeClientNoVerify client(std::move(clientSock), false, false); + clientCtx->loadTrustedCertificates(testCA); + + SSLHandshakeServerNoVerify server(std::move(serverSock), false, false); + + eventBase.loop(); + + EXPECT_TRUE(!client.handshakeVerify_); + EXPECT_TRUE(client.handshakeSuccess_); + EXPECT_TRUE(!client.handshakeError_); + EXPECT_TRUE(!server.handshakeVerify_); + EXPECT_TRUE(server.handshakeSuccess_); + EXPECT_TRUE(!server.handshakeError_); +} + +/** + * Verify that the options in SSLContext can be overridden in + * sslConnect/Accept. Enable verification even if context says otherwise. + * Test requireClientCert with client cert + */ +TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto serverCtx = std::make_shared(); + serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); + serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + serverCtx->loadPrivateKey(testKey); + serverCtx->loadCertificate(testCert); + serverCtx->loadTrustedCertificates(testCA); + serverCtx->loadClientCAList(testCA); + + clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); + clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + clientCtx->loadPrivateKey(testKey); + clientCtx->loadCertificate(testCert); + clientCtx->loadTrustedCertificates(testCA); + + int fds[2]; + getfds(fds); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + + SSLHandshakeClientDoVerify client(std::move(clientSock), true, true); + SSLHandshakeServerDoVerify server(std::move(serverSock), true, true); + + eventBase.loop(); + + EXPECT_TRUE(client.handshakeVerify_); + EXPECT_TRUE(client.handshakeSuccess_); + EXPECT_FALSE(client.handshakeError_); + EXPECT_TRUE(server.handshakeVerify_); + EXPECT_TRUE(server.handshakeSuccess_); + EXPECT_FALSE(server.handshakeError_); +} + +/** + * Verify that the client's verification callback is able to override + * the preverification failure and allow a successful connection. + */ +TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto dfServerCtx = std::make_shared(); + + int fds[2]; + getfds(fds); + getctx(clientCtx, dfServerCtx); + + clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + + SSLHandshakeClient client(std::move(clientSock), false, true); + SSLHandshakeServer server(std::move(serverSock), true, true); + + eventBase.loop(); + + EXPECT_TRUE(client.handshakeVerify_); + EXPECT_TRUE(client.handshakeSuccess_); + EXPECT_TRUE(!client.handshakeError_); + EXPECT_TRUE(!server.handshakeVerify_); + EXPECT_TRUE(server.handshakeSuccess_); + EXPECT_TRUE(!server.handshakeError_); +} + +/** + * Verify that specifying that no validation should be performed allows an + * otherwise-invalid certificate to be accepted and doesn't fire the validation + * callback. + */ +TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto dfServerCtx = std::make_shared(); + + int fds[2]; + getfds(fds); + getctx(clientCtx, dfServerCtx); + + clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); + dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + + SSLHandshakeClient client(std::move(clientSock), false, false); + SSLHandshakeServer server(std::move(serverSock), false, false); + + eventBase.loop(); + + EXPECT_TRUE(!client.handshakeVerify_); + EXPECT_TRUE(client.handshakeSuccess_); + EXPECT_TRUE(!client.handshakeError_); + EXPECT_TRUE(!server.handshakeVerify_); + EXPECT_TRUE(server.handshakeSuccess_); + EXPECT_TRUE(!server.handshakeError_); +} + +/** + * Test requireClientCert with client cert + */ +TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto serverCtx = std::make_shared(); + serverCtx->setVerificationOption( + SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); + serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + serverCtx->loadPrivateKey(testKey); + serverCtx->loadCertificate(testCert); + serverCtx->loadTrustedCertificates(testCA); + serverCtx->loadClientCAList(testCA); + + clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); + clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + clientCtx->loadPrivateKey(testKey); + clientCtx->loadCertificate(testCert); + clientCtx->loadTrustedCertificates(testCA); + + int fds[2]; + getfds(fds); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + + SSLHandshakeClient client(std::move(clientSock), true, true); + SSLHandshakeServer server(std::move(serverSock), true, true); + + eventBase.loop(); + + EXPECT_TRUE(client.handshakeVerify_); + EXPECT_TRUE(client.handshakeSuccess_); + EXPECT_FALSE(client.handshakeError_); + EXPECT_TRUE(server.handshakeVerify_); + EXPECT_TRUE(server.handshakeSuccess_); + EXPECT_FALSE(server.handshakeError_); +} + + +/** + * Test requireClientCert with no client cert + */ +TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto serverCtx = std::make_shared(); + serverCtx->setVerificationOption( + SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); + serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + serverCtx->loadPrivateKey(testKey); + serverCtx->loadCertificate(testCert); + serverCtx->loadTrustedCertificates(testCA); + serverCtx->loadClientCAList(testCA); + clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); + clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + int fds[2]; + getfds(fds); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + + SSLHandshakeClient client(std::move(clientSock), false, false); + SSLHandshakeServer server(std::move(serverSock), false, false); + + eventBase.loop(); + + EXPECT_FALSE(server.handshakeVerify_); + EXPECT_FALSE(server.handshakeSuccess_); + EXPECT_TRUE(server.handshakeError_); +} +} + +/////////////////////////////////////////////////////////////////////////// +// init_unit_test_suite +/////////////////////////////////////////////////////////////////////////// +namespace { +struct Initializer { + Initializer() { + signal(SIGPIPE, SIG_IGN); + } +}; +Initializer initializer; +} // anonymous diff --git a/folly/io/async/test/AsyncSSLSocketTest.h b/folly/io/async/test/AsyncSSLSocketTest.h new file mode 100644 index 00000000..78623f77 --- /dev/null +++ b/folly/io/async/test/AsyncSSLSocketTest.h @@ -0,0 +1,1277 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace folly { + +enum StateEnum { + STATE_WAITING, + STATE_SUCCEEDED, + STATE_FAILED +}; + +// The destructors of all callback classes assert that the state is +// STATE_SUCCEEDED, for both possitive and negative tests. The tests +// are responsible for setting the succeeded state properly before the +// destructors are called. + +class WriteCallbackBase : +public AsyncTransportWrapper::WriteCallback { +public: + WriteCallbackBase() + : state(STATE_WAITING) + , bytesWritten(0) + , exception(AsyncSocketException::UNKNOWN, "none") {} + + ~WriteCallbackBase() { + EXPECT_EQ(state, STATE_SUCCEEDED); + } + + void setSocket( + const std::shared_ptr &socket) { + socket_ = socket; + } + + void writeSuccess() noexcept override { + std::cerr << "writeSuccess" << std::endl; + state = STATE_SUCCEEDED; + } + + void writeErr( + size_t bytesWritten, + const AsyncSocketException& ex) noexcept override { + std::cerr << "writeError: bytesWritten " << bytesWritten + << ", exception " << ex.what() << std::endl; + + state = STATE_FAILED; + this->bytesWritten = bytesWritten; + exception = ex; + socket_->close(); + socket_->detachEventBase(); + } + + std::shared_ptr socket_; + StateEnum state; + size_t bytesWritten; + AsyncSocketException exception; +}; + +class ReadCallbackBase : +public AsyncTransportWrapper::ReadCallback { +public: + explicit ReadCallbackBase(WriteCallbackBase *wcb) + : wcb_(wcb) + , state(STATE_WAITING) {} + + ~ReadCallbackBase() { + EXPECT_EQ(state, STATE_SUCCEEDED); + } + + void setSocket( + const std::shared_ptr &socket) { + socket_ = socket; + } + + void setState(StateEnum s) { + state = s; + if (wcb_) { + wcb_->state = s; + } + } + + void readErr( + const AsyncSocketException& ex) noexcept override { + std::cerr << "readError " << ex.what() << std::endl; + state = STATE_FAILED; + socket_->close(); + socket_->detachEventBase(); + } + + void readEOF() noexcept override { + std::cerr << "readEOF" << std::endl; + + socket_->close(); + socket_->detachEventBase(); + } + + std::shared_ptr socket_; + WriteCallbackBase *wcb_; + StateEnum state; +}; + +class ReadCallback : public ReadCallbackBase { +public: + explicit ReadCallback(WriteCallbackBase *wcb) + : ReadCallbackBase(wcb) + , buffers() {} + + ~ReadCallback() { + for (std::vector::iterator it = buffers.begin(); + it != buffers.end(); + ++it) { + it->free(); + } + currentBuffer.free(); + } + + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + if (!currentBuffer.buffer) { + currentBuffer.allocate(4096); + } + *bufReturn = currentBuffer.buffer; + *lenReturn = currentBuffer.length; + } + + void readDataAvailable(size_t len) noexcept override { + std::cerr << "readDataAvailable, len " << len << std::endl; + + currentBuffer.length = len; + + wcb_->setSocket(socket_); + + // Write back the same data. + socket_->write(wcb_, currentBuffer.buffer, len); + + buffers.push_back(currentBuffer); + currentBuffer.reset(); + state = STATE_SUCCEEDED; + } + + class Buffer { + public: + Buffer() : buffer(nullptr), length(0) {} + Buffer(char* buf, size_t len) : buffer(buf), length(len) {} + + void reset() { + buffer = nullptr; + length = 0; + } + void allocate(size_t length) { + assert(buffer == nullptr); + this->buffer = static_cast(malloc(length)); + this->length = length; + } + void free() { + ::free(buffer); + reset(); + } + + char* buffer; + size_t length; + }; + + std::vector buffers; + Buffer currentBuffer; +}; + +class ReadErrorCallback : public ReadCallbackBase { +public: + explicit ReadErrorCallback(WriteCallbackBase *wcb) + : ReadCallbackBase(wcb) {} + + // Return nullptr buffer to trigger readError() + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + *bufReturn = nullptr; + *lenReturn = 0; + } + + void readDataAvailable(size_t len) noexcept override { + // This should never to called. + FAIL(); + } + + void readErr( + const AsyncSocketException& ex) noexcept override { + ReadCallbackBase::readErr(ex); + std::cerr << "ReadErrorCallback::readError" << std::endl; + setState(STATE_SUCCEEDED); + } +}; + +class WriteErrorCallback : public ReadCallback { +public: + explicit WriteErrorCallback(WriteCallbackBase *wcb) + : ReadCallback(wcb) {} + + void readDataAvailable(size_t len) noexcept override { + std::cerr << "readDataAvailable, len " << len << std::endl; + + currentBuffer.length = len; + + // close the socket before writing to trigger writeError(). + ::close(socket_->getFd()); + + wcb_->setSocket(socket_); + + // Write back the same data. + socket_->write(wcb_, currentBuffer.buffer, len); + + if (wcb_->state == STATE_FAILED) { + setState(STATE_SUCCEEDED); + } else { + state = STATE_FAILED; + } + + buffers.push_back(currentBuffer); + currentBuffer.reset(); + } + + void readErr(const AsyncSocketException& ex) noexcept override { + std::cerr << "readError " << ex.what() << std::endl; + // do nothing since this is expected + } +}; + +class EmptyReadCallback : public ReadCallback { +public: + explicit EmptyReadCallback() + : ReadCallback(nullptr) {} + + void readErr(const AsyncSocketException& ex) noexcept override { + std::cerr << "readError " << ex.what() << std::endl; + state = STATE_FAILED; + tcpSocket_->close(); + tcpSocket_->detachEventBase(); + } + + void readEOF() noexcept override { + std::cerr << "readEOF" << std::endl; + + tcpSocket_->close(); + tcpSocket_->detachEventBase(); + state = STATE_SUCCEEDED; + } + + std::shared_ptr tcpSocket_; +}; + +class HandshakeCallback : +public AsyncSSLSocket::HandshakeCB { +public: + enum ExpectType { + EXPECT_SUCCESS, + EXPECT_ERROR + }; + + explicit HandshakeCallback(ReadCallbackBase *rcb, + ExpectType expect = EXPECT_SUCCESS): + state(STATE_WAITING), + rcb_(rcb), + expect_(expect) {} + + void setSocket( + const std::shared_ptr &socket) { + socket_ = socket; + } + + void setState(StateEnum s) { + state = s; + rcb_->setState(s); + } + + // Functions inherited from AsyncSSLSocketHandshakeCallback + void handshakeSuc(AsyncSSLSocket *sock) noexcept override { + EXPECT_EQ(sock, socket_.get()); + std::cerr << "HandshakeCallback::connectionAccepted" << std::endl; + rcb_->setSocket(socket_); + sock->setReadCB(rcb_); + state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED; + } + void handshakeErr( + AsyncSSLSocket *sock, + const AsyncSocketException& ex) noexcept override { + std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl; + state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED; + if (expect_ == EXPECT_ERROR) { + // rcb will never be invoked + rcb_->setState(STATE_SUCCEEDED); + } + } + + ~HandshakeCallback() { + EXPECT_EQ(state, STATE_SUCCEEDED); + } + + void closeSocket() { + socket_->close(); + state = STATE_SUCCEEDED; + } + + StateEnum state; + std::shared_ptr socket_; + ReadCallbackBase *rcb_; + ExpectType expect_; +}; + +class SSLServerAcceptCallbackBase: +public folly::AsyncServerSocket::AcceptCallback { +public: + explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb): + state(STATE_WAITING), hcb_(hcb) {} + + ~SSLServerAcceptCallbackBase() { + EXPECT_EQ(state, STATE_SUCCEEDED); + } + + void acceptError(const std::exception& ex) noexcept override { + std::cerr << "SSLServerAcceptCallbackBase::acceptError " + << ex.what() << std::endl; + state = STATE_FAILED; + } + + void connectionAccepted(int fd, const folly::SocketAddress& clientAddr) + noexcept override{ + printf("Connection accepted\n"); + std::shared_ptr sslSock; + try { + // Create a AsyncSSLSocket object with the fd. The socket should be + // added to the event base and in the state of accepting SSL connection. + sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd); + } catch (const std::exception &e) { + LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket " + "object with socket " << e.what() << fd; + ::close(fd); + acceptError(e); + return; + } + + connAccepted(sslSock); + } + + virtual void connAccepted( + const std::shared_ptr &s) = 0; + + StateEnum state; + HandshakeCallback *hcb_; + std::shared_ptr ctx_; + folly::EventBase* base_; +}; + +class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase { +public: + uint32_t timeout_; + + explicit SSLServerAcceptCallback(HandshakeCallback *hcb, + uint32_t timeout = 0): + SSLServerAcceptCallbackBase(hcb), + timeout_(timeout) {} + + virtual ~SSLServerAcceptCallback() { + if (timeout_ > 0) { + // if we set a timeout, we expect failure + EXPECT_EQ(hcb_->state, STATE_FAILED); + hcb_->setState(STATE_SUCCEEDED); + } + } + + // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback + void connAccepted( + const std::shared_ptr &s) + noexcept override { + auto sock = std::static_pointer_cast(s); + std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl; + + hcb_->setSocket(sock); + sock->sslAccept(hcb_, timeout_); + EXPECT_EQ(sock->getSSLState(), + AsyncSSLSocket::STATE_ACCEPTING); + + state = STATE_SUCCEEDED; + } +}; + +class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback { +public: + explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb): + SSLServerAcceptCallback(hcb) {} + + // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback + void connAccepted( + const std::shared_ptr &s) + noexcept override { + + auto sock = std::static_pointer_cast(s); + + std::cerr << "SSLServerAcceptCallbackDelay::connAccepted" + << std::endl; + int fd = sock->getFd(); + +#ifndef TCP_NOPUSH + { + // The accepted connection should already have TCP_NODELAY set + int value; + socklen_t valueLength = sizeof(value); + int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength); + EXPECT_EQ(rc, 0); + EXPECT_EQ(value, 1); + } +#endif + + // Unset the TCP_NODELAY option. + int value = 0; + socklen_t valueLength = sizeof(value); + int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength); + EXPECT_EQ(rc, 0); + + rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength); + EXPECT_EQ(rc, 0); + EXPECT_EQ(value, 0); + + SSLServerAcceptCallback::connAccepted(sock); + } +}; + +class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback { +public: + explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb, + uint32_t timeout = 0): + SSLServerAcceptCallback(hcb, timeout) {} + + // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback + void connAccepted( + const std::shared_ptr &s) + noexcept override { + auto sock = std::static_pointer_cast(s); + + std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl; + + hcb_->setSocket(sock); + sock->sslAccept(hcb_, timeout_); + ASSERT_TRUE((sock->getSSLState() == + AsyncSSLSocket::STATE_ACCEPTING) || + (sock->getSSLState() == + AsyncSSLSocket::STATE_CACHE_LOOKUP)); + + state = STATE_SUCCEEDED; + } +}; + + +class HandshakeErrorCallback: public SSLServerAcceptCallbackBase { +public: + explicit HandshakeErrorCallback(HandshakeCallback *hcb): + SSLServerAcceptCallbackBase(hcb) {} + + // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback + void connAccepted( + const std::shared_ptr &s) + noexcept override { + auto sock = std::static_pointer_cast(s); + + std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl; + + // The first call to sslAccept() should succeed. + hcb_->setSocket(sock); + sock->sslAccept(hcb_); + EXPECT_EQ(sock->getSSLState(), + AsyncSSLSocket::STATE_ACCEPTING); + + // The second call to sslAccept() should fail. + HandshakeCallback callback2(hcb_->rcb_); + callback2.setSocket(sock); + sock->sslAccept(&callback2); + EXPECT_EQ(sock->getSSLState(), + AsyncSSLSocket::STATE_ERROR); + + // Both callbacks should be in the error state. + EXPECT_EQ(hcb_->state, STATE_FAILED); + EXPECT_EQ(callback2.state, STATE_FAILED); + + sock->detachEventBase(); + + state = STATE_SUCCEEDED; + hcb_->setState(STATE_SUCCEEDED); + callback2.setState(STATE_SUCCEEDED); + } +}; + +class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase { +public: + explicit HandshakeTimeoutCallback(HandshakeCallback *hcb): + SSLServerAcceptCallbackBase(hcb) {} + + // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback + void connAccepted( + const std::shared_ptr &s) + noexcept override { + std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl; + + auto sock = std::static_pointer_cast(s); + + hcb_->setSocket(sock); + sock->getEventBase()->tryRunAfterDelay([=] { + std::cerr << "Delayed SSL accept, client will have close by now" + << std::endl; + // SSL accept will fail + EXPECT_EQ( + sock->getSSLState(), + AsyncSSLSocket::STATE_UNINIT); + hcb_->socket_->sslAccept(hcb_); + // This registers for an event + EXPECT_EQ( + sock->getSSLState(), + AsyncSSLSocket::STATE_ACCEPTING); + + state = STATE_SUCCEEDED; + }, 100); + } +}; + + +class TestSSLServer { + protected: + EventBase evb_; + std::shared_ptr ctx_; + SSLServerAcceptCallbackBase *acb_; + folly::AsyncServerSocket *socket_; + folly::SocketAddress address_; + pthread_t thread_; + + static void *Main(void *ctx) { + TestSSLServer *self = static_cast(ctx); + self->evb_.loop(); + std::cerr << "Server thread exited event loop" << std::endl; + return nullptr; + } + + public: + // Create a TestSSLServer. + // This immediately starts listening on the given port. + explicit TestSSLServer(SSLServerAcceptCallbackBase *acb); + + // Kill the thread. + ~TestSSLServer() { + evb_.runInEventBaseThread([&](){ + socket_->stopAccepting(); + }); + std::cerr << "Waiting for server thread to exit" << std::endl; + pthread_join(thread_, nullptr); + } + + EventBase &getEventBase() { return evb_; } + + const folly::SocketAddress& getAddress() const { + return address_; + } +}; + +class TestSSLAsyncCacheServer : public TestSSLServer { + public: + explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb, + int lookupDelay = 100) : + TestSSLServer(acb) { + SSL_CTX *sslCtx = ctx_->getSSLCtx(); + SSL_CTX_sess_set_get_cb(sslCtx, + TestSSLAsyncCacheServer::getSessionCallback); + SSL_CTX_set_session_cache_mode( + sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER); + asyncCallbacks_ = 0; + asyncLookups_ = 0; + lookupDelay_ = lookupDelay; + } + + uint32_t getAsyncCallbacks() const { return asyncCallbacks_; } + uint32_t getAsyncLookups() const { return asyncLookups_; } + + private: + static uint32_t asyncCallbacks_; + static uint32_t asyncLookups_; + static uint32_t lookupDelay_; + + static SSL_SESSION *getSessionCallback(SSL *ssl, + unsigned char *sess_id, + int id_len, + int *copyflag) { + *copyflag = 0; + asyncCallbacks_++; +#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP + if (!SSL_want_sess_cache_lookup(ssl)) { + // libssl.so mismatch + std::cerr << "no async support" << std::endl; + return nullptr; + } + + AsyncSSLSocket *sslSocket = + AsyncSSLSocket::getFromSSL(ssl); + assert(sslSocket != nullptr); + // Going to simulate an async cache by just running delaying the miss 100ms + if (asyncCallbacks_ % 2 == 0) { + // This socket is already blocked on lookup, return miss + std::cerr << "returning miss" << std::endl; + } else { + // fresh meat - block it + std::cerr << "async lookup" << std::endl; + sslSocket->getEventBase()->tryRunAfterDelay( + std::bind(&AsyncSSLSocket::restartSSLAccept, + sslSocket), lookupDelay_); + *copyflag = SSL_SESSION_CB_WOULD_BLOCK; + asyncLookups_++; + } +#endif + return nullptr; + } +}; + +void getfds(int fds[2]); + +void getctx( + std::shared_ptr clientCtx, + std::shared_ptr serverCtx); + +void sslsocketpair( + EventBase* eventBase, + AsyncSSLSocket::UniquePtr* clientSock, + AsyncSSLSocket::UniquePtr* serverSock); + +class BlockingWriteClient : + private AsyncSSLSocket::HandshakeCB, + private AsyncTransportWrapper::WriteCallback { + public: + explicit BlockingWriteClient( + AsyncSSLSocket::UniquePtr socket) + : socket_(std::move(socket)), + bufLen_(2500), + iovCount_(2000) { + // Fill buf_ + buf_.reset(new uint8_t[bufLen_]); + for (uint32_t n = 0; n < sizeof(buf_); ++n) { + buf_[n] = n % 0xff; + } + + // Initialize iov_ + iov_.reset(new struct iovec[iovCount_]); + for (uint32_t n = 0; n < iovCount_; ++n) { + iov_[n].iov_base = buf_.get() + n; + if (n & 0x1) { + iov_[n].iov_len = n % bufLen_; + } else { + iov_[n].iov_len = bufLen_ - (n % bufLen_); + } + } + + socket_->sslConn(this, 100); + } + + struct iovec* getIovec() const { + return iov_.get(); + } + uint32_t getIovecCount() const { + return iovCount_; + } + + private: + void handshakeSuc(AsyncSSLSocket*) noexcept override { + socket_->writev(this, iov_.get(), iovCount_); + } + void handshakeErr( + AsyncSSLSocket*, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "client handshake error: " << ex.what(); + } + void writeSuccess() noexcept override { + socket_->close(); + } + void writeErr( + size_t bytesWritten, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: " + << ex.what(); + } + + AsyncSSLSocket::UniquePtr socket_; + uint32_t bufLen_; + uint32_t iovCount_; + std::unique_ptr buf_; + std::unique_ptr iov_; +}; + +class BlockingWriteServer : + private AsyncSSLSocket::HandshakeCB, + private AsyncTransportWrapper::ReadCallback { + public: + explicit BlockingWriteServer( + AsyncSSLSocket::UniquePtr socket) + : socket_(std::move(socket)), + bufSize_(2500 * 2000), + bytesRead_(0) { + buf_.reset(new uint8_t[bufSize_]); + socket_->sslAccept(this, 100); + } + + void checkBuffer(struct iovec* iov, uint32_t count) const { + uint32_t idx = 0; + for (uint32_t n = 0; n < count; ++n) { + size_t bytesLeft = bytesRead_ - idx; + int rc = memcmp(buf_.get() + idx, iov[n].iov_base, + std::min(iov[n].iov_len, bytesLeft)); + if (rc != 0) { + FAIL() << "buffer mismatch at iovec " << n << "/" << count + << ": rc=" << rc; + + } + if (iov[n].iov_len > bytesLeft) { + FAIL() << "server did not read enough data: " + << "ended at byte " << bytesLeft << "/" << iov[n].iov_len + << " in iovec " << n << "/" << count; + } + + idx += iov[n].iov_len; + } + if (idx != bytesRead_) { + ADD_FAILURE() << "server read extra data: " << bytesRead_ + << " bytes read; expected " << idx; + } + } + + private: + void handshakeSuc(AsyncSSLSocket*) noexcept override { + // Wait 10ms before reading, so the client's writes will initially block. + socket_->getEventBase()->tryRunAfterDelay( + [this] { socket_->setReadCB(this); }, 10); + } + void handshakeErr( + AsyncSSLSocket*, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "server handshake error: " << ex.what(); + } + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + *bufReturn = buf_.get() + bytesRead_; + *lenReturn = bufSize_ - bytesRead_; + } + void readDataAvailable(size_t len) noexcept override { + bytesRead_ += len; + socket_->setReadCB(nullptr); + socket_->getEventBase()->tryRunAfterDelay( + [this] { socket_->setReadCB(this); }, 2); + } + void readEOF() noexcept override { + socket_->close(); + } + void readErr( + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "server read error: " << ex.what(); + } + + AsyncSSLSocket::UniquePtr socket_; + uint32_t bufSize_; + uint32_t bytesRead_; + std::unique_ptr buf_; +}; + +class NpnClient : + private AsyncSSLSocket::HandshakeCB, + private AsyncTransportWrapper::WriteCallback { + public: + explicit NpnClient( + AsyncSSLSocket::UniquePtr socket) + : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) { + socket_->sslConn(this); + } + + const unsigned char* nextProto; + unsigned nextProtoLength; + private: + void handshakeSuc(AsyncSSLSocket*) noexcept override { + socket_->getSelectedNextProtocol(&nextProto, + &nextProtoLength); + } + void handshakeErr( + AsyncSSLSocket*, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "client handshake error: " << ex.what(); + } + void writeSuccess() noexcept override { + socket_->close(); + } + void writeErr( + size_t bytesWritten, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: " + << ex.what(); + } + + AsyncSSLSocket::UniquePtr socket_; +}; + +class NpnServer : + private AsyncSSLSocket::HandshakeCB, + private AsyncTransportWrapper::ReadCallback { + public: + explicit NpnServer(AsyncSSLSocket::UniquePtr socket) + : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) { + socket_->sslAccept(this); + } + + const unsigned char* nextProto; + unsigned nextProtoLength; + private: + void handshakeSuc(AsyncSSLSocket*) noexcept override { + socket_->getSelectedNextProtocol(&nextProto, + &nextProtoLength); + } + void handshakeErr( + AsyncSSLSocket*, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "server handshake error: " << ex.what(); + } + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + *lenReturn = 0; + } + void readDataAvailable(size_t len) noexcept override { + } + void readEOF() noexcept override { + socket_->close(); + } + void readErr( + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "server read error: " << ex.what(); + } + + AsyncSSLSocket::UniquePtr socket_; +}; + +#ifndef OPENSSL_NO_TLSEXT +class SNIClient : + private AsyncSSLSocket::HandshakeCB, + private AsyncTransportWrapper::WriteCallback { + public: + explicit SNIClient( + AsyncSSLSocket::UniquePtr socket) + : serverNameMatch(false), socket_(std::move(socket)) { + socket_->sslConn(this); + } + + bool serverNameMatch; + + private: + void handshakeSuc(AsyncSSLSocket*) noexcept override { + serverNameMatch = socket_->isServerNameMatch(); + } + void handshakeErr( + AsyncSSLSocket*, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "client handshake error: " << ex.what(); + } + void writeSuccess() noexcept override { + socket_->close(); + } + void writeErr( + size_t bytesWritten, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: " + << ex.what(); + } + + AsyncSSLSocket::UniquePtr socket_; +}; + +class SNIServer : + private AsyncSSLSocket::HandshakeCB, + private AsyncTransportWrapper::ReadCallback { + public: + explicit SNIServer( + AsyncSSLSocket::UniquePtr socket, + const std::shared_ptr& ctx, + const std::shared_ptr& sniCtx, + const std::string& expectedServerName) + : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx), + expectedServerName_(expectedServerName) { + ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this, + std::placeholders::_1)); + socket_->sslAccept(this); + } + + bool serverNameMatch; + + private: + void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {} + void handshakeErr( + AsyncSSLSocket*, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "server handshake error: " << ex.what(); + } + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + *lenReturn = 0; + } + void readDataAvailable(size_t len) noexcept override { + } + void readEOF() noexcept override { + socket_->close(); + } + void readErr( + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "server read error: " << ex.what(); + } + + folly::SSLContext::ServerNameCallbackResult + serverNameCallback(SSL *ssl) { + const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (sniCtx_ && + sn && + !strcasecmp(expectedServerName_.c_str(), sn)) { + AsyncSSLSocket *sslSocket = + AsyncSSLSocket::getFromSSL(ssl); + sslSocket->switchServerSSLContext(sniCtx_); + serverNameMatch = true; + return folly::SSLContext::SERVER_NAME_FOUND; + } else { + serverNameMatch = false; + return folly::SSLContext::SERVER_NAME_NOT_FOUND; + } + } + + AsyncSSLSocket::UniquePtr socket_; + std::shared_ptr sniCtx_; + std::string expectedServerName_; +}; +#endif + +class SSLClient : public AsyncSocket::ConnectCallback, + public AsyncTransportWrapper::WriteCallback, + public AsyncTransportWrapper::ReadCallback +{ + private: + EventBase *eventBase_; + std::shared_ptr sslSocket_; + SSL_SESSION *session_; + std::shared_ptr ctx_; + uint32_t requests_; + folly::SocketAddress address_; + uint32_t timeout_; + char buf_[128]; + char readbuf_[128]; + uint32_t bytesRead_; + uint32_t hit_; + uint32_t miss_; + uint32_t errors_; + uint32_t writeAfterConnectErrors_; + + public: + SSLClient(EventBase *eventBase, + const folly::SocketAddress& address, + uint32_t requests, uint32_t timeout = 0) + : eventBase_(eventBase), + session_(nullptr), + requests_(requests), + address_(address), + timeout_(timeout), + bytesRead_(0), + hit_(0), + miss_(0), + errors_(0), + writeAfterConnectErrors_(0) { + ctx_.reset(new folly::SSLContext()); + ctx_->setOptions(SSL_OP_NO_TICKET); + ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + memset(buf_, 'a', sizeof(buf_)); + } + + ~SSLClient() { + if (session_) { + SSL_SESSION_free(session_); + } + if (errors_ == 0) { + EXPECT_EQ(bytesRead_, sizeof(buf_)); + } + } + + uint32_t getHit() const { return hit_; } + + uint32_t getMiss() const { return miss_; } + + uint32_t getErrors() const { return errors_; } + + uint32_t getWriteAfterConnectErrors() const { + return writeAfterConnectErrors_; + } + + void connect(bool writeNow = false) { + sslSocket_ = AsyncSSLSocket::newSocket( + ctx_, eventBase_); + if (session_ != nullptr) { + sslSocket_->setSSLSession(session_); + } + requests_--; + sslSocket_->connect(this, address_, timeout_); + if (sslSocket_ && writeNow) { + // write some junk, used in an error test + sslSocket_->write(this, buf_, sizeof(buf_)); + } + } + + void connectSuccess() noexcept override { + std::cerr << "client SSL socket connected" << std::endl; + if (sslSocket_->getSSLSessionReused()) { + hit_++; + } else { + miss_++; + if (session_ != nullptr) { + SSL_SESSION_free(session_); + } + session_ = sslSocket_->getSSLSession(); + } + + // write() + sslSocket_->write(this, buf_, sizeof(buf_)); + sslSocket_->setReadCB(this); + memset(readbuf_, 'b', sizeof(readbuf_)); + bytesRead_ = 0; + } + + void connectErr( + const AsyncSocketException& ex) noexcept override { + std::cerr << "SSLClient::connectError: " << ex.what() << std::endl; + errors_++; + sslSocket_.reset(); + } + + void writeSuccess() noexcept override { + std::cerr << "client write success" << std::endl; + } + + void writeErr( + size_t bytesWritten, + const AsyncSocketException& ex) + noexcept override { + std::cerr << "client writeError: " << ex.what() << std::endl; + if (!sslSocket_) { + writeAfterConnectErrors_++; + } + } + + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + *bufReturn = readbuf_ + bytesRead_; + *lenReturn = sizeof(readbuf_) - bytesRead_; + } + + void readEOF() noexcept override { + std::cerr << "client readEOF" << std::endl; + } + + void readErr( + const AsyncSocketException& ex) noexcept override { + std::cerr << "client readError: " << ex.what() << std::endl; + } + + void readDataAvailable(size_t len) noexcept override { + std::cerr << "client read data: " << len << std::endl; + bytesRead_ += len; + if (len == sizeof(buf_)) { + EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0); + sslSocket_->closeNow(); + sslSocket_.reset(); + if (requests_ != 0) { + connect(); + } + } + } + +}; + +class SSLHandshakeBase : + public AsyncSSLSocket::HandshakeCB, + private AsyncTransportWrapper::WriteCallback { + public: + explicit SSLHandshakeBase( + AsyncSSLSocket::UniquePtr socket, + bool preverifyResult, + bool verifyResult) : + handshakeVerify_(false), + handshakeSuccess_(false), + handshakeError_(false), + socket_(std::move(socket)), + preverifyResult_(preverifyResult), + verifyResult_(verifyResult) { + } + + bool handshakeVerify_; + bool handshakeSuccess_; + bool handshakeError_; + + protected: + AsyncSSLSocket::UniquePtr socket_; + bool preverifyResult_; + bool verifyResult_; + + // HandshakeCallback + bool handshakeVer( + AsyncSSLSocket* sock, + bool preverifyOk, + X509_STORE_CTX* ctx) noexcept override { + handshakeVerify_ = true; + + EXPECT_EQ(preverifyResult_, preverifyOk); + return verifyResult_; + } + + void handshakeSuc(AsyncSSLSocket*) noexcept override { + handshakeSuccess_ = true; + } + + void handshakeErr( + AsyncSSLSocket*, + const AsyncSocketException& ex) noexcept override { + handshakeError_ = true; + } + + // WriteCallback + void writeSuccess() noexcept override { + socket_->close(); + } + + void writeErr( + size_t bytesWritten, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: " + << ex.what(); + } +}; + +class SSLHandshakeClient : public SSLHandshakeBase { + public: + SSLHandshakeClient( + AsyncSSLSocket::UniquePtr socket, + bool preverifyResult, + bool verifyResult) : + SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { + socket_->sslConn(this, 0); + } +}; + +class SSLHandshakeClientNoVerify : public SSLHandshakeBase { + public: + SSLHandshakeClientNoVerify( + AsyncSSLSocket::UniquePtr socket, + bool preverifyResult, + bool verifyResult) : + SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { + socket_->sslConn(this, 0, + folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY); + } +}; + +class SSLHandshakeClientDoVerify : public SSLHandshakeBase { + public: + SSLHandshakeClientDoVerify( + AsyncSSLSocket::UniquePtr socket, + bool preverifyResult, + bool verifyResult) : + SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { + socket_->sslConn(this, 0, + folly::SSLContext::SSLVerifyPeerEnum::VERIFY); + } +}; + +class SSLHandshakeServer : public SSLHandshakeBase { + public: + SSLHandshakeServer( + AsyncSSLSocket::UniquePtr socket, + bool preverifyResult, + bool verifyResult) + : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { + socket_->sslAccept(this, 0); + } +}; + +class SSLHandshakeServerParseClientHello : public SSLHandshakeBase { + public: + SSLHandshakeServerParseClientHello( + AsyncSSLSocket::UniquePtr socket, + bool preverifyResult, + bool verifyResult) + : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { + socket_->enableClientHelloParsing(); + socket_->sslAccept(this, 0); + } + + std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_; + + protected: + void handshakeSuc(AsyncSSLSocket* sock) noexcept override { + handshakeSuccess_ = true; + sock->getSSLSharedCiphers(sharedCiphers_); + sock->getSSLServerCiphers(serverCiphers_); + sock->getSSLClientCiphers(clientCiphers_); + chosenCipher_ = sock->getNegotiatedCipherName(); + } +}; + + +class SSLHandshakeServerNoVerify : public SSLHandshakeBase { + public: + SSLHandshakeServerNoVerify( + AsyncSSLSocket::UniquePtr socket, + bool preverifyResult, + bool verifyResult) + : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { + socket_->sslAccept(this, 0, + folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY); + } +}; + +class SSLHandshakeServerDoVerify : public SSLHandshakeBase { + public: + SSLHandshakeServerDoVerify( + AsyncSSLSocket::UniquePtr socket, + bool preverifyResult, + bool verifyResult) + : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { + socket_->sslAccept(this, 0, + folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); + } +}; + +class EventBaseAborter : public AsyncTimeout { + public: + EventBaseAborter(EventBase* eventBase, + uint32_t timeoutMS) + : AsyncTimeout( + eventBase, AsyncTimeout::InternalEnum::INTERNAL) + , eventBase_(eventBase) { + scheduleTimeout(timeoutMS); + } + + void timeoutExpired() noexcept override { + FAIL() << "test timed out"; + eventBase_->terminateLoopSoon(); + } + + private: + EventBase* eventBase_; +}; + +} diff --git a/folly/io/async/test/AsyncSSLSocketTest2.cpp b/folly/io/async/test/AsyncSSLSocketTest2.cpp new file mode 100644 index 00000000..5f4818ee --- /dev/null +++ b/folly/io/async/test/AsyncSSLSocketTest2.cpp @@ -0,0 +1,147 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include + +#include +#include + +using std::string; +using std::vector; +using std::min; +using std::cerr; +using std::endl; +using std::list; + +namespace folly { + +class AttachDetachClient : public AsyncSocket::ConnectCallback, + public AsyncTransportWrapper::WriteCallback, + public AsyncTransportWrapper::ReadCallback { + private: + EventBase *eventBase_; + std::shared_ptr sslSocket_; + std::shared_ptr ctx_; + folly::SocketAddress address_; + char buf_[128]; + char readbuf_[128]; + uint32_t bytesRead_; + public: + AttachDetachClient(EventBase *eventBase, const folly::SocketAddress& address) + : eventBase_(eventBase), address_(address), bytesRead_(0) { + ctx_.reset(new SSLContext()); + ctx_->setOptions(SSL_OP_NO_TICKET); + ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + } + + void connect() { + sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_); + sslSocket_->connect(this, address_); + } + + void connectSuccess() noexcept override { + cerr << "client SSL socket connected" << endl; + + for (int i = 0; i < 1000; ++i) { + sslSocket_->detachSSLContext(); + sslSocket_->attachSSLContext(ctx_); + } + + EXPECT_EQ(ctx_->getSSLCtx()->references, 2); + + sslSocket_->write(this, buf_, sizeof(buf_)); + sslSocket_->setReadCB(this); + memset(readbuf_, 'b', sizeof(readbuf_)); + bytesRead_ = 0; + } + + void connectErr(const AsyncSocketException& ex) noexcept override + { + cerr << "AttachDetachClient::connectError: " << ex.what() << endl; + sslSocket_.reset(); + } + + void writeSuccess() noexcept override { + cerr << "client write success" << endl; + } + + void writeErr(size_t bytesWritten, const AsyncSocketException& ex) + noexcept override { + cerr << "client writeError: " << ex.what() << endl; + } + + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + *bufReturn = readbuf_ + bytesRead_; + *lenReturn = sizeof(readbuf_) - bytesRead_; + } + void readEOF() noexcept override { + cerr << "client readEOF" << endl; + } + + void readErr(const AsyncSocketException& ex) noexcept override { + cerr << "client readError: " << ex.what() << endl; + } + + void readDataAvailable(size_t len) noexcept override { + cerr << "client read data: " << len << endl; + bytesRead_ += len; + if (len == sizeof(buf_)) { + EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0); + sslSocket_->closeNow(); + } + } +}; + +/** + * Test passing contexts between threads + */ +TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + EventBase eventBase; + EventBaseAborter eba(&eventBase, 3000); + std::shared_ptr client( + new AttachDetachClient(&eventBase, server.getAddress())); + + client->connect(); + eventBase.loop(); +} + +} +/////////////////////////////////////////////////////////////////////////// +// init_unit_test_suite +/////////////////////////////////////////////////////////////////////////// + +namespace { +using folly::SSLContext; +struct Initializer { + Initializer() { + signal(SIGPIPE, SIG_IGN); + SSLContext::setSSLLockTypes({ + {CRYPTO_LOCK_EVP_PKEY, SSLContext::LOCK_NONE}, + {CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_SPINLOCK}, + {CRYPTO_LOCK_SSL_CTX, SSLContext::LOCK_NONE}}); + } +}; +Initializer initializer; +} // anonymous diff --git a/folly/io/async/test/AsyncSSLSocketWriteTest.cpp b/folly/io/async/test/AsyncSSLSocketWriteTest.cpp new file mode 100644 index 00000000..fd9dac73 --- /dev/null +++ b/folly/io/async/test/AsyncSSLSocketWriteTest.cpp @@ -0,0 +1,396 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using std::string; +using namespace testing; + +namespace folly { + +class MockAsyncSSLSocket : public AsyncSSLSocket{ + public: + static std::shared_ptr newSocket( + const std::shared_ptr& ctx, + EventBase* evb) { + auto sock = std::shared_ptr( + new MockAsyncSSLSocket(ctx, evb), + Destructor()); + sock->ssl_ = SSL_new(ctx->getSSLCtx()); + SSL_set_fd(sock->ssl_, -1); + return sock; + } + + // Fake constructor sets the state to established without call to connect + // or accept + MockAsyncSSLSocket(const std::shared_ptr& ctx, + EventBase* evb) + : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) { + state_ = AsyncSocket::StateEnum::ESTABLISHED; + sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED; + } + + // mock the calls to SSL_write to see the buffer length and contents + MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n)); + + // mock the calls to getRawBytesWritten() + MOCK_CONST_METHOD0(getRawBytesWritten, size_t()); + + // public wrapper for protected interface + ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags, + uint32_t* countWritten, uint32_t* partialWritten) { + return performWrite(vec, count, flags, countWritten, partialWritten); + } + + void checkEor(size_t appEor, size_t rawEor) { + EXPECT_EQ(appEor, appEorByteNo_); + EXPECT_EQ(rawEor, minEorRawByteNo_); + } + + void setAppBytesWritten(size_t n) { + appBytesWritten_ = n; + } +}; + +class AsyncSSLSocketWriteTest : public testing::Test { + public: + AsyncSSLSocketWriteTest() : + sslContext_(new SSLContext()), + sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) { + for (int i = 0; i < 500; i++) { + memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26); + } + } + + // Make an iovec containing chunks of the reference text with requested sizes + // for each chunk + iovec *makeVec(std::vector sizes) { + iovec *vec = new iovec[sizes.size()]; + int i = 0; + int pos = 0; + for (auto size: sizes) { + vec[i].iov_base = (void *)(source_ + pos); + vec[i++].iov_len = size; + pos += size; + } + return vec; + } + + // Verify that the given buf/pos matches the reference text + void verifyVec(const void *buf, int n, int pos) { + ASSERT_EQ(memcmp(source_ + pos, buf, n), 0); + } + + // Update a vec on partial write + void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) { + vec[countWritten].iov_base = + ((char *)vec[countWritten].iov_base) + partialWritten; + vec[countWritten].iov_len -= partialWritten; + } + + EventBase eventBase_; + std::shared_ptr sslContext_; + std::shared_ptr sock_; + char source_[26 * 500]; +}; + + +// The entire vec fits in one packet +TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) { + int n = 3; + iovec *vec = makeVec({3, 3, 3}); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9)) + .WillOnce(Invoke([this] (SSL *, const void *buf, int n) { + verifyVec(buf, n, 0); + return 9; })); + uint32_t countWritten = 0; + uint32_t partialWritten = 0; + sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten, + &partialWritten); + EXPECT_EQ(countWritten, n); + EXPECT_EQ(partialWritten, 0); +} + +// First packet is full, second two go in one packet +TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) { + int n = 3; + iovec *vec = makeVec({1500, 3, 3}); + int pos = 0; + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + verifyVec(buf, n, pos); + pos += n; + return n; })); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + verifyVec(buf, n, pos); + pos += n; + return n; })); + uint32_t countWritten = 0; + uint32_t partialWritten = 0; + sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten, + &partialWritten); + EXPECT_EQ(countWritten, n); + EXPECT_EQ(partialWritten, 0); +} + +// Two exactly full packets (coalesce ends midway through second chunk) +TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) { + int n = 3; + iovec *vec = makeVec({1000, 1000, 1000}); + int pos = 0; + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500)) + .Times(2) + .WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int n) { + verifyVec(buf, n, pos); + pos += n; + return n; })); + uint32_t countWritten = 0; + uint32_t partialWritten = 0; + sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten, + &partialWritten); + EXPECT_EQ(countWritten, n); + EXPECT_EQ(partialWritten, 0); +} + +// Partial write success midway through a coalesced vec +TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) { + int n = 5; + iovec *vec = makeVec({300, 300, 300, 300, 300}); + int pos = 0; + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + verifyVec(buf, n, pos); + pos += 1000; + return 1000; /* 500 bytes "pending" */ })); + uint32_t countWritten = 0; + uint32_t partialWritten = 0; + sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten, + &partialWritten); + EXPECT_EQ(countWritten, 3); + EXPECT_EQ(partialWritten, 100); + consumeVec(vec, countWritten, partialWritten); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + verifyVec(buf, n, pos); + pos += n; + return 500; })); + sock_->testPerformWrite(vec + countWritten, n - countWritten, + WriteFlags::NONE, + &countWritten, &partialWritten); + EXPECT_EQ(countWritten, 2); + EXPECT_EQ(partialWritten, 0); +} + +// coalesce ends exactly on a buffer boundary +TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) { + int n = 3; + iovec *vec = makeVec({1000, 500, 500}); + int pos = 0; + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + verifyVec(buf, n, pos); + pos += n; + return n; })); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + verifyVec(buf, n, pos); + pos += n; + return n; })); + uint32_t countWritten = 0; + uint32_t partialWritten = 0; + sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten, + &partialWritten); + EXPECT_EQ(countWritten, 3); + EXPECT_EQ(partialWritten, 0); +} + +// partial write midway through first chunk +TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) { + int n = 2; + iovec *vec = makeVec({1000, 500}); + int pos = 0; + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + verifyVec(buf, n, pos); + pos += 700; + return 700; })); + uint32_t countWritten = 0; + uint32_t partialWritten = 0; + sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten, + &partialWritten); + EXPECT_EQ(countWritten, 0); + EXPECT_EQ(partialWritten, 700); + consumeVec(vec, countWritten, partialWritten); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + verifyVec(buf, n, pos); + pos += n; + return n; })); + sock_->testPerformWrite(vec + countWritten, n - countWritten, + WriteFlags::NONE, + &countWritten, &partialWritten); + EXPECT_EQ(countWritten, 2); + EXPECT_EQ(partialWritten, 0); +} + +// Repeat coalescing2 with WriteFlags::EOR +TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) { + int n = 3; + iovec *vec = makeVec({1500, 3, 3}); + int pos = 0; + const size_t initAppBytesWritten = 500; + const size_t appEor = initAppBytesWritten + 1506; + + sock_->setAppBytesWritten(initAppBytesWritten); + EXPECT_FALSE(sock_->isEorTrackingEnabled()); + sock_->setEorTracking(true); + EXPECT_TRUE(sock_->isEorTrackingEnabled()); + + EXPECT_CALL(*(sock_.get()), getRawBytesWritten()) + // rawBytesWritten after writting initAppBytesWritten + 1500 + // + some random SSL overhead + .WillOnce(Return(3600)) + // rawBytesWritten after writting last 6 bytes + // + some random SSL overhead + .WillOnce(Return(3728)); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500)) + .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) { + // the first 1500 does not have the EOR byte + sock_->checkEor(0, 0); + verifyVec(buf, n, pos); + pos += n; + return n; })); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6)) + .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) { + sock_->checkEor(appEor, 3600 + n); + verifyVec(buf, n, pos); + pos += n; + return n; })); + + uint32_t countWritten = 0; + uint32_t partialWritten = 0; + sock_->testPerformWrite(vec, n , WriteFlags::EOR, + &countWritten, &partialWritten); + EXPECT_EQ(countWritten, n); + EXPECT_EQ(partialWritten, 0); + sock_->checkEor(0, 0); +} + +// coalescing with left over at the last chunk +// WriteFlags::EOR turned on +TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) { + int n = 3; + iovec *vec = makeVec({600, 600, 600}); + int pos = 0; + const size_t initAppBytesWritten = 500; + const size_t appEor = initAppBytesWritten + 1800; + + sock_->setAppBytesWritten(initAppBytesWritten); + sock_->setEorTracking(true); + + EXPECT_CALL(*(sock_.get()), getRawBytesWritten()) + // rawBytesWritten after writting initAppBytesWritten + 1500 bytes + // + some random SSL overhead + .WillOnce(Return(3600)) + // rawBytesWritten after writting last 300 bytes + // + some random SSL overhead + .WillOnce(Return(4100)); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500)) + .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) { + // the first 1500 does not have the EOR byte + sock_->checkEor(0, 0); + verifyVec(buf, n, pos); + pos += n; + return n; })); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300)) + .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) { + sock_->checkEor(appEor, 3600 + n); + verifyVec(buf, n, pos); + pos += n; + return n; })); + + uint32_t countWritten = 0; + uint32_t partialWritten = 0; + sock_->testPerformWrite(vec, n, WriteFlags::EOR, + &countWritten, &partialWritten); + EXPECT_EQ(countWritten, n); + EXPECT_EQ(partialWritten, 0); + sock_->checkEor(0, 0); +} + +// WriteFlags::EOR set +// One buf in iovec +// Partial write at 1000-th byte +TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) { + int n = 1; + iovec *vec = makeVec({1600}); + int pos = 0; + const size_t initAppBytesWritten = 500; + const size_t appEor = initAppBytesWritten + 1600; + + sock_->setAppBytesWritten(initAppBytesWritten); + sock_->setEorTracking(true); + + EXPECT_CALL(*(sock_.get()), getRawBytesWritten()) + // rawBytesWritten after the initAppBytesWritten + // + some random SSL overhead + .WillOnce(Return(2000)) + // rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead) + // + some random SSL overhead + .WillOnce(Return(3100)); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + sock_->checkEor(appEor, 2000 + n); + verifyVec(buf, n, pos); + pos += 1000; + return 1000; })); + + uint32_t countWritten = 0; + uint32_t partialWritten = 0; + sock_->testPerformWrite(vec, n, WriteFlags::EOR, + &countWritten, &partialWritten); + EXPECT_EQ(countWritten, 0); + EXPECT_EQ(partialWritten, 1000); + sock_->checkEor(appEor, 2000 + 1600); + consumeVec(vec, countWritten, partialWritten); + + EXPECT_CALL(*(sock_.get()), getRawBytesWritten()) + .WillOnce(Return(3100)) + .WillOnce(Return(3800)); + EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600)) + .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) { + sock_->checkEor(appEor, 3100 + n); + verifyVec(buf, n, pos); + pos += n; + return n; })); + sock_->testPerformWrite(vec + countWritten, n - countWritten, + WriteFlags::EOR, + &countWritten, &partialWritten); + EXPECT_EQ(countWritten, n); + EXPECT_EQ(partialWritten, 0); + sock_->checkEor(0, 0); +} + +} diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp new file mode 100644 index 00000000..147bec94 --- /dev/null +++ b/folly/io/async/test/AsyncSocketTest2.cpp @@ -0,0 +1,2103 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace boost; + +using std::string; +using std::vector; +using std::min; +using std::cerr; +using std::endl; +using std::unique_ptr; +using std::chrono::milliseconds; +using boost::scoped_array; + +using namespace folly; + +enum StateEnum { + STATE_WAITING, + STATE_SUCCEEDED, + STATE_FAILED +}; + +typedef std::function VoidCallback; + + +class ConnCallback : public AsyncSocket::ConnectCallback { + public: + ConnCallback() + : state(STATE_WAITING) + , exception(AsyncSocketException::UNKNOWN, "none") {} + + void connectSuccess() noexcept override { + state = STATE_SUCCEEDED; + if (successCallback) { + successCallback(); + } + } + + void connectErr(const AsyncSocketException& ex) noexcept override { + state = STATE_FAILED; + exception = ex; + if (errorCallback) { + errorCallback(); + } + } + + StateEnum state; + AsyncSocketException exception; + VoidCallback successCallback; + VoidCallback errorCallback; +}; + +class WriteCallback : public AsyncTransportWrapper::WriteCallback { + public: + WriteCallback() + : state(STATE_WAITING) + , bytesWritten(0) + , exception(AsyncSocketException::UNKNOWN, "none") {} + + void writeSuccess() noexcept override { + state = STATE_SUCCEEDED; + if (successCallback) { + successCallback(); + } + } + + void writeErr(size_t bytesWritten, + const AsyncSocketException& ex) noexcept override { + state = STATE_FAILED; + this->bytesWritten = bytesWritten; + exception = ex; + if (errorCallback) { + errorCallback(); + } + } + + StateEnum state; + size_t bytesWritten; + AsyncSocketException exception; + VoidCallback successCallback; + VoidCallback errorCallback; +}; + +class ReadCallback : public AsyncTransportWrapper::ReadCallback { + public: + ReadCallback() + : state(STATE_WAITING) + , exception(AsyncSocketException::UNKNOWN, "none") + , buffers() {} + + ~ReadCallback() { + for (vector::iterator it = buffers.begin(); + it != buffers.end(); + ++it) { + it->free(); + } + currentBuffer.free(); + } + + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + if (!currentBuffer.buffer) { + currentBuffer.allocate(4096); + } + *bufReturn = currentBuffer.buffer; + *lenReturn = currentBuffer.length; + } + + void readDataAvailable(size_t len) noexcept override { + currentBuffer.length = len; + buffers.push_back(currentBuffer); + currentBuffer.reset(); + if (dataAvailableCallback) { + dataAvailableCallback(); + } + } + + void readEOF() noexcept override { + state = STATE_SUCCEEDED; + } + + void readErr(const AsyncSocketException& ex) noexcept override { + state = STATE_FAILED; + exception = ex; + } + + void verifyData(const char* expected, size_t expectedLen) const { + size_t offset = 0; + for (size_t idx = 0; idx < buffers.size(); ++idx) { + const auto& buf = buffers[idx]; + size_t cmpLen = std::min(buf.length, expectedLen - offset); + CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0); + CHECK_EQ(cmpLen, buf.length); + offset += cmpLen; + } + CHECK_EQ(offset, expectedLen); + } + + class Buffer { + public: + Buffer() : buffer(nullptr), length(0) {} + Buffer(char* buf, size_t len) : buffer(buf), length(len) {} + + void reset() { + buffer = nullptr; + length = 0; + } + void allocate(size_t length) { + assert(buffer == nullptr); + this->buffer = static_cast(malloc(length)); + this->length = length; + } + void free() { + ::free(buffer); + reset(); + } + + char* buffer; + size_t length; + }; + + StateEnum state; + AsyncSocketException exception; + vector buffers; + Buffer currentBuffer; + VoidCallback dataAvailableCallback; +}; + +class ReadVerifier { +}; + +class TestServer { + public: + // Create a TestServer. + // This immediately starts listening on an ephemeral port. + TestServer() + : fd_(-1) { + fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP); + if (fd_ < 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "failed to create test server socket", errno); + } + if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "failed to put test server socket in " + "non-blocking mode", errno); + } + if (listen(fd_, 10) != 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "failed to listen on test server socket", + errno); + } + + address_.setFromLocalAddress(fd_); + // The local address will contain 0.0.0.0. + // Change it to 127.0.0.1, so it can be used to connect to the server + address_.setFromIpPort("127.0.0.1", address_.getPort()); + } + + // Get the address for connecting to the server + const folly::SocketAddress& getAddress() const { + return address_; + } + + int acceptFD(int timeout=50) { + struct pollfd pfd; + pfd.fd = fd_; + pfd.events = POLLIN; + int ret = poll(&pfd, 1, timeout); + if (ret == 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "test server accept() timed out"); + } else if (ret < 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "test server accept() poll failed", errno); + } + + int acceptedFd = ::accept(fd_, nullptr, nullptr); + if (acceptedFd < 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "test server accept() failed", errno); + } + + return acceptedFd; + } + + std::shared_ptr accept(int timeout=50) { + int fd = acceptFD(timeout); + return std::shared_ptr(new BlockingSocket(fd)); + } + + std::shared_ptr acceptAsync(EventBase* evb, int timeout=50) { + int fd = acceptFD(timeout); + return AsyncSocket::newSocket(evb, fd); + } + + /** + * Accept a connection, read data from it, and verify that it matches the + * data in the specified buffer. + */ + void verifyConnection(const char* buf, size_t len) { + // accept a connection + std::shared_ptr acceptedSocket = accept(); + // read the data and compare it to the specified buffer + scoped_array readbuf(new uint8_t[len]); + acceptedSocket->readAll(readbuf.get(), len); + CHECK_EQ(memcmp(buf, readbuf.get(), len), 0); + // make sure we get EOF next + uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len); + CHECK_EQ(bytesRead, 0); + } + + private: + int fd_; + folly::SocketAddress address_; +}; + +class DelayedWrite: public AsyncTimeout { + public: + DelayedWrite(const std::shared_ptr& socket, + unique_ptr&& bufs, AsyncTransportWrapper::WriteCallback* wcb, + bool cork, bool lastWrite = false): + AsyncTimeout(socket->getEventBase()), + socket_(socket), + bufs_(std::move(bufs)), + wcb_(wcb), + cork_(cork), + lastWrite_(lastWrite) {} + + private: + void timeoutExpired() noexcept override { + WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE; + socket_->writeChain(wcb_, std::move(bufs_), flags); + if (lastWrite_) { + socket_->shutdownWrite(); + } + } + + std::shared_ptr socket_; + unique_ptr bufs_; + AsyncTransportWrapper::WriteCallback* wcb_; + bool cork_; + bool lastWrite_; +}; + +/////////////////////////////////////////////////////////////////////////// +// connect() tests +/////////////////////////////////////////////////////////////////////////// + +/** + * Test connecting to a server + */ +TEST(AsyncSocketTest, Connect) { + // Start listening on a local port + TestServer server; + + // Connect using a AsyncSocket + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback cb; + socket->connect(&cb, server.getAddress(), 30); + + evb.loop(); + + CHECK_EQ(cb.state, STATE_SUCCEEDED); +} + +/** + * Test connecting to a server that isn't listening + */ +TEST(AsyncSocketTest, ConnectRefused) { + EventBase evb; + + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + + // Hopefully nothing is actually listening on this address + folly::SocketAddress addr("127.0.0.1", 65535); + ConnCallback cb; + socket->connect(&cb, addr, 30); + + evb.loop(); + + CHECK_EQ(cb.state, STATE_FAILED); + CHECK_EQ(cb.exception.getType(), AsyncSocketException::NOT_OPEN); +} + +/** + * Test connection timeout + */ +TEST(AsyncSocketTest, ConnectTimeout) { + EventBase evb; + + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + + // Try connecting to server that won't respond. + // + // This depends somewhat on the network where this test is run. + // Hopefully this IP will be routable but unresponsive. + // (Alternatively, we could try listening on a local raw socket, but that + // normally requires root privileges.) + folly::SocketAddress addr("8.8.8.8", 65535); + ConnCallback cb; + socket->connect(&cb, addr, 1); // also set a ridiculously small timeout + + evb.loop(); + + CHECK_EQ(cb.state, STATE_FAILED); + CHECK_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT); + + // Verify that we can still get the peer address after a timeout. + // Use case is if the client was created from a client pool, and we want + // to log which peer failed. + folly::SocketAddress peer; + socket->getPeerAddress(&peer); + CHECK_EQ(peer, addr); +} + +/** + * Test writing immediately after connecting, without waiting for connect + * to finish. + */ +TEST(AsyncSocketTest, ConnectAndWrite) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // write() + char buf[128]; + memset(buf, 'a', sizeof(buf)); + WriteCallback wcb; + socket->write(&wcb, buf, sizeof(buf)); + + // Loop. We don't bother accepting on the server socket yet. + // The kernel should be able to buffer the write request so it can succeed. + evb.loop(); + + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + CHECK_EQ(wcb.state, STATE_SUCCEEDED); + + // Make sure the server got a connection and received the data + socket->close(); + server.verifyConnection(buf, sizeof(buf)); +} + +/** + * Test connecting using a nullptr connect callback. + */ +TEST(AsyncSocketTest, ConnectNullCallback) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + socket->connect(nullptr, server.getAddress(), 30); + + // write some data, just so we have some way of verifing + // that the socket works correctly after connecting + char buf[128]; + memset(buf, 'a', sizeof(buf)); + WriteCallback wcb; + socket->write(&wcb, buf, sizeof(buf)); + + evb.loop(); + + CHECK_EQ(wcb.state, STATE_SUCCEEDED); + + // Make sure the server got a connection and received the data + socket->close(); + server.verifyConnection(buf, sizeof(buf)); +} + +/** + * Test calling both write() and close() immediately after connecting, without + * waiting for connect to finish. + * + * This exercises the STATE_CONNECTING_CLOSING code. + */ +TEST(AsyncSocketTest, ConnectWriteAndClose) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // write() + char buf[128]; + memset(buf, 'a', sizeof(buf)); + WriteCallback wcb; + socket->write(&wcb, buf, sizeof(buf)); + + // close() + socket->close(); + + // Loop. We don't bother accepting on the server socket yet. + // The kernel should be able to buffer the write request so it can succeed. + evb.loop(); + + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + CHECK_EQ(wcb.state, STATE_SUCCEEDED); + + // Make sure the server got a connection and received the data + server.verifyConnection(buf, sizeof(buf)); +} + +/** + * Test calling close() immediately after connect() + */ +TEST(AsyncSocketTest, ConnectAndClose) { + TestServer server; + + // Connect using a AsyncSocket + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Hopefully the connect didn't succeed immediately. + // If it did, we can't exercise the close-while-connecting code path. + if (ccb.state == STATE_SUCCEEDED) { + LOG(INFO) << "connect() succeeded immediately; aborting test " + "of close-during-connect behavior"; + return; + } + + socket->close(); + + // Loop, although there shouldn't be anything to do. + evb.loop(); + + // Make sure the connection was aborted + CHECK_EQ(ccb.state, STATE_FAILED); +} + +/** + * Test calling closeNow() immediately after connect() + * + * This should be identical to the normal close behavior. + */ +TEST(AsyncSocketTest, ConnectAndCloseNow) { + TestServer server; + + // Connect using a AsyncSocket + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Hopefully the connect didn't succeed immediately. + // If it did, we can't exercise the close-while-connecting code path. + if (ccb.state == STATE_SUCCEEDED) { + LOG(INFO) << "connect() succeeded immediately; aborting test " + "of closeNow()-during-connect behavior"; + return; + } + + socket->closeNow(); + + // Loop, although there shouldn't be anything to do. + evb.loop(); + + // Make sure the connection was aborted + CHECK_EQ(ccb.state, STATE_FAILED); +} + +/** + * Test calling both write() and closeNow() immediately after connecting, + * without waiting for connect to finish. + * + * This should abort the pending write. + */ +TEST(AsyncSocketTest, ConnectWriteAndCloseNow) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Hopefully the connect didn't succeed immediately. + // If it did, we can't exercise the close-while-connecting code path. + if (ccb.state == STATE_SUCCEEDED) { + LOG(INFO) << "connect() succeeded immediately; aborting test " + "of write-during-connect behavior"; + return; + } + + // write() + char buf[128]; + memset(buf, 'a', sizeof(buf)); + WriteCallback wcb; + socket->write(&wcb, buf, sizeof(buf)); + + // close() + socket->closeNow(); + + // Loop, although there shouldn't be anything to do. + evb.loop(); + + CHECK_EQ(ccb.state, STATE_FAILED); + CHECK_EQ(wcb.state, STATE_FAILED); +} + +/** + * Test installing a read callback immediately, before connect() finishes. + */ +TEST(AsyncSocketTest, ConnectAndRead) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + ReadCallback rcb; + socket->setReadCB(&rcb); + + // Even though we haven't looped yet, we should be able to accept + // the connection and send data to it. + std::shared_ptr acceptedSocket = server.accept(); + uint8_t buf[128]; + memset(buf, 'a', sizeof(buf)); + acceptedSocket->write(buf, sizeof(buf)); + acceptedSocket->flush(); + acceptedSocket->close(); + + // Loop, although there shouldn't be anything to do. + evb.loop(); + + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.buffers.size(), 1); + CHECK_EQ(rcb.buffers[0].length, sizeof(buf)); + CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0); +} + +/** + * Test installing a read callback and then closing immediately before the + * connect attempt finishes. + */ +TEST(AsyncSocketTest, ConnectReadAndClose) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Hopefully the connect didn't succeed immediately. + // If it did, we can't exercise the close-while-connecting code path. + if (ccb.state == STATE_SUCCEEDED) { + LOG(INFO) << "connect() succeeded immediately; aborting test " + "of read-during-connect behavior"; + return; + } + + ReadCallback rcb; + socket->setReadCB(&rcb); + + // close() + socket->close(); + + // Loop, although there shouldn't be anything to do. + evb.loop(); + + CHECK_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt + CHECK_EQ(rcb.buffers.size(), 0); + CHECK_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF +} + +/** + * Test both writing and installing a read callback immediately, + * before connect() finishes. + */ +TEST(AsyncSocketTest, ConnectWriteAndRead) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // write() + char buf1[128]; + memset(buf1, 'a', sizeof(buf1)); + WriteCallback wcb; + socket->write(&wcb, buf1, sizeof(buf1)); + + // set a read callback + ReadCallback rcb; + socket->setReadCB(&rcb); + + // Even though we haven't looped yet, we should be able to accept + // the connection and send data to it. + std::shared_ptr acceptedSocket = server.accept(); + uint8_t buf2[128]; + memset(buf2, 'b', sizeof(buf2)); + acceptedSocket->write(buf2, sizeof(buf2)); + acceptedSocket->flush(); + + // shut down the write half of acceptedSocket, so that the AsyncSocket + // will stop reading and we can break out of the event loop. + shutdown(acceptedSocket->getSocketFD(), SHUT_WR); + + // Loop + evb.loop(); + + // Make sure the connect succeeded + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + + // Make sure the AsyncSocket read the data written by the accepted socket + CHECK_EQ(rcb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.buffers.size(), 1); + CHECK_EQ(rcb.buffers[0].length, sizeof(buf2)); + CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0); + + // Close the AsyncSocket so we'll see EOF on acceptedSocket + socket->close(); + + // Make sure the accepted socket saw the data written by the AsyncSocket + uint8_t readbuf[sizeof(buf1)]; + acceptedSocket->readAll(readbuf, sizeof(readbuf)); + CHECK_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0); + uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf)); + CHECK_EQ(bytesRead, 0); +} + +/** + * Test writing to the socket then shutting down writes before the connect + * attempt finishes. + */ +TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Hopefully the connect didn't succeed immediately. + // If it did, we can't exercise the write-while-connecting code path. + if (ccb.state == STATE_SUCCEEDED) { + LOG(INFO) << "connect() succeeded immediately; skipping test"; + return; + } + + // Ask to write some data + char wbuf[128]; + memset(wbuf, 'a', sizeof(wbuf)); + WriteCallback wcb; + socket->write(&wcb, wbuf, sizeof(wbuf)); + socket->shutdownWrite(); + + // Shutdown writes + socket->shutdownWrite(); + + // Even though we haven't looped yet, we should be able to accept + // the connection. + std::shared_ptr acceptedSocket = server.accept(); + + // Since the connection is still in progress, there should be no data to + // read yet. Verify that the accepted socket is not readable. + struct pollfd fds[1]; + fds[0].fd = acceptedSocket->getSocketFD(); + fds[0].events = POLLIN; + fds[0].revents = 0; + int rc = poll(fds, 1, 0); + CHECK_EQ(rc, 0); + + // Write data to the accepted socket + uint8_t acceptedWbuf[192]; + memset(acceptedWbuf, 'b', sizeof(acceptedWbuf)); + acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf)); + acceptedSocket->flush(); + + // Loop + evb.loop(); + + // The loop should have completed the connection, written the queued data, + // and shutdown writes on the socket. + // + // Check that the connection was completed successfully and that the write + // callback succeeded. + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + CHECK_EQ(wcb.state, STATE_SUCCEEDED); + + // Check that we can read the data that was written to the socket, and that + // we see an EOF, since its socket was half-shutdown. + uint8_t readbuf[sizeof(wbuf)]; + acceptedSocket->readAll(readbuf, sizeof(readbuf)); + CHECK_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0); + uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf)); + CHECK_EQ(bytesRead, 0); + + // Close the accepted socket. This will cause it to see EOF + // and uninstall the read callback when we loop next. + acceptedSocket->close(); + + // Install a read callback, then loop again. + ReadCallback rcb; + socket->setReadCB(&rcb); + evb.loop(); + + // This loop should have read the data and seen the EOF + CHECK_EQ(rcb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.buffers.size(), 1); + CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf)); + CHECK_EQ(memcmp(rcb.buffers[0].buffer, + acceptedWbuf, sizeof(acceptedWbuf)), 0); +} + +/** + * Test reading, writing, and shutting down writes before the connect attempt + * finishes. + */ +TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Hopefully the connect didn't succeed immediately. + // If it did, we can't exercise the write-while-connecting code path. + if (ccb.state == STATE_SUCCEEDED) { + LOG(INFO) << "connect() succeeded immediately; skipping test"; + return; + } + + // Install a read callback + ReadCallback rcb; + socket->setReadCB(&rcb); + + // Ask to write some data + char wbuf[128]; + memset(wbuf, 'a', sizeof(wbuf)); + WriteCallback wcb; + socket->write(&wcb, wbuf, sizeof(wbuf)); + + // Shutdown writes + socket->shutdownWrite(); + + // Even though we haven't looped yet, we should be able to accept + // the connection. + std::shared_ptr acceptedSocket = server.accept(); + + // Since the connection is still in progress, there should be no data to + // read yet. Verify that the accepted socket is not readable. + struct pollfd fds[1]; + fds[0].fd = acceptedSocket->getSocketFD(); + fds[0].events = POLLIN; + fds[0].revents = 0; + int rc = poll(fds, 1, 0); + CHECK_EQ(rc, 0); + + // Write data to the accepted socket + uint8_t acceptedWbuf[192]; + memset(acceptedWbuf, 'b', sizeof(acceptedWbuf)); + acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf)); + acceptedSocket->flush(); + // Shutdown writes to the accepted socket. This will cause it to see EOF + // and uninstall the read callback. + ::shutdown(acceptedSocket->getSocketFD(), SHUT_WR); + + // Loop + evb.loop(); + + // The loop should have completed the connection, written the queued data, + // shutdown writes on the socket, read the data we wrote to it, and see the + // EOF. + // + // Check that the connection was completed successfully and that the read + // and write callbacks were invoked as expected. + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.buffers.size(), 1); + CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf)); + CHECK_EQ(memcmp(rcb.buffers[0].buffer, + acceptedWbuf, sizeof(acceptedWbuf)), 0); + CHECK_EQ(wcb.state, STATE_SUCCEEDED); + + // Check that we can read the data that was written to the socket, and that + // we see an EOF, since its socket was half-shutdown. + uint8_t readbuf[sizeof(wbuf)]; + acceptedSocket->readAll(readbuf, sizeof(readbuf)); + CHECK_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0); + uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf)); + CHECK_EQ(bytesRead, 0); + + // Fully close both sockets + acceptedSocket->close(); + socket->close(); +} + +/** + * Test reading, writing, and calling shutdownWriteNow() before the + * connect attempt finishes. + */ +TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Hopefully the connect didn't succeed immediately. + // If it did, we can't exercise the write-while-connecting code path. + if (ccb.state == STATE_SUCCEEDED) { + LOG(INFO) << "connect() succeeded immediately; skipping test"; + return; + } + + // Install a read callback + ReadCallback rcb; + socket->setReadCB(&rcb); + + // Ask to write some data + char wbuf[128]; + memset(wbuf, 'a', sizeof(wbuf)); + WriteCallback wcb; + socket->write(&wcb, wbuf, sizeof(wbuf)); + + // Shutdown writes immediately. + // This should immediately discard the data that we just tried to write. + socket->shutdownWriteNow(); + + // Verify that writeError() was invoked on the write callback. + CHECK_EQ(wcb.state, STATE_FAILED); + CHECK_EQ(wcb.bytesWritten, 0); + + // Even though we haven't looped yet, we should be able to accept + // the connection. + std::shared_ptr acceptedSocket = server.accept(); + + // Since the connection is still in progress, there should be no data to + // read yet. Verify that the accepted socket is not readable. + struct pollfd fds[1]; + fds[0].fd = acceptedSocket->getSocketFD(); + fds[0].events = POLLIN; + fds[0].revents = 0; + int rc = poll(fds, 1, 0); + CHECK_EQ(rc, 0); + + // Write data to the accepted socket + uint8_t acceptedWbuf[192]; + memset(acceptedWbuf, 'b', sizeof(acceptedWbuf)); + acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf)); + acceptedSocket->flush(); + // Shutdown writes to the accepted socket. This will cause it to see EOF + // and uninstall the read callback. + ::shutdown(acceptedSocket->getSocketFD(), SHUT_WR); + + // Loop + evb.loop(); + + // The loop should have completed the connection, written the queued data, + // shutdown writes on the socket, read the data we wrote to it, and see the + // EOF. + // + // Check that the connection was completed successfully and that the read + // callback was invoked as expected. + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.buffers.size(), 1); + CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf)); + CHECK_EQ(memcmp(rcb.buffers[0].buffer, + acceptedWbuf, sizeof(acceptedWbuf)), 0); + + // Since we used shutdownWriteNow(), it should have discarded all pending + // write data. Verify we see an immediate EOF when reading from the accepted + // socket. + uint8_t readbuf[sizeof(wbuf)]; + uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf)); + CHECK_EQ(bytesRead, 0); + + // Fully close both sockets + acceptedSocket->close(); + socket->close(); +} + +// Helper function for use in testConnectOptWrite() +// Temporarily disable the read callback +void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) { + // Uninstall the read callback + socket->setReadCB(nullptr); + // Schedule the read callback to be reinstalled after 1ms + socket->getEventBase()->runInLoop( + std::bind(&AsyncSocket::setReadCB, socket, rcb)); +} + +/** + * Test connect+write, then have the connect callback perform another write. + * + * This tests interaction of the optimistic writing after connect with + * additional write attempts that occur in the connect callback. + */ +void testConnectOptWrite(size_t size1, size_t size2, bool close = false) { + TestServer server; + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + + // connect() + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Hopefully the connect didn't succeed immediately. + // If it did, we can't exercise the optimistic write code path. + if (ccb.state == STATE_SUCCEEDED) { + LOG(INFO) << "connect() succeeded immediately; aborting test " + "of optimistic write behavior"; + return; + } + + // Tell the connect callback to perform a write when the connect succeeds + WriteCallback wcb2; + scoped_array buf2(new char[size2]); + memset(buf2.get(), 'b', size2); + if (size2 > 0) { + ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); }; + // Tell the second write callback to close the connection when it is done + wcb2.successCallback = [&] { socket->closeNow(); }; + } + + // Schedule one write() immediately, before the connect finishes + scoped_array buf1(new char[size1]); + memset(buf1.get(), 'a', size1); + WriteCallback wcb1; + if (size1 > 0) { + socket->write(&wcb1, buf1.get(), size1); + } + + if (close) { + // immediately perform a close, before connect() completes + socket->close(); + } + + // Start reading from the other endpoint after 10ms. + // If we're using large buffers, we have to read so that the writes don't + // block forever. + std::shared_ptr acceptedSocket = server.acceptAsync(&evb); + ReadCallback rcb; + rcb.dataAvailableCallback = std::bind(tmpDisableReads, + acceptedSocket.get(), &rcb); + socket->getEventBase()->tryRunAfterDelay( + std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb), + 10); + + // Loop. We don't bother accepting on the server socket yet. + // The kernel should be able to buffer the write request so it can succeed. + evb.loop(); + + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + if (size1 > 0) { + CHECK_EQ(wcb1.state, STATE_SUCCEEDED); + } + if (size2 > 0) { + CHECK_EQ(wcb2.state, STATE_SUCCEEDED); + } + + socket->close(); + + // Make sure the read callback received all of the data + size_t bytesRead = 0; + for (vector::const_iterator it = rcb.buffers.begin(); + it != rcb.buffers.end(); + ++it) { + size_t start = bytesRead; + bytesRead += it->length; + size_t end = bytesRead; + if (start < size1) { + size_t cmpLen = min(size1, end) - start; + CHECK_EQ(memcmp(it->buffer, buf1.get() + start, cmpLen), 0); + } + if (end > size1 && end <= size1 + size2) { + size_t itOffset; + size_t buf2Offset; + size_t cmpLen; + if (start >= size1) { + itOffset = 0; + buf2Offset = start - size1; + cmpLen = end - start; + } else { + itOffset = size1 - start; + buf2Offset = 0; + cmpLen = end - size1; + } + CHECK_EQ(memcmp(it->buffer + itOffset, buf2.get() + buf2Offset, + cmpLen), + 0); + } + } + CHECK_EQ(bytesRead, size1 + size2); +} + +TEST(AsyncSocketTest, ConnectCallbackWrite) { + // Test using small writes that should both succeed immediately + testConnectOptWrite(100, 200); + + // Test using a large buffer in the connect callback, that should block + const size_t largeSize = 8*1024*1024; + testConnectOptWrite(100, largeSize); + + // Test using a large initial write + testConnectOptWrite(largeSize, 100); + + // Test using two large buffers + testConnectOptWrite(largeSize, largeSize); + + // Test a small write in the connect callback, + // but no immediate write before connect completes + testConnectOptWrite(0, 64); + + // Test a large write in the connect callback, + // but no immediate write before connect completes + testConnectOptWrite(0, largeSize); + + // Test connect, a small write, then immediately call close() before connect + // completes + testConnectOptWrite(211, 0, true); + + // Test connect, a large immediate write (that will block), then immediately + // call close() before connect completes + testConnectOptWrite(largeSize, 0, true); +} + +/////////////////////////////////////////////////////////////////////////// +// write() related tests +/////////////////////////////////////////////////////////////////////////// + +/** + * Test writing using a nullptr callback + */ +TEST(AsyncSocketTest, WriteNullCallback) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = + AsyncSocket::newSocket(&evb, server.getAddress(), 30); + evb.loop(); // loop until the socket is connected + + // write() with a nullptr callback + char buf[128]; + memset(buf, 'a', sizeof(buf)); + socket->write(nullptr, buf, sizeof(buf)); + + evb.loop(); // loop until the data is sent + + // Make sure the server got a connection and received the data + socket->close(); + server.verifyConnection(buf, sizeof(buf)); +} + +/** + * Test writing with a send timeout + */ +TEST(AsyncSocketTest, WriteTimeout) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = + AsyncSocket::newSocket(&evb, server.getAddress(), 30); + evb.loop(); // loop until the socket is connected + + // write() a large chunk of data, with no-one on the other end reading + size_t writeLength = 8*1024*1024; + uint32_t timeout = 200; + socket->setSendTimeout(timeout); + scoped_array buf(new char[writeLength]); + memset(buf.get(), 'a', writeLength); + WriteCallback wcb; + socket->write(&wcb, buf.get(), writeLength); + + TimePoint start; + evb.loop(); + TimePoint end; + + // Make sure the write attempt timed out as requested + CHECK_EQ(wcb.state, STATE_FAILED); + CHECK_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT); + + // Check that the write timed out within a reasonable period of time. + // We don't check for exactly the specified timeout, since AsyncSocket only + // times out when it hasn't made progress for that period of time. + // + // On linux, the first write sends a few hundred kb of data, then blocks for + // writability, and then unblocks again after 40ms and is able to write + // another smaller of data before blocking permanently. Therefore it doesn't + // time out until 40ms + timeout. + // + // I haven't fully verified the cause of this, but I believe it probably + // occurs because the receiving end delays sending an ack for up to 40ms. + // (This is the default value for TCP_DELACK_MIN.) Once the sender receives + // the ack, it can send some more data. However, after that point the + // receiver's kernel buffer is full. This 40ms delay happens even with + // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints. However, the + // kernel may be automatically disabling TCP_QUICKACK after receiving some + // data. + // + // For now, we simply check that the timeout occurred within 160ms of + // the requested value. + T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160)); +} + +/** + * Test writing to a socket that the remote endpoint has closed + */ +TEST(AsyncSocketTest, WritePipeError) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = + AsyncSocket::newSocket(&evb, server.getAddress(), 30); + socket->setSendTimeout(1000); + evb.loop(); // loop until the socket is connected + + // accept and immediately close the socket + std::shared_ptr acceptedSocket = server.accept(); + acceptedSocket.reset(); + + // write() a large chunk of data + size_t writeLength = 8*1024*1024; + scoped_array buf(new char[writeLength]); + memset(buf.get(), 'a', writeLength); + WriteCallback wcb; + socket->write(&wcb, buf.get(), writeLength); + + evb.loop(); + + // Make sure the write failed. + // It would be nice if AsyncSocketException could convey the errno value, + // so that we could check for EPIPE + CHECK_EQ(wcb.state, STATE_FAILED); + CHECK_EQ(wcb.exception.getType(), + AsyncSocketException::INTERNAL_ERROR); +} + +/** + * Test writing a mix of simple buffers and IOBufs + */ +TEST(AsyncSocketTest, WriteIOBuf) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Accept the connection + std::shared_ptr acceptedSocket = server.acceptAsync(&evb); + ReadCallback rcb; + acceptedSocket->setReadCB(&rcb); + + // Write a simple buffer to the socket + size_t simpleBufLength = 5; + char simpleBuf[simpleBufLength]; + memset(simpleBuf, 'a', simpleBufLength); + WriteCallback wcb; + socket->write(&wcb, simpleBuf, simpleBufLength); + + // Write a single-element IOBuf chain + size_t buf1Length = 7; + unique_ptr buf1(IOBuf::create(buf1Length)); + memset(buf1->writableData(), 'b', buf1Length); + buf1->append(buf1Length); + unique_ptr buf1Copy(buf1->clone()); + WriteCallback wcb2; + socket->writeChain(&wcb2, std::move(buf1)); + + // Write a multiple-element IOBuf chain + size_t buf2Length = 11; + unique_ptr buf2(IOBuf::create(buf2Length)); + memset(buf2->writableData(), 'c', buf2Length); + buf2->append(buf2Length); + size_t buf3Length = 13; + unique_ptr buf3(IOBuf::create(buf3Length)); + memset(buf3->writableData(), 'd', buf3Length); + buf3->append(buf3Length); + buf2->appendChain(std::move(buf3)); + unique_ptr buf2Copy(buf2->clone()); + buf2Copy->coalesce(); + WriteCallback wcb3; + socket->writeChain(&wcb3, std::move(buf2)); + socket->shutdownWrite(); + + // Let the reads and writes run to completion + evb.loop(); + + CHECK_EQ(wcb.state, STATE_SUCCEEDED); + CHECK_EQ(wcb2.state, STATE_SUCCEEDED); + CHECK_EQ(wcb3.state, STATE_SUCCEEDED); + + // Make sure the reader got the right data in the right order + CHECK_EQ(rcb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.buffers.size(), 1); + CHECK_EQ(rcb.buffers[0].length, + simpleBufLength + buf1Length + buf2Length + buf3Length); + CHECK_EQ( + memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0); + CHECK_EQ( + memcmp(rcb.buffers[0].buffer + simpleBufLength, + buf1Copy->data(), buf1Copy->length()), 0); + CHECK_EQ( + memcmp(rcb.buffers[0].buffer + simpleBufLength + buf1Length, + buf2Copy->data(), buf2Copy->length()), 0); + + acceptedSocket->close(); + socket->close(); +} + +TEST(AsyncSocketTest, WriteIOBufCorked) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Accept the connection + std::shared_ptr acceptedSocket = server.acceptAsync(&evb); + ReadCallback rcb; + acceptedSocket->setReadCB(&rcb); + + // Do three writes, 100ms apart, with the "cork" flag set + // on the second write. The reader should see the first write + // arrive by itself, followed by the second and third writes + // arriving together. + size_t buf1Length = 5; + unique_ptr buf1(IOBuf::create(buf1Length)); + memset(buf1->writableData(), 'a', buf1Length); + buf1->append(buf1Length); + size_t buf2Length = 7; + unique_ptr buf2(IOBuf::create(buf2Length)); + memset(buf2->writableData(), 'b', buf2Length); + buf2->append(buf2Length); + size_t buf3Length = 11; + unique_ptr buf3(IOBuf::create(buf3Length)); + memset(buf3->writableData(), 'c', buf3Length); + buf3->append(buf3Length); + WriteCallback wcb1; + socket->writeChain(&wcb1, std::move(buf1)); + WriteCallback wcb2; + DelayedWrite write2(socket, std::move(buf2), &wcb2, true); + write2.scheduleTimeout(100); + WriteCallback wcb3; + DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true); + write3.scheduleTimeout(200); + + evb.loop(); + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + CHECK_EQ(wcb1.state, STATE_SUCCEEDED); + CHECK_EQ(wcb2.state, STATE_SUCCEEDED); + if (wcb3.state != STATE_SUCCEEDED) { + throw(wcb3.exception); + } + CHECK_EQ(wcb3.state, STATE_SUCCEEDED); + + // Make sure the reader got the data with the right grouping + CHECK_EQ(rcb.state, STATE_SUCCEEDED); + CHECK_EQ(rcb.buffers.size(), 2); + CHECK_EQ(rcb.buffers[0].length, buf1Length); + CHECK_EQ(rcb.buffers[1].length, buf2Length + buf3Length); + + acceptedSocket->close(); + socket->close(); +} + +/** + * Test performing a zero-length write + */ +TEST(AsyncSocketTest, ZeroLengthWrite) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = + AsyncSocket::newSocket(&evb, server.getAddress(), 30); + evb.loop(); // loop until the socket is connected + + auto acceptedSocket = server.acceptAsync(&evb); + ReadCallback rcb; + acceptedSocket->setReadCB(&rcb); + + size_t len1 = 1024*1024; + size_t len2 = 1024*1024; + std::unique_ptr buf(new char[len1 + len2]); + memset(buf.get(), 'a', len1); + memset(buf.get(), 'b', len2); + + WriteCallback wcb1; + WriteCallback wcb2; + WriteCallback wcb3; + WriteCallback wcb4; + socket->write(&wcb1, buf.get(), 0); + socket->write(&wcb2, buf.get(), len1); + socket->write(&wcb3, buf.get() + len1, 0); + socket->write(&wcb4, buf.get() + len1, len2); + socket->close(); + + evb.loop(); // loop until the data is sent + + CHECK_EQ(wcb1.state, STATE_SUCCEEDED); + CHECK_EQ(wcb2.state, STATE_SUCCEEDED); + CHECK_EQ(wcb3.state, STATE_SUCCEEDED); + CHECK_EQ(wcb4.state, STATE_SUCCEEDED); + rcb.verifyData(buf.get(), len1 + len2); +} + +TEST(AsyncSocketTest, ZeroLengthWritev) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = + AsyncSocket::newSocket(&evb, server.getAddress(), 30); + evb.loop(); // loop until the socket is connected + + auto acceptedSocket = server.acceptAsync(&evb); + ReadCallback rcb; + acceptedSocket->setReadCB(&rcb); + + size_t len1 = 1024*1024; + size_t len2 = 1024*1024; + std::unique_ptr buf(new char[len1 + len2]); + memset(buf.get(), 'a', len1); + memset(buf.get(), 'b', len2); + + WriteCallback wcb; + size_t iovCount = 4; + struct iovec iov[iovCount]; + iov[0].iov_base = buf.get(); + iov[0].iov_len = len1; + iov[1].iov_base = buf.get() + len1; + iov[1].iov_len = 0; + iov[2].iov_base = buf.get() + len1; + iov[2].iov_len = len2; + iov[3].iov_base = buf.get() + len1 + len2; + iov[3].iov_len = 0; + + socket->writev(&wcb, iov, iovCount); + socket->close(); + evb.loop(); // loop until the data is sent + + CHECK_EQ(wcb.state, STATE_SUCCEEDED); + rcb.verifyData(buf.get(), len1 + len2); +} + +/////////////////////////////////////////////////////////////////////////// +// close() related tests +/////////////////////////////////////////////////////////////////////////// + +/** + * Test calling close() with pending writes when the socket is already closing. + */ +TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) { + TestServer server; + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // accept the socket on the server side + std::shared_ptr acceptedSocket = server.accept(); + + // Loop to ensure the connect has completed + evb.loop(); + + // Make sure we are connected + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + + // Schedule pending writes, until several write attempts have blocked + char buf[128]; + memset(buf, 'a', sizeof(buf)); + typedef vector< std::shared_ptr > WriteCallbackVector; + WriteCallbackVector writeCallbacks; + + writeCallbacks.reserve(5); + while (writeCallbacks.size() < 5) { + std::shared_ptr wcb(new WriteCallback); + + socket->write(wcb.get(), buf, sizeof(buf)); + if (wcb->state == STATE_SUCCEEDED) { + // Succeeded immediately. Keep performing more writes + continue; + } + + // This write is blocked. + // Have the write callback call close() when writeError() is invoked + wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get()); + writeCallbacks.push_back(wcb); + } + + // Call closeNow() to immediately fail the pending writes + socket->closeNow(); + + // Make sure writeError() was invoked on all of the pending write callbacks + for (WriteCallbackVector::const_iterator it = writeCallbacks.begin(); + it != writeCallbacks.end(); + ++it) { + CHECK_EQ((*it)->state, STATE_FAILED); + } +} + + +// TODO: +// - Test connect() and have the connect callback set the read callback +// - Test connect() and have the connect callback unset the read callback +// - Test reading/writing/closing/destroying the socket in the connect callback +// - Test reading/writing/closing/destroying the socket in the read callback +// - Test reading/writing/closing/destroying the socket in the write callback +// - Test one-way shutdown behavior +// - Test changing the EventBase +// +// - TODO: test multiple threads sharing a AsyncSocket, and detaching from it +// in connectSuccess(), readDataAvailable(), writeSuccess() + + +/////////////////////////////////////////////////////////////////////////// +// AsyncServerSocket tests +/////////////////////////////////////////////////////////////////////////// + +/** + * Helper AcceptCallback class for the test code + * It records the callbacks that were invoked, and also supports calling + * generic std::function objects in each callback. + */ +class TestAcceptCallback : public AsyncServerSocket::AcceptCallback { + public: + enum EventType { + TYPE_START, + TYPE_ACCEPT, + TYPE_ERROR, + TYPE_STOP + }; + struct EventInfo { + EventInfo(int fd, const folly::SocketAddress& addr) + : type(TYPE_ACCEPT), + fd(fd), + address(addr), + errorMsg() {} + explicit EventInfo(const std::string& msg) + : type(TYPE_ERROR), + fd(-1), + address(), + errorMsg(msg) {} + explicit EventInfo(EventType et) + : type(et), + fd(-1), + address(), + errorMsg() {} + + EventType type; + int fd; // valid for TYPE_ACCEPT + folly::SocketAddress address; // valid for TYPE_ACCEPT + string errorMsg; // valid for TYPE_ERROR + }; + typedef std::deque EventList; + + TestAcceptCallback() + : connectionAcceptedFn_(), + acceptErrorFn_(), + acceptStoppedFn_(), + events_() {} + + std::deque* getEvents() { + return &events_; + } + + void setConnectionAcceptedFn( + const std::function& fn) { + connectionAcceptedFn_ = fn; + } + void setAcceptErrorFn(const std::function& fn) { + acceptErrorFn_ = fn; + } + void setAcceptStartedFn(const std::function& fn) { + acceptStartedFn_ = fn; + } + void setAcceptStoppedFn(const std::function& fn) { + acceptStoppedFn_ = fn; + } + + void connectionAccepted(int fd, const folly::SocketAddress& clientAddr) + noexcept { + events_.push_back(EventInfo(fd, clientAddr)); + + if (connectionAcceptedFn_) { + connectionAcceptedFn_(fd, clientAddr); + } + } + void acceptError(const std::exception& ex) noexcept { + events_.push_back(EventInfo(ex.what())); + + if (acceptErrorFn_) { + acceptErrorFn_(ex); + } + } + void acceptStarted() noexcept { + events_.push_back(EventInfo(TYPE_START)); + + if (acceptStartedFn_) { + acceptStartedFn_(); + } + } + void acceptStopped() noexcept { + events_.push_back(EventInfo(TYPE_STOP)); + + if (acceptStoppedFn_) { + acceptStoppedFn_(); + } + } + + private: + std::function connectionAcceptedFn_; + std::function acceptErrorFn_; + std::function acceptStartedFn_; + std::function acceptStoppedFn_; + + std::deque events_; +}; + +/** + * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set + */ +TEST(AsyncSocketTest, ServerAcceptOptions) { + EventBase eventBase; + + // Create a server socket + std::shared_ptr serverSocket( + AsyncServerSocket::newSocket(&eventBase)); + serverSocket->bind(0); + serverSocket->listen(16); + folly::SocketAddress serverAddress; + serverSocket->getAddress(&serverAddress); + + // Add a callback to accept one connection then stop the loop + TestAcceptCallback acceptCallback; + acceptCallback.setConnectionAcceptedFn( + [&](int fd, const folly::SocketAddress& addr) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + acceptCallback.setAcceptErrorFn([&](const std::exception& ex) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + serverSocket->addAcceptCallback(&acceptCallback, nullptr); + serverSocket->startAccepting(); + + // Connect to the server socket + std::shared_ptr socket( + AsyncSocket::newSocket(&eventBase, serverAddress)); + + eventBase.loop(); + + // Verify that the server accepted a connection + CHECK_EQ(acceptCallback.getEvents()->size(), 3); + CHECK_EQ(acceptCallback.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(acceptCallback.getEvents()->at(1).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(acceptCallback.getEvents()->at(2).type, + TestAcceptCallback::TYPE_STOP); + int fd = acceptCallback.getEvents()->at(1).fd; + + // The accepted connection should already be in non-blocking mode + int flags = fcntl(fd, F_GETFL, 0); + CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK); + +#ifndef TCP_NOPUSH + // The accepted connection should already have TCP_NODELAY set + int value; + socklen_t valueLength = sizeof(value); + int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength); + CHECK_EQ(rc, 0); + CHECK_EQ(value, 1); +#endif +} + +/** + * Test AsyncServerSocket::removeAcceptCallback() + */ +TEST(AsyncSocketTest, RemoveAcceptCallback) { + // Create a new AsyncServerSocket + EventBase eventBase; + std::shared_ptr serverSocket( + AsyncServerSocket::newSocket(&eventBase)); + serverSocket->bind(0); + serverSocket->listen(16); + folly::SocketAddress serverAddress; + serverSocket->getAddress(&serverAddress); + + // Add several accept callbacks + TestAcceptCallback cb1; + TestAcceptCallback cb2; + TestAcceptCallback cb3; + TestAcceptCallback cb4; + TestAcceptCallback cb5; + TestAcceptCallback cb6; + TestAcceptCallback cb7; + + // Test having callbacks remove other callbacks before them on the list, + // after them on the list, or removing themselves. + // + // Have callback 2 remove callback 3 and callback 5 the first time it is + // called. + int cb2Count = 0; + cb1.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){ + std::shared_ptr sock2( + AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2: -cb3 -cb5 + }); + cb3.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){ + }); + cb4.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){ + std::shared_ptr sock3( + AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4 + }); + cb5.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){ + std::shared_ptr sock5( + AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7 + + }); + cb2.setConnectionAcceptedFn( + [&](int fd, const folly::SocketAddress& addr) { + if (cb2Count == 0) { + serverSocket->removeAcceptCallback(&cb3, nullptr); + serverSocket->removeAcceptCallback(&cb5, nullptr); + } + ++cb2Count; + }); + // Have callback 6 remove callback 4 the first time it is called, + // and destroy the server socket the second time it is called + int cb6Count = 0; + cb6.setConnectionAcceptedFn( + [&](int fd, const folly::SocketAddress& addr) { + if (cb6Count == 0) { + serverSocket->removeAcceptCallback(&cb4, nullptr); + std::shared_ptr sock6( + AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1 + std::shared_ptr sock7( + AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2 + std::shared_ptr sock8( + AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop + + } else { + serverSocket.reset(); + } + ++cb6Count; + }); + // Have callback 7 remove itself + cb7.setConnectionAcceptedFn( + [&](int fd, const folly::SocketAddress& addr) { + serverSocket->removeAcceptCallback(&cb7, nullptr); + }); + + serverSocket->addAcceptCallback(&cb1, nullptr); + serverSocket->addAcceptCallback(&cb2, nullptr); + serverSocket->addAcceptCallback(&cb3, nullptr); + serverSocket->addAcceptCallback(&cb4, nullptr); + serverSocket->addAcceptCallback(&cb5, nullptr); + serverSocket->addAcceptCallback(&cb6, nullptr); + serverSocket->addAcceptCallback(&cb7, nullptr); + serverSocket->startAccepting(); + + // Make several connections to the socket + std::shared_ptr sock1( + AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1 + std::shared_ptr sock4( + AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4 + + // Loop until we are stopped + eventBase.loop(); + + // Check to make sure that the expected callbacks were invoked. + // + // NOTE: This code depends on the AsyncServerSocket operating calling all of + // the AcceptCallbacks in round-robin fashion, in the order that they were + // added. The code is implemented this way right now, but the API doesn't + // explicitly require it be done this way. If we change the code not to be + // exactly round robin in the future, we can simplify the test checks here. + // (We'll also need to update the termination code, since we expect cb6 to + // get called twice to terminate the loop.) + CHECK_EQ(cb1.getEvents()->size(), 4); + CHECK_EQ(cb1.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(cb1.getEvents()->at(1).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(cb1.getEvents()->at(2).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(cb1.getEvents()->at(3).type, + TestAcceptCallback::TYPE_STOP); + + CHECK_EQ(cb2.getEvents()->size(), 4); + CHECK_EQ(cb2.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(cb2.getEvents()->at(1).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(cb2.getEvents()->at(2).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(cb2.getEvents()->at(3).type, + TestAcceptCallback::TYPE_STOP); + + CHECK_EQ(cb3.getEvents()->size(), 2); + CHECK_EQ(cb3.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(cb3.getEvents()->at(1).type, + TestAcceptCallback::TYPE_STOP); + + CHECK_EQ(cb4.getEvents()->size(), 3); + CHECK_EQ(cb4.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(cb4.getEvents()->at(1).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(cb4.getEvents()->at(2).type, + TestAcceptCallback::TYPE_STOP); + + CHECK_EQ(cb5.getEvents()->size(), 2); + CHECK_EQ(cb5.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(cb5.getEvents()->at(1).type, + TestAcceptCallback::TYPE_STOP); + + CHECK_EQ(cb6.getEvents()->size(), 4); + CHECK_EQ(cb6.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(cb6.getEvents()->at(1).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(cb6.getEvents()->at(2).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(cb6.getEvents()->at(3).type, + TestAcceptCallback::TYPE_STOP); + + CHECK_EQ(cb7.getEvents()->size(), 3); + CHECK_EQ(cb7.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(cb7.getEvents()->at(1).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(cb7.getEvents()->at(2).type, + TestAcceptCallback::TYPE_STOP); +} + +/** + * Test AsyncServerSocket::removeAcceptCallback() + */ +TEST(AsyncSocketTest, OtherThreadAcceptCallback) { + // Create a new AsyncServerSocket + EventBase eventBase; + std::shared_ptr serverSocket( + AsyncServerSocket::newSocket(&eventBase)); + serverSocket->bind(0); + serverSocket->listen(16); + folly::SocketAddress serverAddress; + serverSocket->getAddress(&serverAddress); + + // Add several accept callbacks + TestAcceptCallback cb1; + auto thread_id = pthread_self(); + cb1.setAcceptStartedFn([&](){ + CHECK_NE(thread_id, pthread_self()); + thread_id = pthread_self(); + }); + cb1.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){ + CHECK_EQ(thread_id, pthread_self()); + serverSocket->removeAcceptCallback(&cb1, nullptr); + }); + cb1.setAcceptStoppedFn([&](){ + CHECK_EQ(thread_id, pthread_self()); + }); + + // Test having callbacks remove other callbacks before them on the list, + serverSocket->addAcceptCallback(&cb1, nullptr); + serverSocket->startAccepting(); + + // Make several connections to the socket + std::shared_ptr sock1( + AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1 + + // Loop in another thread + auto other = std::thread([&](){ + eventBase.loop(); + }); + other.join(); + + // Check to make sure that the expected callbacks were invoked. + // + // NOTE: This code depends on the AsyncServerSocket operating calling all of + // the AcceptCallbacks in round-robin fashion, in the order that they were + // added. The code is implemented this way right now, but the API doesn't + // explicitly require it be done this way. If we change the code not to be + // exactly round robin in the future, we can simplify the test checks here. + // (We'll also need to update the termination code, since we expect cb6 to + // get called twice to terminate the loop.) + CHECK_EQ(cb1.getEvents()->size(), 3); + CHECK_EQ(cb1.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(cb1.getEvents()->at(1).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(cb1.getEvents()->at(2).type, + TestAcceptCallback::TYPE_STOP); + +} + +void serverSocketSanityTest(AsyncServerSocket* serverSocket) { + // Add a callback to accept one connection then stop accepting + TestAcceptCallback acceptCallback; + acceptCallback.setConnectionAcceptedFn( + [&](int fd, const folly::SocketAddress& addr) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + acceptCallback.setAcceptErrorFn([&](const std::exception& ex) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + serverSocket->addAcceptCallback(&acceptCallback, nullptr); + serverSocket->startAccepting(); + + // Connect to the server socket + EventBase* eventBase = serverSocket->getEventBase(); + folly::SocketAddress serverAddress; + serverSocket->getAddress(&serverAddress); + AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress)); + + // Loop to process all events + eventBase->loop(); + + // Verify that the server accepted a connection + CHECK_EQ(acceptCallback.getEvents()->size(), 3); + CHECK_EQ(acceptCallback.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(acceptCallback.getEvents()->at(1).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(acceptCallback.getEvents()->at(2).type, + TestAcceptCallback::TYPE_STOP); +} + +/* Verify that we don't leak sockets if we are destroyed() + * and there are still writes pending + * + * If destroy() only calls close() instead of closeNow(), + * it would shutdown(writes) on the socket, but it would + * never be close()'d, and the socket would leak + */ +TEST(AsyncSocketTest, DestroyCloseTest) { + TestServer server; + + // connect() + EventBase clientEB; + EventBase serverEB; + std::shared_ptr socket = AsyncSocket::newSocket(&clientEB); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // Accept the connection + std::shared_ptr acceptedSocket = server.acceptAsync(&serverEB); + ReadCallback rcb; + acceptedSocket->setReadCB(&rcb); + + // Write a large buffer to the socket that is larger than kernel buffer + size_t simpleBufLength = 5000000; + char* simpleBuf = new char[simpleBufLength]; + memset(simpleBuf, 'a', simpleBufLength); + WriteCallback wcb; + + // Let the reads and writes run to completion + int fd = acceptedSocket->getFd(); + + acceptedSocket->write(&wcb, simpleBuf, simpleBufLength); + socket.reset(); + acceptedSocket.reset(); + + // Test that server socket was closed + ssize_t sz = read(fd, simpleBuf, simpleBufLength); + CHECK_EQ(sz, -1); + CHECK_EQ(errno, 9); + delete[] simpleBuf; +} + +/** + * Test AsyncServerSocket::useExistingSocket() + */ +TEST(AsyncSocketTest, ServerExistingSocket) { + EventBase eventBase; + + // Test creating a socket, and letting AsyncServerSocket bind and listen + { + // Manually create a socket + int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ASSERT_GE(fd, 0); + + // Create a server socket + AsyncServerSocket::UniquePtr serverSocket( + new AsyncServerSocket(&eventBase)); + serverSocket->useExistingSocket(fd); + folly::SocketAddress address; + serverSocket->getAddress(&address); + address.setPort(0); + serverSocket->bind(address); + serverSocket->listen(16); + + // Make sure the socket works + serverSocketSanityTest(serverSocket.get()); + } + + // Test creating a socket and binding manually, + // then letting AsyncServerSocket listen + { + // Manually create a socket + int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ASSERT_GE(fd, 0); + // bind + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.s_addr = INADDR_ANY; + CHECK_EQ(bind(fd, reinterpret_cast(&addr), + sizeof(addr)), 0); + // Look up the address that we bound to + folly::SocketAddress boundAddress; + boundAddress.setFromLocalAddress(fd); + + // Create a server socket + AsyncServerSocket::UniquePtr serverSocket( + new AsyncServerSocket(&eventBase)); + serverSocket->useExistingSocket(fd); + serverSocket->listen(16); + + // Make sure AsyncServerSocket reports the same address that we bound to + folly::SocketAddress serverSocketAddress; + serverSocket->getAddress(&serverSocketAddress); + CHECK_EQ(boundAddress, serverSocketAddress); + + // Make sure the socket works + serverSocketSanityTest(serverSocket.get()); + } + + // Test creating a socket, binding and listening manually, + // then giving it to AsyncServerSocket + { + // Manually create a socket + int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ASSERT_GE(fd, 0); + // bind + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.s_addr = INADDR_ANY; + CHECK_EQ(bind(fd, reinterpret_cast(&addr), + sizeof(addr)), 0); + // Look up the address that we bound to + folly::SocketAddress boundAddress; + boundAddress.setFromLocalAddress(fd); + // listen + CHECK_EQ(listen(fd, 16), 0); + + // Create a server socket + AsyncServerSocket::UniquePtr serverSocket( + new AsyncServerSocket(&eventBase)); + serverSocket->useExistingSocket(fd); + + // Make sure AsyncServerSocket reports the same address that we bound to + folly::SocketAddress serverSocketAddress; + serverSocket->getAddress(&serverSocketAddress); + CHECK_EQ(boundAddress, serverSocketAddress); + + // Make sure the socket works + serverSocketSanityTest(serverSocket.get()); + } +} + +TEST(AsyncSocketTest, UnixDomainSocketTest) { + EventBase eventBase; + + // Create a server socket + std::shared_ptr serverSocket( + AsyncServerSocket::newSocket(&eventBase)); + string path(1, 0); + path.append("/anonymous"); + folly::SocketAddress serverAddress; + serverAddress.setFromPath(path); + serverSocket->bind(serverAddress); + serverSocket->listen(16); + + // Add a callback to accept one connection then stop the loop + TestAcceptCallback acceptCallback; + acceptCallback.setConnectionAcceptedFn( + [&](int fd, const folly::SocketAddress& addr) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + acceptCallback.setAcceptErrorFn([&](const std::exception& ex) { + serverSocket->removeAcceptCallback(&acceptCallback, nullptr); + }); + serverSocket->addAcceptCallback(&acceptCallback, nullptr); + serverSocket->startAccepting(); + + // Connect to the server socket + std::shared_ptr socket( + AsyncSocket::newSocket(&eventBase, serverAddress)); + + eventBase.loop(); + + // Verify that the server accepted a connection + CHECK_EQ(acceptCallback.getEvents()->size(), 3); + CHECK_EQ(acceptCallback.getEvents()->at(0).type, + TestAcceptCallback::TYPE_START); + CHECK_EQ(acceptCallback.getEvents()->at(1).type, + TestAcceptCallback::TYPE_ACCEPT); + CHECK_EQ(acceptCallback.getEvents()->at(2).type, + TestAcceptCallback::TYPE_STOP); + int fd = acceptCallback.getEvents()->at(1).fd; + + // The accepted connection should already be in non-blocking mode + int flags = fcntl(fd, F_GETFL, 0); + CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK); +} diff --git a/folly/io/async/test/BlockingSocket.h b/folly/io/async/test/BlockingSocket.h new file mode 100644 index 00000000..7d2ee454 --- /dev/null +++ b/folly/io/async/test/BlockingSocket.h @@ -0,0 +1,126 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +class BlockingSocket : public folly::AsyncSocket::ConnectCallback, + public folly::AsyncTransportWrapper::ReadCallback, + public folly::AsyncTransportWrapper::WriteCallback +{ + public: + explicit BlockingSocket(int fd) + : sock_(new folly::AsyncSocket(&eventBase_, fd)) { + } + + BlockingSocket(folly::SocketAddress address, + std::shared_ptr sslContext) + : sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) : + new folly::AsyncSocket(&eventBase_)), + address_(address) {} + + void open() { + sock_->connect(this, address_); + eventBase_.loop(); + if (err_.hasValue()) { + throw err_.value(); + } + } + void close() { + sock_->close(); + } + + int32_t write(uint8_t const* buf, size_t len) { + sock_->write(this, buf, len); + eventBase_.loop(); + if (err_.hasValue()) { + throw err_.value(); + } + return len; + } + + void flush() {} + + int32_t readAll(uint8_t *buf, size_t len) { + return readHelper(buf, len, true); + } + + int32_t read(uint8_t *buf, size_t len) { + return readHelper(buf, len, false); + } + + int getSocketFD() const { + return sock_->getFd(); + } + + private: + folly::EventBase eventBase_; + folly::AsyncSocket::UniquePtr sock_; + folly::Optional err_; + uint8_t *readBuf_{nullptr}; + size_t readLen_{0}; + folly::SocketAddress address_; + + void connectSuccess() noexcept override {} + void connectErr(const folly::AsyncSocketException& ex) noexcept override { + err_ = ex; + } + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + *bufReturn = readBuf_; + *lenReturn = readLen_; + } + void readDataAvailable(size_t len) noexcept override { + readBuf_ += len; + readLen_ -= len; + if (readLen_ == 0) { + sock_->setReadCB(nullptr); + } + } + void readEOF() noexcept override { + } + void readErr(const folly::AsyncSocketException& ex) noexcept override { + err_ = ex; + } + void writeSuccess() noexcept override {} + void writeErr(size_t bytesWritten, + const folly::AsyncSocketException& ex) noexcept override { + err_ = ex; + } + + int32_t readHelper(uint8_t *buf, size_t len, bool all) { + readBuf_ = buf; + readLen_ = len; + sock_->setReadCB(this); + while (!err_ && sock_->good() && readLen_ > 0) { + eventBase_.loop(); + if (!all) { + break; + } + } + sock_->setReadCB(nullptr); + if (err_.hasValue()) { + throw err_.value(); + } + if (all && readLen_ > 0) { + throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN, + "eof"); + } + return len - readLen_; + } +}; diff --git a/folly/io/async/test/certs/ca-cert.pem b/folly/io/async/test/certs/ca-cert.pem new file mode 100644 index 00000000..1a4f22b1 --- /dev/null +++ b/folly/io/async/test/certs/ca-cert.pem @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIJAKMZICGWUzawMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAlVTMQ8wDQYDVQQKDAZUaHJpZnQxJTAjBgNVBAMMHFRocmlmdCBDZXJ0aWZp +Y2F0ZSBBdXRob3JpdHkwHhcNMTQwNTE2MjAyODUyWhcNNDExMDAxMjAyODUyWjBF +MQswCQYDVQQGEwJVUzEPMA0GA1UECgwGVGhyaWZ0MSUwIwYDVQQDDBxUaHJpZnQg +Q2VydGlmaWNhdGUgQXV0aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEA1Bx2vUvXZ8PrvEBxwdH5qM1F2Xo7UkeC1jzQ+OLUBEcCiEduyStitSvB +NOAzAGdjt7NmHTP/7OJngp2vzQGjSQzm20XacyTieFUuPBuikUc0Ge3Tf+uQXtiU +zZPh+xn6arHH+zBWtmUCt3cBrpgRqdnWUsbl8eqo5HsczY781FxQbDoT9VP6A+9R +KGTsEhxxKbWJ1C7OngwLKc7Zv4DtTC1JFlFyKd8ryDtxP4s/GgsXJkoK0Hkpputr +cMxMm6OGt77mFvzR2qRY1CpEK/9rjBB6Gqd8GakXsvoOsqL/37k2wVhN/JoS/Pde +12Mp6TZ2rA8NW8vRujfWU0u55gnQnwIDAQABo1AwTjAdBgNVHQ4EFgQUQ00NGVmY +NZ6LJg8UQUOVLZX1Gh8wHwYDVR0jBBgwFoAUQ00NGVmYNZ6LJg8UQUOVLZX1Gh8w +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQUFAAOCAQEAdlxt5+z9uXCBr1Wt6r49 +4MmOYw9lOnEOG1JPMRo108TLpmwXEWReCAtjQuR7BitRJW0kJtlO1M6t3qoIh6GA +sBkgsjQM1xNY3YEpx71MLt1V+JD+2WtSBKMyysj1TiOmIH66kkvXO3ptXzhjhZyX +G6B+kxLtxrqkn9SJULyN55X8T+dkW28UIBZVLavoREDU+UPrYU9JgZeIVObtGSWi +DvS4RIJZNjgG3vTrT00rfUGEfTlI54Vbcmv0cYvswP/nMsLtDStCdgI7c/ipyJve +dfuI4CedjE240AxK5OFxFg/k/IfnB4a5oojbdIR9hKrTU57TPaUVD50Na9WA1aqX +5Q== +-----END CERTIFICATE----- diff --git a/folly/io/async/test/certs/tests-cert.pem b/folly/io/async/test/certs/tests-cert.pem new file mode 100644 index 00000000..894ba82c --- /dev/null +++ b/folly/io/async/test/certs/tests-cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDKzCCAhOgAwIBAgIBCjANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJVUzEP +MA0GA1UECgwGVGhyaWZ0MSUwIwYDVQQDDBxUaHJpZnQgQ2VydGlmaWNhdGUgQXV0 +aG9yaXR5MB4XDTE0MDUxNjIwMjg1MloXDTQxMTAwMTIwMjg1MlowRjELMAkGA1UE +BhMCVVMxDTALBgNVBAgTBE9oaW8xETAPBgNVBAcTCEhpbGxpYXJkMRUwEwYDVQQD +EwxBc294IENvbXBhbnkwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCz +ZGrJ5XQHAuMYHlBgn32OOc9l0n3RjXccio2ceeWctXkSxDP3vFyZ4kBILF1lsY1g +o8UTjMkqSDYcytCLK0qavrv9BZRLB9FcpqJ9o4V9feaI/HsHa8DYHyEs8qyNTTNG +YQ3i4j+AA9iDSpezIYy/tyAOAjrSquUW1jI4tzKTBh8hk8MAMvR2/NPHPkrp4gI+ +EMH6u4vWdr4F9bbriLFWoU04T9mWOMk7G+h8BS9sgINg2+v5cWvl3BC4kLk5L1yJ +FEyuofSSCEEe6dDf7uVh+RPKa4hEkIYo31AEOPFrN56d+pCj/5l67HTWXoQx3rjy +dNXMvgU75urm6TQe8dB5AgMBAAGjJTAjMCEGA1UdEQQaMBiHBH8AAAGHEAAAAAAA +AAAAAAAAAAAAAAEwDQYJKoZIhvcNAQEFBQADggEBAD26XYInaEvlWZJYgtl3yQyC +3NRQc3LG7XxWg4aFdXCxYLPRAL2HLoarKYH8GPFso57t5xnhA8WfP7iJxmgsKdCS +0pNIicOWsMmXvYLib0j9tMCFR+a8rn3f4n+clwnqas4w/vWBJUoMgyxtkP8NNNZO +kIl02JKRhuyiFyPLilVp5tu0e+lmyUER+ak53WjLq2yoytYAlHkzkOpc4MZ/TNt5 +UTEtx/WVlZvlrPi3dsi7QikkjQgo1wCnm7owtuAHlPDMAB8wKk4+vvIOjsGM33T/ +8ffq/4X1HeYM0w0fM+SVlX1rwkXA1RW/jn48VWFHpWbE10+m196OdiToGfm2OJI= +-----END CERTIFICATE----- diff --git a/folly/io/async/test/certs/tests-key.pem b/folly/io/async/test/certs/tests-key.pem new file mode 100644 index 00000000..caa90522 --- /dev/null +++ b/folly/io/async/test/certs/tests-key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAs2RqyeV0BwLjGB5QYJ99jjnPZdJ90Y13HIqNnHnlnLV5EsQz +97xcmeJASCxdZbGNYKPFE4zJKkg2HMrQiytKmr67/QWUSwfRXKaifaOFfX3miPx7 +B2vA2B8hLPKsjU0zRmEN4uI/gAPYg0qXsyGMv7cgDgI60qrlFtYyOLcykwYfIZPD +ADL0dvzTxz5K6eICPhDB+ruL1na+BfW264ixVqFNOE/ZljjJOxvofAUvbICDYNvr ++XFr5dwQuJC5OS9ciRRMrqH0kghBHunQ3+7lYfkTymuIRJCGKN9QBDjxazeenfqQ +o/+Zeux01l6EMd648nTVzL4FO+bq5uk0HvHQeQIDAQABAoIBAQCSPcBYindF5/Kd +jMjVm+9M7I/IYAo1tG9vkvvSngSy9bWXuN7sjF+pCyqAK7qP1mh8acWVJGYx0+BZ +JHVRnp8Y+3hg0hWL/PmN4EICzjVakjJHZhwddpglF2uCKurD3jV4oFIjrXE6uOfe +UAbO/wCwoWa+RM8TQkGzljYmyiGufCcXlgEKMNA7TIvbJ9TVx3VTCOQy6EjZ13jd +M6X7byV/ZOFpZ2H0QV46LvZraw04riXQ/59gVmzizYdI+BwnxxapsCmalTJoV/Y0 +LMI2ylat4PTMVTxPF+ti7Nt+rUkkEx6kuiAgfc+bzE4BSD5X4wy3fdLVLccoxXYw +4N3fOuQhAoGBAOLrMhiSCrzXGjDWTbPrwzxXDO0qm+wURELi3N5SXIkKUdG2/In6 +wNdpXdvqblOm7SASgPf9KCwUSADrNw6R6nbfrrir5EHg66YydI/OW42QzJKcBUFh +5Q5na3fvoL/zRhsmh0gEymBg+OIfNel2LY69bl8aAko2y0R1kj7zb8X1AoGBAMph +9hlnkIBSw60+pKHaOqo2t/kihNNMFyfOgJvh8960eFeMDhMIXgxPUR8yaPX0bBMb +bCdEJJ2pmq7zUBPvxVJLedwkGMhywElA8yYVh+S6x4Cg+lYo4spIjrHQ/WTvJkHB +GrDskxdq80lbXjwRd0dPJZkxhKJec1o0n8S03Mn1AoGAGarK5taWGlgmYUHMVj6j +vc6G6si4DFMaiYpJu2gLiYC+Un9lP2I6r+L+N+LjidjG16rgJazf/2Rn5Jq2hpJg +uAODKuZekkkTvp/UaXPJDVFEooy9V3DwTNnL4SwcvbmRw35vLOlFzvMJE+K94WN5 +sbyhoGY7vhNGmL7HxREaIoUCgYEAwpteVWFz3yE2ziF9l7FMVh7V23go9zGk1n9I +xhyJL26khbLEWeLi5L1kiTYlHdUSE3F8F2n8N6s+ddq79t/KA29WV6xSNHW7lvUg +mk975CMC8hpZfn5ETjVlGXGYJ/Wa+QGiE9z5ODx8gt6cB/DXnLdrtRqbqrJeA7C0 +rScpY/0CgYBCC1QeuAiwWHOqQn3BwsZo9JQBTyT0QvDqLH/F+h9QbXep+4HvyAxG +nTMNDtGyfyKGDaDUn5hyeU7Oxvzq0K9P+eZD3MjQeaMEg/++GPGUPmDUTqyb2UT8 +5s0NIUobxfKnTD6IpgOIq7ffvVY6cKBMyuLmu/gSvscsbONHjKti3Q== +-----END RSA PRIVATE KEY----- diff --git a/folly/io/test/ShutdownSocketSetTest.cpp b/folly/io/test/ShutdownSocketSetTest.cpp new file mode 100644 index 00000000..76a598fa --- /dev/null +++ b/folly/io/test/ShutdownSocketSetTest.cpp @@ -0,0 +1,231 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include + +using folly::ShutdownSocketSet; + +namespace folly { namespace test { + +ShutdownSocketSet shutdownSocketSet; + +class Server { + public: + Server(); + + void stop(bool abortive); + void join(); + int port() const { return port_; } + int closeClients(bool abortive); + + private: + int acceptSocket_; + int port_; + enum StopMode { + NO_STOP, + ORDERLY, + ABORTIVE + }; + std::atomic stop_; + std::thread serverThread_; + std::vector fds_; +}; + +Server::Server() + : acceptSocket_(-1), + port_(0), + stop_(NO_STOP) { + acceptSocket_ = socket(PF_INET, SOCK_STREAM, 0); + CHECK_ERR(acceptSocket_); + shutdownSocketSet.add(acceptSocket_); + + sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.s_addr = INADDR_ANY; + CHECK_ERR(bind(acceptSocket_, + reinterpret_cast(&addr), + sizeof(addr))); + + CHECK_ERR(listen(acceptSocket_, 10)); + + socklen_t addrLen = sizeof(addr); + CHECK_ERR(getsockname(acceptSocket_, + reinterpret_cast(&addr), + &addrLen)); + + port_ = ntohs(addr.sin_port); + + serverThread_ = std::thread([this] { + while (stop_ == NO_STOP) { + sockaddr_in peer; + socklen_t peerLen = sizeof(peer); + int fd = accept(acceptSocket_, + reinterpret_cast(&peer), + &peerLen); + if (fd == -1) { + if (errno == EINTR) { + continue; + } + if (errno == EINVAL || errno == ENOTSOCK) { // socket broken + break; + } + } + CHECK_ERR(fd); + shutdownSocketSet.add(fd); + fds_.push_back(fd); + } + + if (stop_ != NO_STOP) { + closeClients(stop_ == ABORTIVE); + } + + shutdownSocketSet.close(acceptSocket_); + acceptSocket_ = -1; + port_ = 0; + }); +} + +int Server::closeClients(bool abortive) { + for (int fd : fds_) { + if (abortive) { + struct linger l = {1, 0}; + CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l))); + } + shutdownSocketSet.close(fd); + } + int n = fds_.size(); + fds_.clear(); + return n; +} + +void Server::stop(bool abortive) { + stop_ = abortive ? ABORTIVE : ORDERLY; + shutdown(acceptSocket_, SHUT_RDWR); +} + +void Server::join() { + serverThread_.join(); +} + +int createConnectedSocket(int port) { + int sock = socket(PF_INET, SOCK_STREAM, 0); + CHECK_ERR(sock); + sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + addr.sin_addr.s_addr = htonl((127 << 24) | 1); // XXX + CHECK_ERR(connect(sock, + reinterpret_cast(&addr), + sizeof(addr))); + return sock; +} + +void runCloseTest(bool abortive) { + Server server; + + int sock = createConnectedSocket(server.port()); + + std::thread stopper([&server, abortive] { + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + server.stop(abortive); + server.join(); + }); + + char c; + int r = read(sock, &c, 1); + if (abortive) { + int e = errno; + EXPECT_EQ(-1, r); + EXPECT_EQ(ECONNRESET, e); + } else { + EXPECT_EQ(0, r); + } + + close(sock); + + stopper.join(); + + EXPECT_EQ(0, server.closeClients(false)); // closed by server when it exited +} + +TEST(ShutdownSocketSetTest, OrderlyClose) { + runCloseTest(false); +} + +TEST(ShutdownSocketSetTest, AbortiveClose) { + runCloseTest(true); +} + +void runKillTest(bool abortive) { + Server server; + + int sock = createConnectedSocket(server.port()); + + std::thread killer([&server, abortive] { + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + shutdownSocketSet.shutdownAll(abortive); + server.join(); + }); + + char c; + int r = read(sock, &c, 1); + + // "abortive" is just a hint for ShutdownSocketSet, so accept both + // behaviors + if (abortive) { + if (r == -1) { + EXPECT_EQ(ECONNRESET, errno); + } else { + EXPECT_EQ(r, 0); + } + } else { + EXPECT_EQ(0, r); + } + + close(sock); + + killer.join(); + + // NOT closed by server when it exited + EXPECT_EQ(1, server.closeClients(false)); +} + +TEST(ShutdownSocketSetTest, OrderlyKill) { + runKillTest(false); +} + +TEST(ShutdownSocketSetTest, AbortiveKill) { + runKillTest(true); +} + +}} // namespaces + +int main(int argc, char *argv[]) { + testing::InitGoogleTest(&argc, argv); + google::ParseCommandLineFlags(&argc, &argv, true); + return RUN_ALL_TESTS(); +}