2 * Copyright 2015 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/SocketAddress.h>
29 #include <gtest/gtest.h>
35 #include <sys/types.h>
36 #include <sys/socket.h>
37 #include <netinet/tcp.h>
47 // The destructors of all callback classes assert that the state is
48 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
49 // are responsible for setting the succeeded state properly before the
50 // destructors are called.
52 class WriteCallbackBase :
53 public AsyncTransportWrapper::WriteCallback {
56 : state(STATE_WAITING)
58 , exception(AsyncSocketException::UNKNOWN, "none") {}
60 ~WriteCallbackBase() {
61 EXPECT_EQ(state, STATE_SUCCEEDED);
65 const std::shared_ptr<AsyncSSLSocket> &socket) {
69 void writeSuccess() noexcept override {
70 std::cerr << "writeSuccess" << std::endl;
71 state = STATE_SUCCEEDED;
76 const AsyncSocketException& ex) noexcept override {
77 std::cerr << "writeError: bytesWritten " << bytesWritten
78 << ", exception " << ex.what() << std::endl;
81 this->bytesWritten = bytesWritten;
84 socket_->detachEventBase();
87 std::shared_ptr<AsyncSSLSocket> socket_;
90 AsyncSocketException exception;
93 class ReadCallbackBase :
94 public AsyncTransportWrapper::ReadCallback {
96 explicit ReadCallbackBase(WriteCallbackBase *wcb)
98 , state(STATE_WAITING) {}
100 ~ReadCallbackBase() {
101 EXPECT_EQ(state, STATE_SUCCEEDED);
105 const std::shared_ptr<AsyncSSLSocket> &socket) {
109 void setState(StateEnum s) {
117 const AsyncSocketException& ex) noexcept override {
118 std::cerr << "readError " << ex.what() << std::endl;
119 state = STATE_FAILED;
121 socket_->detachEventBase();
124 void readEOF() noexcept override {
125 std::cerr << "readEOF" << std::endl;
128 socket_->detachEventBase();
131 std::shared_ptr<AsyncSSLSocket> socket_;
132 WriteCallbackBase *wcb_;
136 class ReadCallback : public ReadCallbackBase {
138 explicit ReadCallback(WriteCallbackBase *wcb)
139 : ReadCallbackBase(wcb)
143 for (std::vector<Buffer>::iterator it = buffers.begin();
148 currentBuffer.free();
151 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
152 if (!currentBuffer.buffer) {
153 currentBuffer.allocate(4096);
155 *bufReturn = currentBuffer.buffer;
156 *lenReturn = currentBuffer.length;
159 void readDataAvailable(size_t len) noexcept override {
160 std::cerr << "readDataAvailable, len " << len << std::endl;
162 currentBuffer.length = len;
164 wcb_->setSocket(socket_);
166 // Write back the same data.
167 socket_->write(wcb_, currentBuffer.buffer, len);
169 buffers.push_back(currentBuffer);
170 currentBuffer.reset();
171 state = STATE_SUCCEEDED;
176 Buffer() : buffer(nullptr), length(0) {}
177 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
183 void allocate(size_t length) {
184 assert(buffer == nullptr);
185 this->buffer = static_cast<char*>(malloc(length));
186 this->length = length;
197 std::vector<Buffer> buffers;
198 Buffer currentBuffer;
201 class ReadErrorCallback : public ReadCallbackBase {
203 explicit ReadErrorCallback(WriteCallbackBase *wcb)
204 : ReadCallbackBase(wcb) {}
206 // Return nullptr buffer to trigger readError()
207 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
208 *bufReturn = nullptr;
212 void readDataAvailable(size_t len) noexcept override {
213 // This should never to called.
218 const AsyncSocketException& ex) noexcept override {
219 ReadCallbackBase::readErr(ex);
220 std::cerr << "ReadErrorCallback::readError" << std::endl;
221 setState(STATE_SUCCEEDED);
225 class WriteErrorCallback : public ReadCallback {
227 explicit WriteErrorCallback(WriteCallbackBase *wcb)
228 : ReadCallback(wcb) {}
230 void readDataAvailable(size_t len) noexcept override {
231 std::cerr << "readDataAvailable, len " << len << std::endl;
233 currentBuffer.length = len;
235 // close the socket before writing to trigger writeError().
236 ::close(socket_->getFd());
238 wcb_->setSocket(socket_);
240 // Write back the same data.
241 socket_->write(wcb_, currentBuffer.buffer, len);
243 if (wcb_->state == STATE_FAILED) {
244 setState(STATE_SUCCEEDED);
246 state = STATE_FAILED;
249 buffers.push_back(currentBuffer);
250 currentBuffer.reset();
253 void readErr(const AsyncSocketException& ex) noexcept override {
254 std::cerr << "readError " << ex.what() << std::endl;
255 // do nothing since this is expected
259 class EmptyReadCallback : public ReadCallback {
261 explicit EmptyReadCallback()
262 : ReadCallback(nullptr) {}
264 void readErr(const AsyncSocketException& ex) noexcept override {
265 std::cerr << "readError " << ex.what() << std::endl;
266 state = STATE_FAILED;
268 tcpSocket_->detachEventBase();
271 void readEOF() noexcept override {
272 std::cerr << "readEOF" << std::endl;
275 tcpSocket_->detachEventBase();
276 state = STATE_SUCCEEDED;
279 std::shared_ptr<AsyncSocket> tcpSocket_;
282 class HandshakeCallback :
283 public AsyncSSLSocket::HandshakeCB {
290 explicit HandshakeCallback(ReadCallbackBase *rcb,
291 ExpectType expect = EXPECT_SUCCESS):
292 state(STATE_WAITING),
297 const std::shared_ptr<AsyncSSLSocket> &socket) {
301 void setState(StateEnum s) {
306 // Functions inherited from AsyncSSLSocketHandshakeCallback
307 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
308 EXPECT_EQ(sock, socket_.get());
309 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
310 rcb_->setSocket(socket_);
311 sock->setReadCB(rcb_);
312 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
315 AsyncSSLSocket *sock,
316 const AsyncSocketException& ex) noexcept override {
317 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
318 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
319 if (expect_ == EXPECT_ERROR) {
320 // rcb will never be invoked
321 rcb_->setState(STATE_SUCCEEDED);
325 ~HandshakeCallback() {
326 EXPECT_EQ(state, STATE_SUCCEEDED);
331 state = STATE_SUCCEEDED;
335 std::shared_ptr<AsyncSSLSocket> socket_;
336 ReadCallbackBase *rcb_;
340 class SSLServerAcceptCallbackBase:
341 public folly::AsyncServerSocket::AcceptCallback {
343 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
344 state(STATE_WAITING), hcb_(hcb) {}
346 ~SSLServerAcceptCallbackBase() {
347 EXPECT_EQ(state, STATE_SUCCEEDED);
350 void acceptError(const std::exception& ex) noexcept override {
351 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
352 << ex.what() << std::endl;
353 state = STATE_FAILED;
356 void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
358 printf("Connection accepted\n");
359 std::shared_ptr<AsyncSSLSocket> sslSock;
361 // Create a AsyncSSLSocket object with the fd. The socket should be
362 // added to the event base and in the state of accepting SSL connection.
363 sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
364 } catch (const std::exception &e) {
365 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
366 "object with socket " << e.what() << fd;
372 connAccepted(sslSock);
375 virtual void connAccepted(
376 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
379 HandshakeCallback *hcb_;
380 std::shared_ptr<folly::SSLContext> ctx_;
381 folly::EventBase* base_;
384 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
388 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
389 uint32_t timeout = 0):
390 SSLServerAcceptCallbackBase(hcb),
393 virtual ~SSLServerAcceptCallback() {
395 // if we set a timeout, we expect failure
396 EXPECT_EQ(hcb_->state, STATE_FAILED);
397 hcb_->setState(STATE_SUCCEEDED);
401 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
403 const std::shared_ptr<folly::AsyncSSLSocket> &s)
405 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
406 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
408 hcb_->setSocket(sock);
409 sock->sslAccept(hcb_, timeout_);
410 EXPECT_EQ(sock->getSSLState(),
411 AsyncSSLSocket::STATE_ACCEPTING);
413 state = STATE_SUCCEEDED;
417 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
419 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
420 SSLServerAcceptCallback(hcb) {}
422 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
424 const std::shared_ptr<folly::AsyncSSLSocket> &s)
427 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
429 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
431 int fd = sock->getFd();
435 // The accepted connection should already have TCP_NODELAY set
437 socklen_t valueLength = sizeof(value);
438 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
444 // Unset the TCP_NODELAY option.
446 socklen_t valueLength = sizeof(value);
447 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
450 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
454 SSLServerAcceptCallback::connAccepted(sock);
458 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
460 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
461 uint32_t timeout = 0):
462 SSLServerAcceptCallback(hcb, timeout) {}
464 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
466 const std::shared_ptr<folly::AsyncSSLSocket> &s)
468 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
470 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
472 hcb_->setSocket(sock);
473 sock->sslAccept(hcb_, timeout_);
474 ASSERT_TRUE((sock->getSSLState() ==
475 AsyncSSLSocket::STATE_ACCEPTING) ||
476 (sock->getSSLState() ==
477 AsyncSSLSocket::STATE_CACHE_LOOKUP));
479 state = STATE_SUCCEEDED;
484 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
486 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
487 SSLServerAcceptCallbackBase(hcb) {}
489 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
491 const std::shared_ptr<folly::AsyncSSLSocket> &s)
493 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
495 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
497 // The first call to sslAccept() should succeed.
498 hcb_->setSocket(sock);
499 sock->sslAccept(hcb_);
500 EXPECT_EQ(sock->getSSLState(),
501 AsyncSSLSocket::STATE_ACCEPTING);
503 // The second call to sslAccept() should fail.
504 HandshakeCallback callback2(hcb_->rcb_);
505 callback2.setSocket(sock);
506 sock->sslAccept(&callback2);
507 EXPECT_EQ(sock->getSSLState(),
508 AsyncSSLSocket::STATE_ERROR);
510 // Both callbacks should be in the error state.
511 EXPECT_EQ(hcb_->state, STATE_FAILED);
512 EXPECT_EQ(callback2.state, STATE_FAILED);
514 sock->detachEventBase();
516 state = STATE_SUCCEEDED;
517 hcb_->setState(STATE_SUCCEEDED);
518 callback2.setState(STATE_SUCCEEDED);
522 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
524 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
525 SSLServerAcceptCallbackBase(hcb) {}
527 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
529 const std::shared_ptr<folly::AsyncSSLSocket> &s)
531 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
533 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
535 hcb_->setSocket(sock);
536 sock->getEventBase()->tryRunAfterDelay([=] {
537 std::cerr << "Delayed SSL accept, client will have close by now"
539 // SSL accept will fail
542 AsyncSSLSocket::STATE_UNINIT);
543 hcb_->socket_->sslAccept(hcb_);
544 // This registers for an event
547 AsyncSSLSocket::STATE_ACCEPTING);
549 state = STATE_SUCCEEDED;
555 class TestSSLServer {
558 std::shared_ptr<folly::SSLContext> ctx_;
559 SSLServerAcceptCallbackBase *acb_;
560 std::shared_ptr<folly::AsyncServerSocket> socket_;
561 folly::SocketAddress address_;
564 static void *Main(void *ctx) {
565 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
567 std::cerr << "Server thread exited event loop" << std::endl;
572 // Create a TestSSLServer.
573 // This immediately starts listening on the given port.
574 explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
578 evb_.runInEventBaseThread([&](){
579 socket_->stopAccepting();
581 std::cerr << "Waiting for server thread to exit" << std::endl;
582 pthread_join(thread_, nullptr);
585 EventBase &getEventBase() { return evb_; }
587 const folly::SocketAddress& getAddress() const {
592 class TestSSLAsyncCacheServer : public TestSSLServer {
594 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
595 int lookupDelay = 100) :
597 SSL_CTX *sslCtx = ctx_->getSSLCtx();
598 SSL_CTX_sess_set_get_cb(sslCtx,
599 TestSSLAsyncCacheServer::getSessionCallback);
600 SSL_CTX_set_session_cache_mode(
601 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
604 lookupDelay_ = lookupDelay;
607 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
608 uint32_t getAsyncLookups() const { return asyncLookups_; }
611 static uint32_t asyncCallbacks_;
612 static uint32_t asyncLookups_;
613 static uint32_t lookupDelay_;
615 static SSL_SESSION *getSessionCallback(SSL *ssl,
616 unsigned char *sess_id,
621 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
622 if (!SSL_want_sess_cache_lookup(ssl)) {
623 // libssl.so mismatch
624 std::cerr << "no async support" << std::endl;
628 AsyncSSLSocket *sslSocket =
629 AsyncSSLSocket::getFromSSL(ssl);
630 assert(sslSocket != nullptr);
631 // Going to simulate an async cache by just running delaying the miss 100ms
632 if (asyncCallbacks_ % 2 == 0) {
633 // This socket is already blocked on lookup, return miss
634 std::cerr << "returning miss" << std::endl;
636 // fresh meat - block it
637 std::cerr << "async lookup" << std::endl;
638 sslSocket->getEventBase()->tryRunAfterDelay(
639 std::bind(&AsyncSSLSocket::restartSSLAccept,
640 sslSocket), lookupDelay_);
641 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
649 void getfds(int fds[2]);
652 std::shared_ptr<folly::SSLContext> clientCtx,
653 std::shared_ptr<folly::SSLContext> serverCtx);
656 EventBase* eventBase,
657 AsyncSSLSocket::UniquePtr* clientSock,
658 AsyncSSLSocket::UniquePtr* serverSock);
660 class BlockingWriteClient :
661 private AsyncSSLSocket::HandshakeCB,
662 private AsyncTransportWrapper::WriteCallback {
664 explicit BlockingWriteClient(
665 AsyncSSLSocket::UniquePtr socket)
666 : socket_(std::move(socket)),
670 buf_.reset(new uint8_t[bufLen_]);
671 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
676 iov_.reset(new struct iovec[iovCount_]);
677 for (uint32_t n = 0; n < iovCount_; ++n) {
678 iov_[n].iov_base = buf_.get() + n;
680 iov_[n].iov_len = n % bufLen_;
682 iov_[n].iov_len = bufLen_ - (n % bufLen_);
686 socket_->sslConn(this, 100);
689 struct iovec* getIovec() const {
692 uint32_t getIovecCount() const {
697 void handshakeSuc(AsyncSSLSocket*) noexcept override {
698 socket_->writev(this, iov_.get(), iovCount_);
702 const AsyncSocketException& ex) noexcept override {
703 ADD_FAILURE() << "client handshake error: " << ex.what();
705 void writeSuccess() noexcept override {
710 const AsyncSocketException& ex) noexcept override {
711 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
715 AsyncSSLSocket::UniquePtr socket_;
718 std::unique_ptr<uint8_t[]> buf_;
719 std::unique_ptr<struct iovec[]> iov_;
722 class BlockingWriteServer :
723 private AsyncSSLSocket::HandshakeCB,
724 private AsyncTransportWrapper::ReadCallback {
726 explicit BlockingWriteServer(
727 AsyncSSLSocket::UniquePtr socket)
728 : socket_(std::move(socket)),
729 bufSize_(2500 * 2000),
731 buf_.reset(new uint8_t[bufSize_]);
732 socket_->sslAccept(this, 100);
735 void checkBuffer(struct iovec* iov, uint32_t count) const {
737 for (uint32_t n = 0; n < count; ++n) {
738 size_t bytesLeft = bytesRead_ - idx;
739 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
740 std::min(iov[n].iov_len, bytesLeft));
742 FAIL() << "buffer mismatch at iovec " << n << "/" << count
746 if (iov[n].iov_len > bytesLeft) {
747 FAIL() << "server did not read enough data: "
748 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
749 << " in iovec " << n << "/" << count;
752 idx += iov[n].iov_len;
754 if (idx != bytesRead_) {
755 ADD_FAILURE() << "server read extra data: " << bytesRead_
756 << " bytes read; expected " << idx;
761 void handshakeSuc(AsyncSSLSocket*) noexcept override {
762 // Wait 10ms before reading, so the client's writes will initially block.
763 socket_->getEventBase()->tryRunAfterDelay(
764 [this] { socket_->setReadCB(this); }, 10);
768 const AsyncSocketException& ex) noexcept override {
769 ADD_FAILURE() << "server handshake error: " << ex.what();
771 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
772 *bufReturn = buf_.get() + bytesRead_;
773 *lenReturn = bufSize_ - bytesRead_;
775 void readDataAvailable(size_t len) noexcept override {
777 socket_->setReadCB(nullptr);
778 socket_->getEventBase()->tryRunAfterDelay(
779 [this] { socket_->setReadCB(this); }, 2);
781 void readEOF() noexcept override {
785 const AsyncSocketException& ex) noexcept override {
786 ADD_FAILURE() << "server read error: " << ex.what();
789 AsyncSSLSocket::UniquePtr socket_;
792 std::unique_ptr<uint8_t[]> buf_;
796 private AsyncSSLSocket::HandshakeCB,
797 private AsyncTransportWrapper::WriteCallback {
800 AsyncSSLSocket::UniquePtr socket)
801 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
802 socket_->sslConn(this);
805 const unsigned char* nextProto;
806 unsigned nextProtoLength;
808 void handshakeSuc(AsyncSSLSocket*) noexcept override {
809 socket_->getSelectedNextProtocol(&nextProto,
814 const AsyncSocketException& ex) noexcept override {
815 ADD_FAILURE() << "client handshake error: " << ex.what();
817 void writeSuccess() noexcept override {
822 const AsyncSocketException& ex) noexcept override {
823 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
827 AsyncSSLSocket::UniquePtr socket_;
831 private AsyncSSLSocket::HandshakeCB,
832 private AsyncTransportWrapper::ReadCallback {
834 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
835 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
836 socket_->sslAccept(this);
839 const unsigned char* nextProto;
840 unsigned nextProtoLength;
842 void handshakeSuc(AsyncSSLSocket*) noexcept override {
843 socket_->getSelectedNextProtocol(&nextProto,
848 const AsyncSocketException& ex) noexcept override {
849 ADD_FAILURE() << "server handshake error: " << ex.what();
851 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
854 void readDataAvailable(size_t len) noexcept override {
856 void readEOF() noexcept override {
860 const AsyncSocketException& ex) noexcept override {
861 ADD_FAILURE() << "server read error: " << ex.what();
864 AsyncSSLSocket::UniquePtr socket_;
867 #ifndef OPENSSL_NO_TLSEXT
869 private AsyncSSLSocket::HandshakeCB,
870 private AsyncTransportWrapper::WriteCallback {
873 AsyncSSLSocket::UniquePtr socket)
874 : serverNameMatch(false), socket_(std::move(socket)) {
875 socket_->sslConn(this);
878 bool serverNameMatch;
881 void handshakeSuc(AsyncSSLSocket*) noexcept override {
882 serverNameMatch = socket_->isServerNameMatch();
886 const AsyncSocketException& ex) noexcept override {
887 ADD_FAILURE() << "client handshake error: " << ex.what();
889 void writeSuccess() noexcept override {
894 const AsyncSocketException& ex) noexcept override {
895 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
899 AsyncSSLSocket::UniquePtr socket_;
903 private AsyncSSLSocket::HandshakeCB,
904 private AsyncTransportWrapper::ReadCallback {
907 AsyncSSLSocket::UniquePtr socket,
908 const std::shared_ptr<folly::SSLContext>& ctx,
909 const std::shared_ptr<folly::SSLContext>& sniCtx,
910 const std::string& expectedServerName)
911 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
912 expectedServerName_(expectedServerName) {
913 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
914 std::placeholders::_1));
915 socket_->sslAccept(this);
918 bool serverNameMatch;
921 void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
924 const AsyncSocketException& ex) noexcept override {
925 ADD_FAILURE() << "server handshake error: " << ex.what();
927 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
930 void readDataAvailable(size_t len) noexcept override {
932 void readEOF() noexcept override {
936 const AsyncSocketException& ex) noexcept override {
937 ADD_FAILURE() << "server read error: " << ex.what();
940 folly::SSLContext::ServerNameCallbackResult
941 serverNameCallback(SSL *ssl) {
942 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
945 !strcasecmp(expectedServerName_.c_str(), sn)) {
946 AsyncSSLSocket *sslSocket =
947 AsyncSSLSocket::getFromSSL(ssl);
948 sslSocket->switchServerSSLContext(sniCtx_);
949 serverNameMatch = true;
950 return folly::SSLContext::SERVER_NAME_FOUND;
952 serverNameMatch = false;
953 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
957 AsyncSSLSocket::UniquePtr socket_;
958 std::shared_ptr<folly::SSLContext> sniCtx_;
959 std::string expectedServerName_;
963 class SSLClient : public AsyncSocket::ConnectCallback,
964 public AsyncTransportWrapper::WriteCallback,
965 public AsyncTransportWrapper::ReadCallback
968 EventBase *eventBase_;
969 std::shared_ptr<AsyncSSLSocket> sslSocket_;
970 SSL_SESSION *session_;
971 std::shared_ptr<folly::SSLContext> ctx_;
973 folly::SocketAddress address_;
981 uint32_t writeAfterConnectErrors_;
983 // These settings test that we eventually drain the
984 // socket, even if the maxReadsPerEvent_ is hit during
985 // a event loop iteration.
986 static constexpr size_t kMaxReadsPerEvent = 2;
987 static constexpr size_t kMaxReadBufferSz =
988 sizeof(readbuf_) / kMaxReadsPerEvent / 2; // 2 event loop iterations
991 SSLClient(EventBase *eventBase,
992 const folly::SocketAddress& address,
994 uint32_t timeout = 0)
995 : eventBase_(eventBase),
1004 writeAfterConnectErrors_(0) {
1005 ctx_.reset(new folly::SSLContext());
1006 ctx_->setOptions(SSL_OP_NO_TICKET);
1007 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1008 memset(buf_, 'a', sizeof(buf_));
1013 SSL_SESSION_free(session_);
1016 EXPECT_EQ(bytesRead_, sizeof(buf_));
1020 uint32_t getHit() const { return hit_; }
1022 uint32_t getMiss() const { return miss_; }
1024 uint32_t getErrors() const { return errors_; }
1026 uint32_t getWriteAfterConnectErrors() const {
1027 return writeAfterConnectErrors_;
1030 void connect(bool writeNow = false) {
1031 sslSocket_ = AsyncSSLSocket::newSocket(
1033 if (session_ != nullptr) {
1034 sslSocket_->setSSLSession(session_);
1037 sslSocket_->connect(this, address_, timeout_);
1038 if (sslSocket_ && writeNow) {
1039 // write some junk, used in an error test
1040 sslSocket_->write(this, buf_, sizeof(buf_));
1044 void connectSuccess() noexcept override {
1045 std::cerr << "client SSL socket connected" << std::endl;
1046 if (sslSocket_->getSSLSessionReused()) {
1050 if (session_ != nullptr) {
1051 SSL_SESSION_free(session_);
1053 session_ = sslSocket_->getSSLSession();
1057 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1058 sslSocket_->write(this, buf_, sizeof(buf_));
1059 sslSocket_->setReadCB(this);
1060 memset(readbuf_, 'b', sizeof(readbuf_));
1065 const AsyncSocketException& ex) noexcept override {
1066 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1071 void writeSuccess() noexcept override {
1072 std::cerr << "client write success" << std::endl;
1076 size_t bytesWritten,
1077 const AsyncSocketException& ex)
1079 std::cerr << "client writeError: " << ex.what() << std::endl;
1081 writeAfterConnectErrors_++;
1085 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1086 *bufReturn = readbuf_ + bytesRead_;
1087 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1090 void readEOF() noexcept override {
1091 std::cerr << "client readEOF" << std::endl;
1095 const AsyncSocketException& ex) noexcept override {
1096 std::cerr << "client readError: " << ex.what() << std::endl;
1099 void readDataAvailable(size_t len) noexcept override {
1100 std::cerr << "client read data: " << len << std::endl;
1102 if (bytesRead_ == sizeof(buf_)) {
1103 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1104 sslSocket_->closeNow();
1106 if (requests_ != 0) {
1114 class SSLHandshakeBase :
1115 public AsyncSSLSocket::HandshakeCB,
1116 private AsyncTransportWrapper::WriteCallback {
1118 explicit SSLHandshakeBase(
1119 AsyncSSLSocket::UniquePtr socket,
1120 bool preverifyResult,
1121 bool verifyResult) :
1122 handshakeVerify_(false),
1123 handshakeSuccess_(false),
1124 handshakeError_(false),
1125 socket_(std::move(socket)),
1126 preverifyResult_(preverifyResult),
1127 verifyResult_(verifyResult) {
1130 bool handshakeVerify_;
1131 bool handshakeSuccess_;
1132 bool handshakeError_;
1133 std::chrono::nanoseconds handshakeTime;
1136 AsyncSSLSocket::UniquePtr socket_;
1137 bool preverifyResult_;
1140 // HandshakeCallback
1142 AsyncSSLSocket* sock,
1144 X509_STORE_CTX* ctx) noexcept override {
1145 handshakeVerify_ = true;
1147 EXPECT_EQ(preverifyResult_, preverifyOk);
1148 return verifyResult_;
1151 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1152 handshakeSuccess_ = true;
1153 handshakeTime = socket_->getHandshakeTime();
1158 const AsyncSocketException& ex) noexcept override {
1159 handshakeError_ = true;
1160 handshakeTime = socket_->getHandshakeTime();
1164 void writeSuccess() noexcept override {
1169 size_t bytesWritten,
1170 const AsyncSocketException& ex) noexcept override {
1171 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1176 class SSLHandshakeClient : public SSLHandshakeBase {
1179 AsyncSSLSocket::UniquePtr socket,
1180 bool preverifyResult,
1181 bool verifyResult) :
1182 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1183 socket_->sslConn(this, 0);
1187 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1189 SSLHandshakeClientNoVerify(
1190 AsyncSSLSocket::UniquePtr socket,
1191 bool preverifyResult,
1192 bool verifyResult) :
1193 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1194 socket_->sslConn(this, 0,
1195 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1199 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1201 SSLHandshakeClientDoVerify(
1202 AsyncSSLSocket::UniquePtr socket,
1203 bool preverifyResult,
1204 bool verifyResult) :
1205 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1206 socket_->sslConn(this, 0,
1207 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1211 class SSLHandshakeServer : public SSLHandshakeBase {
1214 AsyncSSLSocket::UniquePtr socket,
1215 bool preverifyResult,
1217 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1218 socket_->sslAccept(this, 0);
1222 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1224 SSLHandshakeServerParseClientHello(
1225 AsyncSSLSocket::UniquePtr socket,
1226 bool preverifyResult,
1228 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1229 socket_->enableClientHelloParsing();
1230 socket_->sslAccept(this, 0);
1233 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1236 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1237 handshakeSuccess_ = true;
1238 sock->getSSLSharedCiphers(sharedCiphers_);
1239 sock->getSSLServerCiphers(serverCiphers_);
1240 sock->getSSLClientCiphers(clientCiphers_);
1241 chosenCipher_ = sock->getNegotiatedCipherName();
1246 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1248 SSLHandshakeServerNoVerify(
1249 AsyncSSLSocket::UniquePtr socket,
1250 bool preverifyResult,
1252 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1253 socket_->sslAccept(this, 0,
1254 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1258 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1260 SSLHandshakeServerDoVerify(
1261 AsyncSSLSocket::UniquePtr socket,
1262 bool preverifyResult,
1264 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1265 socket_->sslAccept(this, 0,
1266 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1270 class EventBaseAborter : public AsyncTimeout {
1272 EventBaseAborter(EventBase* eventBase,
1275 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1276 , eventBase_(eventBase) {
1277 scheduleTimeout(timeoutMS);
1280 void timeoutExpired() noexcept override {
1281 FAIL() << "test timed out";
1282 eventBase_->terminateLoopSoon();
1286 EventBase* eventBase_;