2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/SocketAddress.h>
29 #include <gtest/gtest.h>
35 #include <sys/types.h>
36 #include <sys/socket.h>
37 #include <netinet/tcp.h>
47 // The destructors of all callback classes assert that the state is
48 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
49 // are responsible for setting the succeeded state properly before the
50 // destructors are called.
52 class WriteCallbackBase :
53 public AsyncTransportWrapper::WriteCallback {
56 : state(STATE_WAITING)
58 , exception(AsyncSocketException::UNKNOWN, "none") {}
60 ~WriteCallbackBase() {
61 EXPECT_EQ(state, STATE_SUCCEEDED);
65 const std::shared_ptr<AsyncSSLSocket> &socket) {
69 void writeSuccess() noexcept override {
70 std::cerr << "writeSuccess" << std::endl;
71 state = STATE_SUCCEEDED;
76 const AsyncSocketException& ex) noexcept override {
77 std::cerr << "writeError: bytesWritten " << bytesWritten
78 << ", exception " << ex.what() << std::endl;
81 this->bytesWritten = bytesWritten;
84 socket_->detachEventBase();
87 std::shared_ptr<AsyncSSLSocket> socket_;
90 AsyncSocketException exception;
93 class ReadCallbackBase :
94 public AsyncTransportWrapper::ReadCallback {
96 explicit ReadCallbackBase(WriteCallbackBase *wcb)
98 , state(STATE_WAITING) {}
100 ~ReadCallbackBase() {
101 EXPECT_EQ(state, STATE_SUCCEEDED);
105 const std::shared_ptr<AsyncSSLSocket> &socket) {
109 void setState(StateEnum s) {
117 const AsyncSocketException& ex) noexcept override {
118 std::cerr << "readError " << ex.what() << std::endl;
119 state = STATE_FAILED;
121 socket_->detachEventBase();
124 void readEOF() noexcept override {
125 std::cerr << "readEOF" << std::endl;
128 socket_->detachEventBase();
131 std::shared_ptr<AsyncSSLSocket> socket_;
132 WriteCallbackBase *wcb_;
136 class ReadCallback : public ReadCallbackBase {
138 explicit ReadCallback(WriteCallbackBase *wcb)
139 : ReadCallbackBase(wcb)
143 for (std::vector<Buffer>::iterator it = buffers.begin();
148 currentBuffer.free();
151 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
152 if (!currentBuffer.buffer) {
153 currentBuffer.allocate(4096);
155 *bufReturn = currentBuffer.buffer;
156 *lenReturn = currentBuffer.length;
159 void readDataAvailable(size_t len) noexcept override {
160 std::cerr << "readDataAvailable, len " << len << std::endl;
162 currentBuffer.length = len;
164 wcb_->setSocket(socket_);
166 // Write back the same data.
167 socket_->write(wcb_, currentBuffer.buffer, len);
169 buffers.push_back(currentBuffer);
170 currentBuffer.reset();
171 state = STATE_SUCCEEDED;
176 Buffer() : buffer(nullptr), length(0) {}
177 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
183 void allocate(size_t length) {
184 assert(buffer == nullptr);
185 this->buffer = static_cast<char*>(malloc(length));
186 this->length = length;
197 std::vector<Buffer> buffers;
198 Buffer currentBuffer;
201 class ReadErrorCallback : public ReadCallbackBase {
203 explicit ReadErrorCallback(WriteCallbackBase *wcb)
204 : ReadCallbackBase(wcb) {}
206 // Return nullptr buffer to trigger readError()
207 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
208 *bufReturn = nullptr;
212 void readDataAvailable(size_t /* len */) noexcept override {
213 // This should never to called.
218 const AsyncSocketException& ex) noexcept override {
219 ReadCallbackBase::readErr(ex);
220 std::cerr << "ReadErrorCallback::readError" << std::endl;
221 setState(STATE_SUCCEEDED);
225 class WriteErrorCallback : public ReadCallback {
227 explicit WriteErrorCallback(WriteCallbackBase *wcb)
228 : ReadCallback(wcb) {}
230 void readDataAvailable(size_t len) noexcept override {
231 std::cerr << "readDataAvailable, len " << len << std::endl;
233 currentBuffer.length = len;
235 // close the socket before writing to trigger writeError().
236 ::close(socket_->getFd());
238 wcb_->setSocket(socket_);
240 // Write back the same data.
241 socket_->write(wcb_, currentBuffer.buffer, len);
243 if (wcb_->state == STATE_FAILED) {
244 setState(STATE_SUCCEEDED);
246 state = STATE_FAILED;
249 buffers.push_back(currentBuffer);
250 currentBuffer.reset();
253 void readErr(const AsyncSocketException& ex) noexcept override {
254 std::cerr << "readError " << ex.what() << std::endl;
255 // do nothing since this is expected
259 class EmptyReadCallback : public ReadCallback {
261 explicit EmptyReadCallback()
262 : ReadCallback(nullptr) {}
264 void readErr(const AsyncSocketException& ex) noexcept override {
265 std::cerr << "readError " << ex.what() << std::endl;
266 state = STATE_FAILED;
268 tcpSocket_->detachEventBase();
271 void readEOF() noexcept override {
272 std::cerr << "readEOF" << std::endl;
275 tcpSocket_->detachEventBase();
276 state = STATE_SUCCEEDED;
279 std::shared_ptr<AsyncSocket> tcpSocket_;
282 class HandshakeCallback :
283 public AsyncSSLSocket::HandshakeCB {
290 explicit HandshakeCallback(ReadCallbackBase *rcb,
291 ExpectType expect = EXPECT_SUCCESS):
292 state(STATE_WAITING),
297 const std::shared_ptr<AsyncSSLSocket> &socket) {
301 void setState(StateEnum s) {
306 // Functions inherited from AsyncSSLSocketHandshakeCallback
307 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
308 EXPECT_EQ(sock, socket_.get());
309 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
310 rcb_->setSocket(socket_);
311 sock->setReadCB(rcb_);
312 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
314 void handshakeErr(AsyncSSLSocket* /* sock */,
315 const AsyncSocketException& ex) noexcept override {
316 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
317 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
318 if (expect_ == EXPECT_ERROR) {
319 // rcb will never be invoked
320 rcb_->setState(STATE_SUCCEEDED);
324 ~HandshakeCallback() {
325 EXPECT_EQ(state, STATE_SUCCEEDED);
330 state = STATE_SUCCEEDED;
334 std::shared_ptr<AsyncSSLSocket> socket_;
335 ReadCallbackBase *rcb_;
339 class SSLServerAcceptCallbackBase:
340 public folly::AsyncServerSocket::AcceptCallback {
342 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
343 state(STATE_WAITING), hcb_(hcb) {}
345 ~SSLServerAcceptCallbackBase() {
346 EXPECT_EQ(state, STATE_SUCCEEDED);
349 void acceptError(const std::exception& ex) noexcept override {
350 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
351 << ex.what() << std::endl;
352 state = STATE_FAILED;
355 void connectionAccepted(
356 int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
357 printf("Connection accepted\n");
358 std::shared_ptr<AsyncSSLSocket> sslSock;
360 // Create a AsyncSSLSocket object with the fd. The socket should be
361 // added to the event base and in the state of accepting SSL connection.
362 sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
363 } catch (const std::exception &e) {
364 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
365 "object with socket " << e.what() << fd;
371 connAccepted(sslSock);
374 virtual void connAccepted(
375 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
378 HandshakeCallback *hcb_;
379 std::shared_ptr<folly::SSLContext> ctx_;
380 folly::EventBase* base_;
383 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
387 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
388 uint32_t timeout = 0):
389 SSLServerAcceptCallbackBase(hcb),
392 virtual ~SSLServerAcceptCallback() {
394 // if we set a timeout, we expect failure
395 EXPECT_EQ(hcb_->state, STATE_FAILED);
396 hcb_->setState(STATE_SUCCEEDED);
400 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
402 const std::shared_ptr<folly::AsyncSSLSocket> &s)
404 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
405 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
407 hcb_->setSocket(sock);
408 sock->sslAccept(hcb_, timeout_);
409 EXPECT_EQ(sock->getSSLState(),
410 AsyncSSLSocket::STATE_ACCEPTING);
412 state = STATE_SUCCEEDED;
416 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
418 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
419 SSLServerAcceptCallback(hcb) {}
421 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
423 const std::shared_ptr<folly::AsyncSSLSocket> &s)
426 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
428 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
430 int fd = sock->getFd();
434 // The accepted connection should already have TCP_NODELAY set
436 socklen_t valueLength = sizeof(value);
437 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
443 // Unset the TCP_NODELAY option.
445 socklen_t valueLength = sizeof(value);
446 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
449 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
453 SSLServerAcceptCallback::connAccepted(sock);
457 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
459 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
460 uint32_t timeout = 0):
461 SSLServerAcceptCallback(hcb, timeout) {}
463 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
465 const std::shared_ptr<folly::AsyncSSLSocket> &s)
467 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
469 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
471 hcb_->setSocket(sock);
472 sock->sslAccept(hcb_, timeout_);
473 ASSERT_TRUE((sock->getSSLState() ==
474 AsyncSSLSocket::STATE_ACCEPTING) ||
475 (sock->getSSLState() ==
476 AsyncSSLSocket::STATE_CACHE_LOOKUP));
478 state = STATE_SUCCEEDED;
483 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
485 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
486 SSLServerAcceptCallbackBase(hcb) {}
488 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
490 const std::shared_ptr<folly::AsyncSSLSocket> &s)
492 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
494 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
496 // The first call to sslAccept() should succeed.
497 hcb_->setSocket(sock);
498 sock->sslAccept(hcb_);
499 EXPECT_EQ(sock->getSSLState(),
500 AsyncSSLSocket::STATE_ACCEPTING);
502 // The second call to sslAccept() should fail.
503 HandshakeCallback callback2(hcb_->rcb_);
504 callback2.setSocket(sock);
505 sock->sslAccept(&callback2);
506 EXPECT_EQ(sock->getSSLState(),
507 AsyncSSLSocket::STATE_ERROR);
509 // Both callbacks should be in the error state.
510 EXPECT_EQ(hcb_->state, STATE_FAILED);
511 EXPECT_EQ(callback2.state, STATE_FAILED);
513 sock->detachEventBase();
515 state = STATE_SUCCEEDED;
516 hcb_->setState(STATE_SUCCEEDED);
517 callback2.setState(STATE_SUCCEEDED);
521 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
523 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
524 SSLServerAcceptCallbackBase(hcb) {}
526 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
528 const std::shared_ptr<folly::AsyncSSLSocket> &s)
530 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
532 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
534 hcb_->setSocket(sock);
535 sock->getEventBase()->tryRunAfterDelay([=] {
536 std::cerr << "Delayed SSL accept, client will have close by now"
538 // SSL accept will fail
541 AsyncSSLSocket::STATE_UNINIT);
542 hcb_->socket_->sslAccept(hcb_);
543 // This registers for an event
546 AsyncSSLSocket::STATE_ACCEPTING);
548 state = STATE_SUCCEEDED;
554 class TestSSLServer {
557 std::shared_ptr<folly::SSLContext> ctx_;
558 SSLServerAcceptCallbackBase *acb_;
559 std::shared_ptr<folly::AsyncServerSocket> socket_;
560 folly::SocketAddress address_;
563 static void *Main(void *ctx) {
564 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
566 std::cerr << "Server thread exited event loop" << std::endl;
571 // Create a TestSSLServer.
572 // This immediately starts listening on the given port.
573 explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
577 evb_.runInEventBaseThread([&](){
578 socket_->stopAccepting();
580 std::cerr << "Waiting for server thread to exit" << std::endl;
581 pthread_join(thread_, nullptr);
584 EventBase &getEventBase() { return evb_; }
586 const folly::SocketAddress& getAddress() const {
591 class TestSSLAsyncCacheServer : public TestSSLServer {
593 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
594 int lookupDelay = 100) :
596 SSL_CTX *sslCtx = ctx_->getSSLCtx();
597 SSL_CTX_sess_set_get_cb(sslCtx,
598 TestSSLAsyncCacheServer::getSessionCallback);
599 SSL_CTX_set_session_cache_mode(
600 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
603 lookupDelay_ = lookupDelay;
606 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
607 uint32_t getAsyncLookups() const { return asyncLookups_; }
610 static uint32_t asyncCallbacks_;
611 static uint32_t asyncLookups_;
612 static uint32_t lookupDelay_;
614 static SSL_SESSION* getSessionCallback(SSL* ssl,
615 unsigned char* /* sess_id */,
620 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
621 if (!SSL_want_sess_cache_lookup(ssl)) {
622 // libssl.so mismatch
623 std::cerr << "no async support" << std::endl;
627 AsyncSSLSocket *sslSocket =
628 AsyncSSLSocket::getFromSSL(ssl);
629 assert(sslSocket != nullptr);
630 // Going to simulate an async cache by just running delaying the miss 100ms
631 if (asyncCallbacks_ % 2 == 0) {
632 // This socket is already blocked on lookup, return miss
633 std::cerr << "returning miss" << std::endl;
635 // fresh meat - block it
636 std::cerr << "async lookup" << std::endl;
637 sslSocket->getEventBase()->tryRunAfterDelay(
638 std::bind(&AsyncSSLSocket::restartSSLAccept,
639 sslSocket), lookupDelay_);
640 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
648 void getfds(int fds[2]);
651 std::shared_ptr<folly::SSLContext> clientCtx,
652 std::shared_ptr<folly::SSLContext> serverCtx);
655 EventBase* eventBase,
656 AsyncSSLSocket::UniquePtr* clientSock,
657 AsyncSSLSocket::UniquePtr* serverSock);
659 class BlockingWriteClient :
660 private AsyncSSLSocket::HandshakeCB,
661 private AsyncTransportWrapper::WriteCallback {
663 explicit BlockingWriteClient(
664 AsyncSSLSocket::UniquePtr socket)
665 : socket_(std::move(socket)),
669 buf_.reset(new uint8_t[bufLen_]);
670 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
675 iov_.reset(new struct iovec[iovCount_]);
676 for (uint32_t n = 0; n < iovCount_; ++n) {
677 iov_[n].iov_base = buf_.get() + n;
679 iov_[n].iov_len = n % bufLen_;
681 iov_[n].iov_len = bufLen_ - (n % bufLen_);
685 socket_->sslConn(this, 100);
688 struct iovec* getIovec() const {
691 uint32_t getIovecCount() const {
696 void handshakeSuc(AsyncSSLSocket*) noexcept override {
697 socket_->writev(this, iov_.get(), iovCount_);
701 const AsyncSocketException& ex) noexcept override {
702 ADD_FAILURE() << "client handshake error: " << ex.what();
704 void writeSuccess() noexcept override {
709 const AsyncSocketException& ex) noexcept override {
710 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
714 AsyncSSLSocket::UniquePtr socket_;
717 std::unique_ptr<uint8_t[]> buf_;
718 std::unique_ptr<struct iovec[]> iov_;
721 class BlockingWriteServer :
722 private AsyncSSLSocket::HandshakeCB,
723 private AsyncTransportWrapper::ReadCallback {
725 explicit BlockingWriteServer(
726 AsyncSSLSocket::UniquePtr socket)
727 : socket_(std::move(socket)),
728 bufSize_(2500 * 2000),
730 buf_.reset(new uint8_t[bufSize_]);
731 socket_->sslAccept(this, 100);
734 void checkBuffer(struct iovec* iov, uint32_t count) const {
736 for (uint32_t n = 0; n < count; ++n) {
737 size_t bytesLeft = bytesRead_ - idx;
738 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
739 std::min(iov[n].iov_len, bytesLeft));
741 FAIL() << "buffer mismatch at iovec " << n << "/" << count
745 if (iov[n].iov_len > bytesLeft) {
746 FAIL() << "server did not read enough data: "
747 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
748 << " in iovec " << n << "/" << count;
751 idx += iov[n].iov_len;
753 if (idx != bytesRead_) {
754 ADD_FAILURE() << "server read extra data: " << bytesRead_
755 << " bytes read; expected " << idx;
760 void handshakeSuc(AsyncSSLSocket*) noexcept override {
761 // Wait 10ms before reading, so the client's writes will initially block.
762 socket_->getEventBase()->tryRunAfterDelay(
763 [this] { socket_->setReadCB(this); }, 10);
767 const AsyncSocketException& ex) noexcept override {
768 ADD_FAILURE() << "server handshake error: " << ex.what();
770 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
771 *bufReturn = buf_.get() + bytesRead_;
772 *lenReturn = bufSize_ - bytesRead_;
774 void readDataAvailable(size_t len) noexcept override {
776 socket_->setReadCB(nullptr);
777 socket_->getEventBase()->tryRunAfterDelay(
778 [this] { socket_->setReadCB(this); }, 2);
780 void readEOF() noexcept override {
784 const AsyncSocketException& ex) noexcept override {
785 ADD_FAILURE() << "server read error: " << ex.what();
788 AsyncSSLSocket::UniquePtr socket_;
791 std::unique_ptr<uint8_t[]> buf_;
795 private AsyncSSLSocket::HandshakeCB,
796 private AsyncTransportWrapper::WriteCallback {
799 AsyncSSLSocket::UniquePtr socket)
800 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
801 socket_->sslConn(this);
804 const unsigned char* nextProto;
805 unsigned nextProtoLength;
806 SSLContext::NextProtocolType protocolType;
809 void handshakeSuc(AsyncSSLSocket*) noexcept override {
810 socket_->getSelectedNextProtocol(
811 &nextProto, &nextProtoLength, &protocolType);
815 const AsyncSocketException& ex) noexcept override {
816 ADD_FAILURE() << "client handshake error: " << ex.what();
818 void writeSuccess() noexcept override {
823 const AsyncSocketException& ex) noexcept override {
824 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
828 AsyncSSLSocket::UniquePtr socket_;
832 private AsyncSSLSocket::HandshakeCB,
833 private AsyncTransportWrapper::ReadCallback {
835 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
836 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
837 socket_->sslAccept(this);
840 const unsigned char* nextProto;
841 unsigned nextProtoLength;
842 SSLContext::NextProtocolType protocolType;
845 void handshakeSuc(AsyncSSLSocket*) noexcept override {
846 socket_->getSelectedNextProtocol(
847 &nextProto, &nextProtoLength, &protocolType);
851 const AsyncSocketException& ex) noexcept override {
852 ADD_FAILURE() << "server handshake error: " << ex.what();
854 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
857 void readDataAvailable(size_t /* len */) noexcept override {}
858 void readEOF() noexcept override {
862 const AsyncSocketException& ex) noexcept override {
863 ADD_FAILURE() << "server read error: " << ex.what();
866 AsyncSSLSocket::UniquePtr socket_;
869 #ifndef OPENSSL_NO_TLSEXT
871 private AsyncSSLSocket::HandshakeCB,
872 private AsyncTransportWrapper::WriteCallback {
875 AsyncSSLSocket::UniquePtr socket)
876 : serverNameMatch(false), socket_(std::move(socket)) {
877 socket_->sslConn(this);
880 bool serverNameMatch;
883 void handshakeSuc(AsyncSSLSocket*) noexcept override {
884 serverNameMatch = socket_->isServerNameMatch();
888 const AsyncSocketException& ex) noexcept override {
889 ADD_FAILURE() << "client handshake error: " << ex.what();
891 void writeSuccess() noexcept override {
896 const AsyncSocketException& ex) noexcept override {
897 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
901 AsyncSSLSocket::UniquePtr socket_;
905 private AsyncSSLSocket::HandshakeCB,
906 private AsyncTransportWrapper::ReadCallback {
909 AsyncSSLSocket::UniquePtr socket,
910 const std::shared_ptr<folly::SSLContext>& ctx,
911 const std::shared_ptr<folly::SSLContext>& sniCtx,
912 const std::string& expectedServerName)
913 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
914 expectedServerName_(expectedServerName) {
915 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
916 std::placeholders::_1));
917 socket_->sslAccept(this);
920 bool serverNameMatch;
923 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
926 const AsyncSocketException& ex) noexcept override {
927 ADD_FAILURE() << "server handshake error: " << ex.what();
929 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
932 void readDataAvailable(size_t /* len */) noexcept override {}
933 void readEOF() noexcept override {
937 const AsyncSocketException& ex) noexcept override {
938 ADD_FAILURE() << "server read error: " << ex.what();
941 folly::SSLContext::ServerNameCallbackResult
942 serverNameCallback(SSL *ssl) {
943 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
946 !strcasecmp(expectedServerName_.c_str(), sn)) {
947 AsyncSSLSocket *sslSocket =
948 AsyncSSLSocket::getFromSSL(ssl);
949 sslSocket->switchServerSSLContext(sniCtx_);
950 serverNameMatch = true;
951 return folly::SSLContext::SERVER_NAME_FOUND;
953 serverNameMatch = false;
954 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
958 AsyncSSLSocket::UniquePtr socket_;
959 std::shared_ptr<folly::SSLContext> sniCtx_;
960 std::string expectedServerName_;
964 class SSLClient : public AsyncSocket::ConnectCallback,
965 public AsyncTransportWrapper::WriteCallback,
966 public AsyncTransportWrapper::ReadCallback
969 EventBase *eventBase_;
970 std::shared_ptr<AsyncSSLSocket> sslSocket_;
971 SSL_SESSION *session_;
972 std::shared_ptr<folly::SSLContext> ctx_;
974 folly::SocketAddress address_;
982 uint32_t writeAfterConnectErrors_;
984 // These settings test that we eventually drain the
985 // socket, even if the maxReadsPerEvent_ is hit during
986 // a event loop iteration.
987 static constexpr size_t kMaxReadsPerEvent = 2;
988 static constexpr size_t kMaxReadBufferSz =
989 sizeof(readbuf_) / kMaxReadsPerEvent / 2; // 2 event loop iterations
992 SSLClient(EventBase *eventBase,
993 const folly::SocketAddress& address,
995 uint32_t timeout = 0)
996 : eventBase_(eventBase),
1005 writeAfterConnectErrors_(0) {
1006 ctx_.reset(new folly::SSLContext());
1007 ctx_->setOptions(SSL_OP_NO_TICKET);
1008 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1009 memset(buf_, 'a', sizeof(buf_));
1014 SSL_SESSION_free(session_);
1017 EXPECT_EQ(bytesRead_, sizeof(buf_));
1021 uint32_t getHit() const { return hit_; }
1023 uint32_t getMiss() const { return miss_; }
1025 uint32_t getErrors() const { return errors_; }
1027 uint32_t getWriteAfterConnectErrors() const {
1028 return writeAfterConnectErrors_;
1031 void connect(bool writeNow = false) {
1032 sslSocket_ = AsyncSSLSocket::newSocket(
1034 if (session_ != nullptr) {
1035 sslSocket_->setSSLSession(session_);
1038 sslSocket_->connect(this, address_, timeout_);
1039 if (sslSocket_ && writeNow) {
1040 // write some junk, used in an error test
1041 sslSocket_->write(this, buf_, sizeof(buf_));
1045 void connectSuccess() noexcept override {
1046 std::cerr << "client SSL socket connected" << std::endl;
1047 if (sslSocket_->getSSLSessionReused()) {
1051 if (session_ != nullptr) {
1052 SSL_SESSION_free(session_);
1054 session_ = sslSocket_->getSSLSession();
1058 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1059 sslSocket_->write(this, buf_, sizeof(buf_));
1060 sslSocket_->setReadCB(this);
1061 memset(readbuf_, 'b', sizeof(readbuf_));
1066 const AsyncSocketException& ex) noexcept override {
1067 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1072 void writeSuccess() noexcept override {
1073 std::cerr << "client write success" << std::endl;
1076 void writeErr(size_t /* bytesWritten */,
1077 const AsyncSocketException& ex) noexcept override {
1078 std::cerr << "client writeError: " << ex.what() << std::endl;
1080 writeAfterConnectErrors_++;
1084 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1085 *bufReturn = readbuf_ + bytesRead_;
1086 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1089 void readEOF() noexcept override {
1090 std::cerr << "client readEOF" << std::endl;
1094 const AsyncSocketException& ex) noexcept override {
1095 std::cerr << "client readError: " << ex.what() << std::endl;
1098 void readDataAvailable(size_t len) noexcept override {
1099 std::cerr << "client read data: " << len << std::endl;
1101 if (bytesRead_ == sizeof(buf_)) {
1102 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1103 sslSocket_->closeNow();
1105 if (requests_ != 0) {
1113 class SSLHandshakeBase :
1114 public AsyncSSLSocket::HandshakeCB,
1115 private AsyncTransportWrapper::WriteCallback {
1117 explicit SSLHandshakeBase(
1118 AsyncSSLSocket::UniquePtr socket,
1119 bool preverifyResult,
1120 bool verifyResult) :
1121 handshakeVerify_(false),
1122 handshakeSuccess_(false),
1123 handshakeError_(false),
1124 socket_(std::move(socket)),
1125 preverifyResult_(preverifyResult),
1126 verifyResult_(verifyResult) {
1129 bool handshakeVerify_;
1130 bool handshakeSuccess_;
1131 bool handshakeError_;
1132 std::chrono::nanoseconds handshakeTime;
1135 AsyncSSLSocket::UniquePtr socket_;
1136 bool preverifyResult_;
1139 // HandshakeCallback
1140 bool handshakeVer(AsyncSSLSocket* /* sock */,
1142 X509_STORE_CTX* /* ctx */) noexcept override {
1143 handshakeVerify_ = true;
1145 EXPECT_EQ(preverifyResult_, preverifyOk);
1146 return verifyResult_;
1149 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1150 handshakeSuccess_ = true;
1151 handshakeTime = socket_->getHandshakeTime();
1154 void handshakeErr(AsyncSSLSocket*,
1155 const AsyncSocketException& /* ex */) noexcept override {
1156 handshakeError_ = true;
1157 handshakeTime = socket_->getHandshakeTime();
1161 void writeSuccess() noexcept override {
1166 size_t bytesWritten,
1167 const AsyncSocketException& ex) noexcept override {
1168 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1173 class SSLHandshakeClient : public SSLHandshakeBase {
1176 AsyncSSLSocket::UniquePtr socket,
1177 bool preverifyResult,
1178 bool verifyResult) :
1179 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1180 socket_->sslConn(this, 0);
1184 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1186 SSLHandshakeClientNoVerify(
1187 AsyncSSLSocket::UniquePtr socket,
1188 bool preverifyResult,
1189 bool verifyResult) :
1190 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1191 socket_->sslConn(this, 0,
1192 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1196 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1198 SSLHandshakeClientDoVerify(
1199 AsyncSSLSocket::UniquePtr socket,
1200 bool preverifyResult,
1201 bool verifyResult) :
1202 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1203 socket_->sslConn(this, 0,
1204 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1208 class SSLHandshakeServer : public SSLHandshakeBase {
1211 AsyncSSLSocket::UniquePtr socket,
1212 bool preverifyResult,
1214 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1215 socket_->sslAccept(this, 0);
1219 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1221 SSLHandshakeServerParseClientHello(
1222 AsyncSSLSocket::UniquePtr socket,
1223 bool preverifyResult,
1225 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1226 socket_->enableClientHelloParsing();
1227 socket_->sslAccept(this, 0);
1230 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1233 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1234 handshakeSuccess_ = true;
1235 sock->getSSLSharedCiphers(sharedCiphers_);
1236 sock->getSSLServerCiphers(serverCiphers_);
1237 sock->getSSLClientCiphers(clientCiphers_);
1238 chosenCipher_ = sock->getNegotiatedCipherName();
1243 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1245 SSLHandshakeServerNoVerify(
1246 AsyncSSLSocket::UniquePtr socket,
1247 bool preverifyResult,
1249 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1250 socket_->sslAccept(this, 0,
1251 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1255 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1257 SSLHandshakeServerDoVerify(
1258 AsyncSSLSocket::UniquePtr socket,
1259 bool preverifyResult,
1261 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1262 socket_->sslAccept(this, 0,
1263 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1267 class EventBaseAborter : public AsyncTimeout {
1269 EventBaseAborter(EventBase* eventBase,
1272 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1273 , eventBase_(eventBase) {
1274 scheduleTimeout(timeoutMS);
1277 void timeoutExpired() noexcept override {
1278 FAIL() << "test timed out";
1279 eventBase_->terminateLoopSoon();
1283 EventBase* eventBase_;