2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/Sockets.h>
33 #include <folly/portability/Unistd.h>
36 #include <sys/types.h>
37 #include <condition_variable>
49 // The destructors of all callback classes assert that the state is
50 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
51 // are responsible for setting the succeeded state properly before the
52 // destructors are called.
54 class WriteCallbackBase :
55 public AsyncTransportWrapper::WriteCallback {
58 : state(STATE_WAITING)
60 , exception(AsyncSocketException::UNKNOWN, "none") {}
62 ~WriteCallbackBase() {
63 EXPECT_EQ(STATE_SUCCEEDED, state);
67 const std::shared_ptr<AsyncSSLSocket> &socket) {
71 void writeSuccess() noexcept override {
72 std::cerr << "writeSuccess" << std::endl;
73 state = STATE_SUCCEEDED;
78 const AsyncSocketException& ex) noexcept override {
79 std::cerr << "writeError: bytesWritten " << nBytesWritten
80 << ", exception " << ex.what() << std::endl;
83 this->bytesWritten = nBytesWritten;
88 std::shared_ptr<AsyncSSLSocket> socket_;
91 AsyncSocketException exception;
94 class ReadCallbackBase :
95 public AsyncTransportWrapper::ReadCallback {
97 explicit ReadCallbackBase(WriteCallbackBase* wcb)
98 : wcb_(wcb), state(STATE_WAITING) {}
100 ~ReadCallbackBase() {
101 EXPECT_EQ(STATE_SUCCEEDED, state);
105 const std::shared_ptr<AsyncSSLSocket> &socket) {
109 void setState(StateEnum s) {
117 const AsyncSocketException& ex) noexcept override {
118 std::cerr << "readError " << ex.what() << std::endl;
119 state = STATE_FAILED;
123 void readEOF() noexcept override {
124 std::cerr << "readEOF" << std::endl;
129 std::shared_ptr<AsyncSSLSocket> socket_;
130 WriteCallbackBase *wcb_;
134 class ReadCallback : public ReadCallbackBase {
136 explicit ReadCallback(WriteCallbackBase *wcb)
137 : ReadCallbackBase(wcb)
141 for (std::vector<Buffer>::iterator it = buffers.begin();
146 currentBuffer.free();
149 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
150 if (!currentBuffer.buffer) {
151 currentBuffer.allocate(4096);
153 *bufReturn = currentBuffer.buffer;
154 *lenReturn = currentBuffer.length;
157 void readDataAvailable(size_t len) noexcept override {
158 std::cerr << "readDataAvailable, len " << len << std::endl;
160 currentBuffer.length = len;
162 wcb_->setSocket(socket_);
164 // Write back the same data.
165 socket_->write(wcb_, currentBuffer.buffer, len);
167 buffers.push_back(currentBuffer);
168 currentBuffer.reset();
169 state = STATE_SUCCEEDED;
174 Buffer() : buffer(nullptr), length(0) {}
175 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
181 void allocate(size_t len) {
182 assert(buffer == nullptr);
183 this->buffer = static_cast<char*>(malloc(len));
195 std::vector<Buffer> buffers;
196 Buffer currentBuffer;
199 class ReadErrorCallback : public ReadCallbackBase {
201 explicit ReadErrorCallback(WriteCallbackBase *wcb)
202 : ReadCallbackBase(wcb) {}
204 // Return nullptr buffer to trigger readError()
205 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
206 *bufReturn = nullptr;
210 void readDataAvailable(size_t /* len */) noexcept override {
211 // This should never to called.
216 const AsyncSocketException& ex) noexcept override {
217 ReadCallbackBase::readErr(ex);
218 std::cerr << "ReadErrorCallback::readError" << std::endl;
219 setState(STATE_SUCCEEDED);
223 class ReadEOFCallback : public ReadCallbackBase {
225 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
227 // Return nullptr buffer to trigger readError()
228 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
229 *bufReturn = nullptr;
233 void readDataAvailable(size_t /* len */) noexcept override {
234 // This should never to called.
238 void readEOF() noexcept override {
239 ReadCallbackBase::readEOF();
240 setState(STATE_SUCCEEDED);
244 class WriteErrorCallback : public ReadCallback {
246 explicit WriteErrorCallback(WriteCallbackBase *wcb)
247 : ReadCallback(wcb) {}
249 void readDataAvailable(size_t len) noexcept override {
250 std::cerr << "readDataAvailable, len " << len << std::endl;
252 currentBuffer.length = len;
254 // close the socket before writing to trigger writeError().
255 ::close(socket_->getFd());
257 wcb_->setSocket(socket_);
259 // Write back the same data.
260 folly::test::msvcSuppressAbortOnInvalidParams([&] {
261 socket_->write(wcb_, currentBuffer.buffer, len);
264 if (wcb_->state == STATE_FAILED) {
265 setState(STATE_SUCCEEDED);
267 state = STATE_FAILED;
270 buffers.push_back(currentBuffer);
271 currentBuffer.reset();
274 void readErr(const AsyncSocketException& ex) noexcept override {
275 std::cerr << "readError " << ex.what() << std::endl;
276 // do nothing since this is expected
280 class EmptyReadCallback : public ReadCallback {
282 explicit EmptyReadCallback()
283 : ReadCallback(nullptr) {}
285 void readErr(const AsyncSocketException& ex) noexcept override {
286 std::cerr << "readError " << ex.what() << std::endl;
287 state = STATE_FAILED;
293 void readEOF() noexcept override {
294 std::cerr << "readEOF" << std::endl;
298 state = STATE_SUCCEEDED;
301 std::shared_ptr<AsyncSocket> tcpSocket_;
304 class HandshakeCallback :
305 public AsyncSSLSocket::HandshakeCB {
312 explicit HandshakeCallback(ReadCallbackBase *rcb,
313 ExpectType expect = EXPECT_SUCCESS):
314 state(STATE_WAITING),
319 const std::shared_ptr<AsyncSSLSocket> &socket) {
323 void setState(StateEnum s) {
328 // Functions inherited from AsyncSSLSocketHandshakeCallback
329 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
330 std::lock_guard<std::mutex> g(mutex_);
332 EXPECT_EQ(sock, socket_.get());
333 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
334 rcb_->setSocket(socket_);
335 sock->setReadCB(rcb_);
336 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
338 void handshakeErr(AsyncSSLSocket* /* sock */,
339 const AsyncSocketException& ex) noexcept override {
340 std::lock_guard<std::mutex> g(mutex_);
342 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
343 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
344 if (expect_ == EXPECT_ERROR) {
345 // rcb will never be invoked
346 rcb_->setState(STATE_SUCCEEDED);
348 errorString_ = ex.what();
351 void waitForHandshake() {
352 std::unique_lock<std::mutex> lock(mutex_);
353 cv_.wait(lock, [this] { return state != STATE_WAITING; });
356 ~HandshakeCallback() {
357 EXPECT_EQ(STATE_SUCCEEDED, state);
362 state = STATE_SUCCEEDED;
365 std::shared_ptr<AsyncSSLSocket> getSocket() {
370 std::shared_ptr<AsyncSSLSocket> socket_;
371 ReadCallbackBase *rcb_;
374 std::condition_variable cv_;
375 std::string errorString_;
378 class SSLServerAcceptCallbackBase:
379 public folly::AsyncServerSocket::AcceptCallback {
381 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
382 state(STATE_WAITING), hcb_(hcb) {}
384 ~SSLServerAcceptCallbackBase() {
385 EXPECT_EQ(STATE_SUCCEEDED, state);
388 void acceptError(const std::exception& ex) noexcept override {
389 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
390 << ex.what() << std::endl;
391 state = STATE_FAILED;
394 void connectionAccepted(
395 int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
397 socket_->detachEventBase();
399 printf("Connection accepted\n");
401 // Create a AsyncSSLSocket object with the fd. The socket should be
402 // added to the event base and in the state of accepting SSL connection.
403 socket_ = AsyncSSLSocket::newSocket(ctx_, base_, fd);
404 } catch (const std::exception &e) {
405 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
406 "object with socket " << e.what() << fd;
412 connAccepted(socket_);
415 virtual void connAccepted(
416 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
419 socket_->detachEventBase();
423 HandshakeCallback *hcb_;
424 std::shared_ptr<folly::SSLContext> ctx_;
425 std::shared_ptr<AsyncSSLSocket> socket_;
426 folly::EventBase* base_;
429 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
433 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
434 uint32_t timeout = 0):
435 SSLServerAcceptCallbackBase(hcb),
438 virtual ~SSLServerAcceptCallback() {
440 // if we set a timeout, we expect failure
441 EXPECT_EQ(hcb_->state, STATE_FAILED);
442 hcb_->setState(STATE_SUCCEEDED);
446 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
448 const std::shared_ptr<folly::AsyncSSLSocket> &s)
450 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
451 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
453 hcb_->setSocket(sock);
454 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
455 EXPECT_EQ(sock->getSSLState(),
456 AsyncSSLSocket::STATE_ACCEPTING);
458 state = STATE_SUCCEEDED;
462 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
464 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
465 SSLServerAcceptCallback(hcb) {}
467 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
469 const std::shared_ptr<folly::AsyncSSLSocket> &s)
472 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
474 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
476 int fd = sock->getFd();
480 // The accepted connection should already have TCP_NODELAY set
482 socklen_t valueLength = sizeof(value);
483 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
489 // Unset the TCP_NODELAY option.
491 socklen_t valueLength = sizeof(value);
492 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
495 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
499 SSLServerAcceptCallback::connAccepted(sock);
503 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
505 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
506 uint32_t timeout = 0):
507 SSLServerAcceptCallback(hcb, timeout) {}
509 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
511 const std::shared_ptr<folly::AsyncSSLSocket> &s)
513 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
515 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
517 hcb_->setSocket(sock);
518 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
519 ASSERT_TRUE((sock->getSSLState() ==
520 AsyncSSLSocket::STATE_ACCEPTING) ||
521 (sock->getSSLState() ==
522 AsyncSSLSocket::STATE_CACHE_LOOKUP));
524 state = STATE_SUCCEEDED;
529 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
531 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
532 SSLServerAcceptCallbackBase(hcb) {}
534 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
536 const std::shared_ptr<folly::AsyncSSLSocket> &s)
538 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
540 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
542 // The first call to sslAccept() should succeed.
543 hcb_->setSocket(sock);
544 sock->sslAccept(hcb_);
545 EXPECT_EQ(sock->getSSLState(),
546 AsyncSSLSocket::STATE_ACCEPTING);
548 // The second call to sslAccept() should fail.
549 HandshakeCallback callback2(hcb_->rcb_);
550 callback2.setSocket(sock);
551 sock->sslAccept(&callback2);
552 EXPECT_EQ(sock->getSSLState(),
553 AsyncSSLSocket::STATE_ERROR);
555 // Both callbacks should be in the error state.
556 EXPECT_EQ(hcb_->state, STATE_FAILED);
557 EXPECT_EQ(callback2.state, STATE_FAILED);
559 state = STATE_SUCCEEDED;
560 hcb_->setState(STATE_SUCCEEDED);
561 callback2.setState(STATE_SUCCEEDED);
565 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
567 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
568 SSLServerAcceptCallbackBase(hcb) {}
570 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
572 const std::shared_ptr<folly::AsyncSSLSocket> &s)
574 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
576 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
578 hcb_->setSocket(sock);
579 sock->getEventBase()->tryRunAfterDelay([=] {
580 std::cerr << "Delayed SSL accept, client will have close by now"
582 // SSL accept will fail
585 AsyncSSLSocket::STATE_UNINIT);
586 hcb_->socket_->sslAccept(hcb_);
587 // This registers for an event
590 AsyncSSLSocket::STATE_ACCEPTING);
592 state = STATE_SUCCEEDED;
597 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
599 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
600 // We don't care if we get invoked or not.
601 // The client may time out and give up before connAccepted() is even
603 state = STATE_SUCCEEDED;
606 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
608 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
609 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
611 // Just wait a while before closing the socket, so the client
612 // will time out waiting for the handshake to complete.
613 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
617 class TestSSLServer {
620 std::shared_ptr<folly::SSLContext> ctx_;
621 SSLServerAcceptCallbackBase *acb_;
622 std::shared_ptr<folly::AsyncServerSocket> socket_;
623 folly::SocketAddress address_;
626 static void *Main(void *ctx) {
627 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
629 self->acb_->detach();
630 std::cerr << "Server thread exited event loop" << std::endl;
635 // Create a TestSSLServer.
636 // This immediately starts listening on the given port.
637 explicit TestSSLServer(
638 SSLServerAcceptCallbackBase* acb,
639 bool enableTFO = false);
643 evb_.runInEventBaseThread([&](){
644 socket_->stopAccepting();
646 std::cerr << "Waiting for server thread to exit" << std::endl;
647 pthread_join(thread_, nullptr);
650 EventBase &getEventBase() { return evb_; }
652 const folly::SocketAddress& getAddress() const {
657 class TestSSLAsyncCacheServer : public TestSSLServer {
659 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
660 int lookupDelay = 100) :
662 SSL_CTX *sslCtx = ctx_->getSSLCtx();
663 SSL_CTX_sess_set_get_cb(sslCtx,
664 TestSSLAsyncCacheServer::getSessionCallback);
665 SSL_CTX_set_session_cache_mode(
666 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
669 lookupDelay_ = lookupDelay;
672 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
673 uint32_t getAsyncLookups() const { return asyncLookups_; }
676 static uint32_t asyncCallbacks_;
677 static uint32_t asyncLookups_;
678 static uint32_t lookupDelay_;
680 static SSL_SESSION* getSessionCallback(SSL* ssl,
681 unsigned char* /* sess_id */,
687 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
688 if (!SSL_want_sess_cache_lookup(ssl)) {
689 // libssl.so mismatch
690 std::cerr << "no async support" << std::endl;
694 AsyncSSLSocket *sslSocket =
695 AsyncSSLSocket::getFromSSL(ssl);
696 assert(sslSocket != nullptr);
697 // Going to simulate an async cache by just running delaying the miss 100ms
698 if (asyncCallbacks_ % 2 == 0) {
699 // This socket is already blocked on lookup, return miss
700 std::cerr << "returning miss" << std::endl;
702 // fresh meat - block it
703 std::cerr << "async lookup" << std::endl;
704 sslSocket->getEventBase()->tryRunAfterDelay(
705 std::bind(&AsyncSSLSocket::restartSSLAccept,
706 sslSocket), lookupDelay_);
707 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
715 void getfds(int fds[2]);
718 std::shared_ptr<folly::SSLContext> clientCtx,
719 std::shared_ptr<folly::SSLContext> serverCtx);
722 EventBase* eventBase,
723 AsyncSSLSocket::UniquePtr* clientSock,
724 AsyncSSLSocket::UniquePtr* serverSock);
726 class BlockingWriteClient :
727 private AsyncSSLSocket::HandshakeCB,
728 private AsyncTransportWrapper::WriteCallback {
730 explicit BlockingWriteClient(
731 AsyncSSLSocket::UniquePtr socket)
732 : socket_(std::move(socket)),
736 buf_.reset(new uint8_t[bufLen_]);
737 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
742 iov_.reset(new struct iovec[iovCount_]);
743 for (uint32_t n = 0; n < iovCount_; ++n) {
744 iov_[n].iov_base = buf_.get() + n;
746 iov_[n].iov_len = n % bufLen_;
748 iov_[n].iov_len = bufLen_ - (n % bufLen_);
752 socket_->sslConn(this, std::chrono::milliseconds(100));
755 struct iovec* getIovec() const {
758 uint32_t getIovecCount() const {
763 void handshakeSuc(AsyncSSLSocket*) noexcept override {
764 socket_->writev(this, iov_.get(), iovCount_);
768 const AsyncSocketException& ex) noexcept override {
769 ADD_FAILURE() << "client handshake error: " << ex.what();
771 void writeSuccess() noexcept override {
776 const AsyncSocketException& ex) noexcept override {
777 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
781 AsyncSSLSocket::UniquePtr socket_;
784 std::unique_ptr<uint8_t[]> buf_;
785 std::unique_ptr<struct iovec[]> iov_;
788 class BlockingWriteServer :
789 private AsyncSSLSocket::HandshakeCB,
790 private AsyncTransportWrapper::ReadCallback {
792 explicit BlockingWriteServer(
793 AsyncSSLSocket::UniquePtr socket)
794 : socket_(std::move(socket)),
795 bufSize_(2500 * 2000),
797 buf_.reset(new uint8_t[bufSize_]);
798 socket_->sslAccept(this, std::chrono::milliseconds(100));
801 void checkBuffer(struct iovec* iov, uint32_t count) const {
803 for (uint32_t n = 0; n < count; ++n) {
804 size_t bytesLeft = bytesRead_ - idx;
805 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
806 std::min(iov[n].iov_len, bytesLeft));
808 FAIL() << "buffer mismatch at iovec " << n << "/" << count
812 if (iov[n].iov_len > bytesLeft) {
813 FAIL() << "server did not read enough data: "
814 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
815 << " in iovec " << n << "/" << count;
818 idx += iov[n].iov_len;
820 if (idx != bytesRead_) {
821 ADD_FAILURE() << "server read extra data: " << bytesRead_
822 << " bytes read; expected " << idx;
827 void handshakeSuc(AsyncSSLSocket*) noexcept override {
828 // Wait 10ms before reading, so the client's writes will initially block.
829 socket_->getEventBase()->tryRunAfterDelay(
830 [this] { socket_->setReadCB(this); }, 10);
834 const AsyncSocketException& ex) noexcept override {
835 ADD_FAILURE() << "server handshake error: " << ex.what();
837 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
838 *bufReturn = buf_.get() + bytesRead_;
839 *lenReturn = bufSize_ - bytesRead_;
841 void readDataAvailable(size_t len) noexcept override {
843 socket_->setReadCB(nullptr);
844 socket_->getEventBase()->tryRunAfterDelay(
845 [this] { socket_->setReadCB(this); }, 2);
847 void readEOF() noexcept override {
851 const AsyncSocketException& ex) noexcept override {
852 ADD_FAILURE() << "server read error: " << ex.what();
855 AsyncSSLSocket::UniquePtr socket_;
858 std::unique_ptr<uint8_t[]> buf_;
862 private AsyncSSLSocket::HandshakeCB,
863 private AsyncTransportWrapper::WriteCallback {
866 AsyncSSLSocket::UniquePtr socket)
867 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
868 socket_->sslConn(this);
871 const unsigned char* nextProto;
872 unsigned nextProtoLength;
873 SSLContext::NextProtocolType protocolType;
876 void handshakeSuc(AsyncSSLSocket*) noexcept override {
877 socket_->getSelectedNextProtocol(
878 &nextProto, &nextProtoLength, &protocolType);
882 const AsyncSocketException& ex) noexcept override {
883 ADD_FAILURE() << "client handshake error: " << ex.what();
885 void writeSuccess() noexcept override {
890 const AsyncSocketException& ex) noexcept override {
891 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
895 AsyncSSLSocket::UniquePtr socket_;
899 private AsyncSSLSocket::HandshakeCB,
900 private AsyncTransportWrapper::ReadCallback {
902 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
903 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
904 socket_->sslAccept(this);
907 const unsigned char* nextProto;
908 unsigned nextProtoLength;
909 SSLContext::NextProtocolType protocolType;
912 void handshakeSuc(AsyncSSLSocket*) noexcept override {
913 socket_->getSelectedNextProtocol(
914 &nextProto, &nextProtoLength, &protocolType);
918 const AsyncSocketException& ex) noexcept override {
919 ADD_FAILURE() << "server handshake error: " << ex.what();
921 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
924 void readDataAvailable(size_t /* len */) noexcept override {}
925 void readEOF() noexcept override {
929 const AsyncSocketException& ex) noexcept override {
930 ADD_FAILURE() << "server read error: " << ex.what();
933 AsyncSSLSocket::UniquePtr socket_;
936 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
937 public AsyncTransportWrapper::ReadCallback {
939 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
940 : socket_(std::move(socket)) {
941 socket_->sslAccept(this);
944 ~RenegotiatingServer() {
945 socket_->setReadCB(nullptr);
948 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
949 LOG(INFO) << "Renegotiating server handshake success";
950 socket_->setReadCB(this);
954 const AsyncSocketException& ex) noexcept override {
955 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
957 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
958 *lenReturn = sizeof(buf);
961 void readDataAvailable(size_t /* len */) noexcept override {}
962 void readEOF() noexcept override {}
963 void readErr(const AsyncSocketException& ex) noexcept override {
964 LOG(INFO) << "server got read error " << ex.what();
965 auto exPtr = dynamic_cast<const SSLException*>(&ex);
966 ASSERT_NE(nullptr, exPtr);
967 std::string exStr(ex.what());
968 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
969 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
970 renegotiationError_ = true;
973 AsyncSSLSocket::UniquePtr socket_;
974 unsigned char buf[128];
975 bool renegotiationError_{false};
978 #ifndef OPENSSL_NO_TLSEXT
980 private AsyncSSLSocket::HandshakeCB,
981 private AsyncTransportWrapper::WriteCallback {
984 AsyncSSLSocket::UniquePtr socket)
985 : serverNameMatch(false), socket_(std::move(socket)) {
986 socket_->sslConn(this);
989 bool serverNameMatch;
992 void handshakeSuc(AsyncSSLSocket*) noexcept override {
993 serverNameMatch = socket_->isServerNameMatch();
997 const AsyncSocketException& ex) noexcept override {
998 ADD_FAILURE() << "client handshake error: " << ex.what();
1000 void writeSuccess() noexcept override {
1004 size_t bytesWritten,
1005 const AsyncSocketException& ex) noexcept override {
1006 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1010 AsyncSSLSocket::UniquePtr socket_;
1014 private AsyncSSLSocket::HandshakeCB,
1015 private AsyncTransportWrapper::ReadCallback {
1018 AsyncSSLSocket::UniquePtr socket,
1019 const std::shared_ptr<folly::SSLContext>& ctx,
1020 const std::shared_ptr<folly::SSLContext>& sniCtx,
1021 const std::string& expectedServerName)
1022 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1023 expectedServerName_(expectedServerName) {
1024 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1025 std::placeholders::_1));
1026 socket_->sslAccept(this);
1029 bool serverNameMatch;
1032 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1035 const AsyncSocketException& ex) noexcept override {
1036 ADD_FAILURE() << "server handshake error: " << ex.what();
1038 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1041 void readDataAvailable(size_t /* len */) noexcept override {}
1042 void readEOF() noexcept override {
1046 const AsyncSocketException& ex) noexcept override {
1047 ADD_FAILURE() << "server read error: " << ex.what();
1050 folly::SSLContext::ServerNameCallbackResult
1051 serverNameCallback(SSL *ssl) {
1052 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1055 !strcasecmp(expectedServerName_.c_str(), sn)) {
1056 AsyncSSLSocket *sslSocket =
1057 AsyncSSLSocket::getFromSSL(ssl);
1058 sslSocket->switchServerSSLContext(sniCtx_);
1059 serverNameMatch = true;
1060 return folly::SSLContext::SERVER_NAME_FOUND;
1062 serverNameMatch = false;
1063 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1067 AsyncSSLSocket::UniquePtr socket_;
1068 std::shared_ptr<folly::SSLContext> sniCtx_;
1069 std::string expectedServerName_;
1073 class SSLClient : public AsyncSocket::ConnectCallback,
1074 public AsyncTransportWrapper::WriteCallback,
1075 public AsyncTransportWrapper::ReadCallback
1078 EventBase *eventBase_;
1079 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1080 SSL_SESSION *session_;
1081 std::shared_ptr<folly::SSLContext> ctx_;
1083 folly::SocketAddress address_;
1087 uint32_t bytesRead_;
1091 uint32_t writeAfterConnectErrors_;
1093 // These settings test that we eventually drain the
1094 // socket, even if the maxReadsPerEvent_ is hit during
1095 // a event loop iteration.
1096 static constexpr size_t kMaxReadsPerEvent = 2;
1097 // 2 event loop iterations
1098 static constexpr size_t kMaxReadBufferSz =
1099 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1102 SSLClient(EventBase *eventBase,
1103 const folly::SocketAddress& address,
1105 uint32_t timeout = 0)
1106 : eventBase_(eventBase),
1108 requests_(requests),
1115 writeAfterConnectErrors_(0) {
1116 ctx_.reset(new folly::SSLContext());
1117 ctx_->setOptions(SSL_OP_NO_TICKET);
1118 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1119 memset(buf_, 'a', sizeof(buf_));
1124 SSL_SESSION_free(session_);
1127 EXPECT_EQ(bytesRead_, sizeof(buf_));
1131 uint32_t getHit() const { return hit_; }
1133 uint32_t getMiss() const { return miss_; }
1135 uint32_t getErrors() const { return errors_; }
1137 uint32_t getWriteAfterConnectErrors() const {
1138 return writeAfterConnectErrors_;
1141 void connect(bool writeNow = false) {
1142 sslSocket_ = AsyncSSLSocket::newSocket(
1144 if (session_ != nullptr) {
1145 sslSocket_->setSSLSession(session_);
1148 sslSocket_->connect(this, address_, timeout_);
1149 if (sslSocket_ && writeNow) {
1150 // write some junk, used in an error test
1151 sslSocket_->write(this, buf_, sizeof(buf_));
1155 void connectSuccess() noexcept override {
1156 std::cerr << "client SSL socket connected" << std::endl;
1157 if (sslSocket_->getSSLSessionReused()) {
1161 if (session_ != nullptr) {
1162 SSL_SESSION_free(session_);
1164 session_ = sslSocket_->getSSLSession();
1168 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1169 sslSocket_->write(this, buf_, sizeof(buf_));
1170 sslSocket_->setReadCB(this);
1171 memset(readbuf_, 'b', sizeof(readbuf_));
1176 const AsyncSocketException& ex) noexcept override {
1177 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1182 void writeSuccess() noexcept override {
1183 std::cerr << "client write success" << std::endl;
1186 void writeErr(size_t /* bytesWritten */,
1187 const AsyncSocketException& ex) noexcept override {
1188 std::cerr << "client writeError: " << ex.what() << std::endl;
1190 writeAfterConnectErrors_++;
1194 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1195 *bufReturn = readbuf_ + bytesRead_;
1196 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1199 void readEOF() noexcept override {
1200 std::cerr << "client readEOF" << std::endl;
1204 const AsyncSocketException& ex) noexcept override {
1205 std::cerr << "client readError: " << ex.what() << std::endl;
1208 void readDataAvailable(size_t len) noexcept override {
1209 std::cerr << "client read data: " << len << std::endl;
1211 if (bytesRead_ == sizeof(buf_)) {
1212 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1213 sslSocket_->closeNow();
1215 if (requests_ != 0) {
1223 class SSLHandshakeBase :
1224 public AsyncSSLSocket::HandshakeCB,
1225 private AsyncTransportWrapper::WriteCallback {
1227 explicit SSLHandshakeBase(
1228 AsyncSSLSocket::UniquePtr socket,
1229 bool preverifyResult,
1230 bool verifyResult) :
1231 handshakeVerify_(false),
1232 handshakeSuccess_(false),
1233 handshakeError_(false),
1234 socket_(std::move(socket)),
1235 preverifyResult_(preverifyResult),
1236 verifyResult_(verifyResult) {
1239 AsyncSSLSocket::UniquePtr moveSocket() && {
1240 return std::move(socket_);
1243 bool handshakeVerify_;
1244 bool handshakeSuccess_;
1245 bool handshakeError_;
1246 std::chrono::nanoseconds handshakeTime;
1249 AsyncSSLSocket::UniquePtr socket_;
1250 bool preverifyResult_;
1253 // HandshakeCallback
1254 bool handshakeVer(AsyncSSLSocket* /* sock */,
1256 X509_STORE_CTX* /* ctx */) noexcept override {
1257 handshakeVerify_ = true;
1259 EXPECT_EQ(preverifyResult_, preverifyOk);
1260 return verifyResult_;
1263 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1264 LOG(INFO) << "Handshake success";
1265 handshakeSuccess_ = true;
1266 handshakeTime = socket_->getHandshakeTime();
1271 const AsyncSocketException& ex) noexcept override {
1272 LOG(INFO) << "Handshake error " << ex.what();
1273 handshakeError_ = true;
1274 handshakeTime = socket_->getHandshakeTime();
1278 void writeSuccess() noexcept override {
1283 size_t bytesWritten,
1284 const AsyncSocketException& ex) noexcept override {
1285 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1290 class SSLHandshakeClient : public SSLHandshakeBase {
1293 AsyncSSLSocket::UniquePtr socket,
1294 bool preverifyResult,
1295 bool verifyResult) :
1296 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1297 socket_->sslConn(this, std::chrono::milliseconds::zero());
1301 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1303 SSLHandshakeClientNoVerify(
1304 AsyncSSLSocket::UniquePtr socket,
1305 bool preverifyResult,
1306 bool verifyResult) :
1307 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1310 std::chrono::milliseconds::zero(),
1311 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1315 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1317 SSLHandshakeClientDoVerify(
1318 AsyncSSLSocket::UniquePtr socket,
1319 bool preverifyResult,
1320 bool verifyResult) :
1321 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1324 std::chrono::milliseconds::zero(),
1325 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1329 class SSLHandshakeServer : public SSLHandshakeBase {
1332 AsyncSSLSocket::UniquePtr socket,
1333 bool preverifyResult,
1335 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1336 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1340 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1342 SSLHandshakeServerParseClientHello(
1343 AsyncSSLSocket::UniquePtr socket,
1344 bool preverifyResult,
1346 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1347 socket_->enableClientHelloParsing();
1348 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1351 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1354 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1355 handshakeSuccess_ = true;
1356 sock->getSSLSharedCiphers(sharedCiphers_);
1357 sock->getSSLServerCiphers(serverCiphers_);
1358 sock->getSSLClientCiphers(clientCiphers_);
1359 chosenCipher_ = sock->getNegotiatedCipherName();
1364 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1366 SSLHandshakeServerNoVerify(
1367 AsyncSSLSocket::UniquePtr socket,
1368 bool preverifyResult,
1370 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1373 std::chrono::milliseconds::zero(),
1374 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1378 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1380 SSLHandshakeServerDoVerify(
1381 AsyncSSLSocket::UniquePtr socket,
1382 bool preverifyResult,
1384 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1387 std::chrono::milliseconds::zero(),
1388 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1392 class EventBaseAborter : public AsyncTimeout {
1394 EventBaseAborter(EventBase* eventBase,
1397 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1398 , eventBase_(eventBase) {
1399 scheduleTimeout(timeoutMS);
1402 void timeoutExpired() noexcept override {
1403 FAIL() << "test timed out";
1404 eventBase_->terminateLoopSoon();
1408 EventBase* eventBase_;