2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
20 #include <folly/ExceptionWrapper.h>
21 #include <folly/SocketAddress.h>
22 #include <folly/experimental/TestUtil.h>
23 #include <folly/io/async/AsyncSSLSocket.h>
24 #include <folly/io/async/AsyncServerSocket.h>
25 #include <folly/io/async/AsyncSocket.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/io/async/AsyncTransport.h>
28 #include <folly/io/async/EventBase.h>
29 #include <folly/io/async/ssl/SSLErrors.h>
30 #include <folly/io/async/test/TestSSLServer.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/PThread.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
37 #include <sys/types.h>
38 #include <condition_variable>
44 // The destructors of all callback classes assert that the state is
45 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
46 // are responsible for setting the succeeded state properly before the
47 // destructors are called.
49 class SendMsgParamsCallbackBase :
50 public folly::AsyncSocket::SendMsgParamsCallback {
52 SendMsgParamsCallbackBase() {}
55 const std::shared_ptr<AsyncSSLSocket> &socket) {
57 oldCallback_ = socket_->getSendMsgParamsCB();
58 socket_->setSendMsgParamCB(this);
61 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
63 return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
66 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
67 oldCallback_->getAncillaryData(flags, data);
70 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
71 return oldCallback_->getAncillaryDataSize(flags);
74 std::shared_ptr<AsyncSSLSocket> socket_;
75 folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
78 class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
80 SendMsgFlagsCallback() {}
82 void resetFlags(int flags) {
86 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
91 return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
98 class SendMsgDataCallback : public SendMsgFlagsCallback {
100 SendMsgDataCallback() {}
102 void resetData(std::vector<char>&& data) {
103 ancillaryData_.swap(data);
106 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
107 if (ancillaryData_.size()) {
108 std::cerr << "getAncillaryData: copying data" << std::endl;
109 memcpy(data, ancillaryData_.data(), ancillaryData_.size());
111 oldCallback_->getAncillaryData(flags, data);
115 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
116 if (ancillaryData_.size()) {
117 std::cerr << "getAncillaryDataSize: returning size" << std::endl;
118 return ancillaryData_.size();
120 return oldCallback_->getAncillaryDataSize(flags);
124 std::vector<char> ancillaryData_;
127 class WriteCallbackBase :
128 public AsyncTransportWrapper::WriteCallback {
130 explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
131 : state(STATE_WAITING)
133 , exception(AsyncSocketException::UNKNOWN, "none")
136 ~WriteCallbackBase() override {
137 EXPECT_EQ(STATE_SUCCEEDED, state);
140 virtual void setSocket(
141 const std::shared_ptr<AsyncSSLSocket> &socket) {
144 mcb_->setSocket(socket);
148 void writeSuccess() noexcept override {
149 std::cerr << "writeSuccess" << std::endl;
150 state = STATE_SUCCEEDED;
154 size_t nBytesWritten,
155 const AsyncSocketException& ex) noexcept override {
156 std::cerr << "writeError: bytesWritten " << nBytesWritten
157 << ", exception " << ex.what() << std::endl;
159 state = STATE_FAILED;
160 this->bytesWritten = nBytesWritten;
165 std::shared_ptr<AsyncSSLSocket> socket_;
168 AsyncSocketException exception;
169 SendMsgParamsCallbackBase* mcb_;
172 class ExpectWriteErrorCallback :
173 public WriteCallbackBase {
175 explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
176 : WriteCallbackBase(mcb) {}
178 ~ExpectWriteErrorCallback() override {
179 EXPECT_EQ(STATE_FAILED, state);
180 EXPECT_EQ(exception.type_,
181 AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
182 EXPECT_EQ(exception.errno_, 22);
183 // Suppress the assert in ~WriteCallbackBase()
184 state = STATE_SUCCEEDED;
189 /* copied from include/uapi/linux/net_tstamp.h */
190 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
191 enum SOF_TIMESTAMPING {
192 SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
193 SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
194 SOF_TIMESTAMPING_OPT_ID = (1 << 7),
195 SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
196 SOF_TIMESTAMPING_TX_ACK = (1 << 9),
197 SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
200 class WriteCheckTimestampCallback :
201 public WriteCallbackBase {
203 explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
204 : WriteCallbackBase(mcb) {}
206 ~WriteCheckTimestampCallback() override {
207 EXPECT_EQ(STATE_SUCCEEDED, state);
208 EXPECT_TRUE(gotTimestamp_);
209 EXPECT_TRUE(gotByteSeq_);
213 const std::shared_ptr<AsyncSSLSocket> &socket) override {
214 WriteCallbackBase::setSocket(socket);
216 EXPECT_NE(socket_->getFd(), 0);
217 int flags = SOF_TIMESTAMPING_OPT_ID
218 | SOF_TIMESTAMPING_OPT_TSONLY
219 | SOF_TIMESTAMPING_SOFTWARE;
220 AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
221 int ret = tstampingOpt.apply(socket_->getFd(), flags);
225 void checkForTimestampNotifications() noexcept {
226 int fd = socket_->getFd();
227 std::vector<char> ctrl(1024, 0);
232 memset(&msg, 0, sizeof(msg));
233 entry.iov_base = &data;
234 entry.iov_len = sizeof(data);
235 msg.msg_iov = &entry;
237 msg.msg_control = ctrl.data();
238 msg.msg_controllen = ctrl.size();
242 ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
244 if (errno != EAGAIN) {
245 auto errnoCopy = errno;
246 std::cerr << "::recvmsg exited with code " << ret
247 << ", errno: " << errnoCopy << std::endl;
248 AsyncSocketException ex(
249 AsyncSocketException::INTERNAL_ERROR,
257 for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
258 cmsg != nullptr && cmsg->cmsg_len != 0;
259 cmsg = CMSG_NXTHDR(&msg, cmsg)) {
260 if (cmsg->cmsg_level == SOL_SOCKET &&
261 cmsg->cmsg_type == SCM_TIMESTAMPING) {
262 gotTimestamp_ = true;
266 if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
267 (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
275 bool gotTimestamp_{false};
276 bool gotByteSeq_{false};
278 #endif // MSG_ERRQUEUE
280 class ReadCallbackBase :
281 public AsyncTransportWrapper::ReadCallback {
283 explicit ReadCallbackBase(WriteCallbackBase* wcb)
284 : wcb_(wcb), state(STATE_WAITING) {}
286 ~ReadCallbackBase() override {
287 EXPECT_EQ(STATE_SUCCEEDED, state);
291 const std::shared_ptr<AsyncSSLSocket> &socket) {
295 void setState(StateEnum s) {
303 const AsyncSocketException& ex) noexcept override {
304 std::cerr << "readError " << ex.what() << std::endl;
305 state = STATE_FAILED;
309 void readEOF() noexcept override {
310 std::cerr << "readEOF" << std::endl;
315 std::shared_ptr<AsyncSSLSocket> socket_;
316 WriteCallbackBase *wcb_;
320 class ReadCallback : public ReadCallbackBase {
322 explicit ReadCallback(WriteCallbackBase *wcb)
323 : ReadCallbackBase(wcb)
326 ~ReadCallback() override {
327 for (std::vector<Buffer>::iterator it = buffers.begin();
332 currentBuffer.free();
335 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
336 if (!currentBuffer.buffer) {
337 currentBuffer.allocate(4096);
339 *bufReturn = currentBuffer.buffer;
340 *lenReturn = currentBuffer.length;
343 void readDataAvailable(size_t len) noexcept override {
344 std::cerr << "readDataAvailable, len " << len << std::endl;
346 currentBuffer.length = len;
348 wcb_->setSocket(socket_);
350 // Write back the same data.
351 socket_->write(wcb_, currentBuffer.buffer, len);
353 buffers.push_back(currentBuffer);
354 currentBuffer.reset();
355 state = STATE_SUCCEEDED;
360 Buffer() : buffer(nullptr), length(0) {}
361 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
367 void allocate(size_t len) {
368 assert(buffer == nullptr);
369 this->buffer = static_cast<char*>(malloc(len));
381 std::vector<Buffer> buffers;
382 Buffer currentBuffer;
385 class ReadErrorCallback : public ReadCallbackBase {
387 explicit ReadErrorCallback(WriteCallbackBase *wcb)
388 : ReadCallbackBase(wcb) {}
390 // Return nullptr buffer to trigger readError()
391 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
392 *bufReturn = nullptr;
396 void readDataAvailable(size_t /* len */) noexcept override {
397 // This should never to called.
402 const AsyncSocketException& ex) noexcept override {
403 ReadCallbackBase::readErr(ex);
404 std::cerr << "ReadErrorCallback::readError" << std::endl;
405 setState(STATE_SUCCEEDED);
409 class ReadEOFCallback : public ReadCallbackBase {
411 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
413 // Return nullptr buffer to trigger readError()
414 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
415 *bufReturn = nullptr;
419 void readDataAvailable(size_t /* len */) noexcept override {
420 // This should never to called.
424 void readEOF() noexcept override {
425 ReadCallbackBase::readEOF();
426 setState(STATE_SUCCEEDED);
430 class WriteErrorCallback : public ReadCallback {
432 explicit WriteErrorCallback(WriteCallbackBase *wcb)
433 : ReadCallback(wcb) {}
435 void readDataAvailable(size_t len) noexcept override {
436 std::cerr << "readDataAvailable, len " << len << std::endl;
438 currentBuffer.length = len;
440 // close the socket before writing to trigger writeError().
441 ::close(socket_->getFd());
443 wcb_->setSocket(socket_);
445 // Write back the same data.
446 folly::test::msvcSuppressAbortOnInvalidParams([&] {
447 socket_->write(wcb_, currentBuffer.buffer, len);
450 if (wcb_->state == STATE_FAILED) {
451 setState(STATE_SUCCEEDED);
453 state = STATE_FAILED;
456 buffers.push_back(currentBuffer);
457 currentBuffer.reset();
460 void readErr(const AsyncSocketException& ex) noexcept override {
461 std::cerr << "readError " << ex.what() << std::endl;
462 // do nothing since this is expected
466 class EmptyReadCallback : public ReadCallback {
468 explicit EmptyReadCallback()
469 : ReadCallback(nullptr) {}
471 void readErr(const AsyncSocketException& ex) noexcept override {
472 std::cerr << "readError " << ex.what() << std::endl;
473 state = STATE_FAILED;
479 void readEOF() noexcept override {
480 std::cerr << "readEOF" << std::endl;
484 state = STATE_SUCCEEDED;
487 std::shared_ptr<AsyncSocket> tcpSocket_;
490 class HandshakeCallback :
491 public AsyncSSLSocket::HandshakeCB {
498 explicit HandshakeCallback(ReadCallbackBase *rcb,
499 ExpectType expect = EXPECT_SUCCESS):
500 state(STATE_WAITING),
505 const std::shared_ptr<AsyncSSLSocket> &socket) {
509 void setState(StateEnum s) {
514 // Functions inherited from AsyncSSLSocketHandshakeCallback
515 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
516 std::lock_guard<std::mutex> g(mutex_);
518 EXPECT_EQ(sock, socket_.get());
519 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
520 rcb_->setSocket(socket_);
521 sock->setReadCB(rcb_);
522 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
524 void handshakeErr(AsyncSSLSocket* /* sock */,
525 const AsyncSocketException& ex) noexcept override {
526 std::lock_guard<std::mutex> g(mutex_);
528 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
529 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
530 if (expect_ == EXPECT_ERROR) {
531 // rcb will never be invoked
532 rcb_->setState(STATE_SUCCEEDED);
534 errorString_ = ex.what();
537 void waitForHandshake() {
538 std::unique_lock<std::mutex> lock(mutex_);
539 cv_.wait(lock, [this] { return state != STATE_WAITING; });
542 ~HandshakeCallback() override {
543 EXPECT_EQ(STATE_SUCCEEDED, state);
548 state = STATE_SUCCEEDED;
551 std::shared_ptr<AsyncSSLSocket> getSocket() {
556 std::shared_ptr<AsyncSSLSocket> socket_;
557 ReadCallbackBase *rcb_;
560 std::condition_variable cv_;
561 std::string errorString_;
564 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
568 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
569 uint32_t timeout = 0):
570 SSLServerAcceptCallbackBase(hcb),
573 ~SSLServerAcceptCallback() override {
575 // if we set a timeout, we expect failure
576 EXPECT_EQ(hcb_->state, STATE_FAILED);
577 hcb_->setState(STATE_SUCCEEDED);
582 const std::shared_ptr<folly::AsyncSSLSocket> &s)
584 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
585 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
587 hcb_->setSocket(sock);
588 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
589 EXPECT_EQ(sock->getSSLState(),
590 AsyncSSLSocket::STATE_ACCEPTING);
592 state = STATE_SUCCEEDED;
596 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
598 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
599 SSLServerAcceptCallback(hcb) {}
602 const std::shared_ptr<folly::AsyncSSLSocket> &s)
605 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
607 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
609 int fd = sock->getFd();
613 // The accepted connection should already have TCP_NODELAY set
615 socklen_t valueLength = sizeof(value);
616 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
622 // Unset the TCP_NODELAY option.
624 socklen_t valueLength = sizeof(value);
625 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
628 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
632 SSLServerAcceptCallback::connAccepted(sock);
636 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
638 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
639 uint32_t timeout = 0):
640 SSLServerAcceptCallback(hcb, timeout) {}
643 const std::shared_ptr<folly::AsyncSSLSocket> &s)
645 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
647 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
649 hcb_->setSocket(sock);
650 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
651 ASSERT_TRUE((sock->getSSLState() ==
652 AsyncSSLSocket::STATE_ACCEPTING) ||
653 (sock->getSSLState() ==
654 AsyncSSLSocket::STATE_CACHE_LOOKUP));
656 state = STATE_SUCCEEDED;
661 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
663 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
664 SSLServerAcceptCallbackBase(hcb) {}
667 const std::shared_ptr<folly::AsyncSSLSocket> &s)
669 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
671 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
673 // The first call to sslAccept() should succeed.
674 hcb_->setSocket(sock);
675 sock->sslAccept(hcb_);
676 EXPECT_EQ(sock->getSSLState(),
677 AsyncSSLSocket::STATE_ACCEPTING);
679 // The second call to sslAccept() should fail.
680 HandshakeCallback callback2(hcb_->rcb_);
681 callback2.setSocket(sock);
682 sock->sslAccept(&callback2);
683 EXPECT_EQ(sock->getSSLState(),
684 AsyncSSLSocket::STATE_ERROR);
686 // Both callbacks should be in the error state.
687 EXPECT_EQ(hcb_->state, STATE_FAILED);
688 EXPECT_EQ(callback2.state, STATE_FAILED);
690 state = STATE_SUCCEEDED;
691 hcb_->setState(STATE_SUCCEEDED);
692 callback2.setState(STATE_SUCCEEDED);
696 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
698 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
699 SSLServerAcceptCallbackBase(hcb) {}
702 const std::shared_ptr<folly::AsyncSSLSocket> &s)
704 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
706 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
708 hcb_->setSocket(sock);
709 sock->getEventBase()->tryRunAfterDelay([=] {
710 std::cerr << "Delayed SSL accept, client will have close by now"
712 // SSL accept will fail
715 AsyncSSLSocket::STATE_UNINIT);
716 hcb_->socket_->sslAccept(hcb_);
717 // This registers for an event
720 AsyncSSLSocket::STATE_ACCEPTING);
722 state = STATE_SUCCEEDED;
727 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
729 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
730 // We don't care if we get invoked or not.
731 // The client may time out and give up before connAccepted() is even
733 state = STATE_SUCCEEDED;
737 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
738 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
740 // Just wait a while before closing the socket, so the client
741 // will time out waiting for the handshake to complete.
742 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
746 class TestSSLAsyncCacheServer : public TestSSLServer {
748 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
749 int lookupDelay = 100) :
751 SSL_CTX *sslCtx = ctx_->getSSLCtx();
752 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
753 SSL_CTX_sess_set_get_cb(sslCtx,
754 TestSSLAsyncCacheServer::getSessionCallback);
756 SSL_CTX_set_session_cache_mode(
757 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
760 lookupDelay_ = lookupDelay;
763 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
764 uint32_t getAsyncLookups() const { return asyncLookups_; }
767 static uint32_t asyncCallbacks_;
768 static uint32_t asyncLookups_;
769 static uint32_t lookupDelay_;
771 static SSL_SESSION* getSessionCallback(SSL* ssl,
772 unsigned char* /* sess_id */,
778 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
779 if (!SSL_want_sess_cache_lookup(ssl)) {
780 // libssl.so mismatch
781 std::cerr << "no async support" << std::endl;
785 AsyncSSLSocket *sslSocket =
786 AsyncSSLSocket::getFromSSL(ssl);
787 assert(sslSocket != nullptr);
788 // Going to simulate an async cache by just running delaying the miss 100ms
789 if (asyncCallbacks_ % 2 == 0) {
790 // This socket is already blocked on lookup, return miss
791 std::cerr << "returning miss" << std::endl;
793 // fresh meat - block it
794 std::cerr << "async lookup" << std::endl;
795 sslSocket->getEventBase()->tryRunAfterDelay(
796 std::bind(&AsyncSSLSocket::restartSSLAccept,
797 sslSocket), lookupDelay_);
798 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
806 void getfds(int fds[2]);
809 std::shared_ptr<folly::SSLContext> clientCtx,
810 std::shared_ptr<folly::SSLContext> serverCtx);
813 EventBase* eventBase,
814 AsyncSSLSocket::UniquePtr* clientSock,
815 AsyncSSLSocket::UniquePtr* serverSock);
817 class BlockingWriteClient :
818 private AsyncSSLSocket::HandshakeCB,
819 private AsyncTransportWrapper::WriteCallback {
821 explicit BlockingWriteClient(
822 AsyncSSLSocket::UniquePtr socket)
823 : socket_(std::move(socket)),
827 buf_.reset(new uint8_t[bufLen_]);
828 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
833 iov_.reset(new struct iovec[iovCount_]);
834 for (uint32_t n = 0; n < iovCount_; ++n) {
835 iov_[n].iov_base = buf_.get() + n;
837 iov_[n].iov_len = n % bufLen_;
839 iov_[n].iov_len = bufLen_ - (n % bufLen_);
843 socket_->sslConn(this, std::chrono::milliseconds(100));
846 struct iovec* getIovec() const {
849 uint32_t getIovecCount() const {
854 void handshakeSuc(AsyncSSLSocket*) noexcept override {
855 socket_->writev(this, iov_.get(), iovCount_);
859 const AsyncSocketException& ex) noexcept override {
860 ADD_FAILURE() << "client handshake error: " << ex.what();
862 void writeSuccess() noexcept override {
867 const AsyncSocketException& ex) noexcept override {
868 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
872 AsyncSSLSocket::UniquePtr socket_;
875 std::unique_ptr<uint8_t[]> buf_;
876 std::unique_ptr<struct iovec[]> iov_;
879 class BlockingWriteServer :
880 private AsyncSSLSocket::HandshakeCB,
881 private AsyncTransportWrapper::ReadCallback {
883 explicit BlockingWriteServer(
884 AsyncSSLSocket::UniquePtr socket)
885 : socket_(std::move(socket)),
886 bufSize_(2500 * 2000),
888 buf_.reset(new uint8_t[bufSize_]);
889 socket_->sslAccept(this, std::chrono::milliseconds(100));
892 void checkBuffer(struct iovec* iov, uint32_t count) const {
894 for (uint32_t n = 0; n < count; ++n) {
895 size_t bytesLeft = bytesRead_ - idx;
896 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
897 std::min(iov[n].iov_len, bytesLeft));
899 FAIL() << "buffer mismatch at iovec " << n << "/" << count
903 if (iov[n].iov_len > bytesLeft) {
904 FAIL() << "server did not read enough data: "
905 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
906 << " in iovec " << n << "/" << count;
909 idx += iov[n].iov_len;
911 if (idx != bytesRead_) {
912 ADD_FAILURE() << "server read extra data: " << bytesRead_
913 << " bytes read; expected " << idx;
918 void handshakeSuc(AsyncSSLSocket*) noexcept override {
919 // Wait 10ms before reading, so the client's writes will initially block.
920 socket_->getEventBase()->tryRunAfterDelay(
921 [this] { socket_->setReadCB(this); }, 10);
925 const AsyncSocketException& ex) noexcept override {
926 ADD_FAILURE() << "server handshake error: " << ex.what();
928 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
929 *bufReturn = buf_.get() + bytesRead_;
930 *lenReturn = bufSize_ - bytesRead_;
932 void readDataAvailable(size_t len) noexcept override {
934 socket_->setReadCB(nullptr);
935 socket_->getEventBase()->tryRunAfterDelay(
936 [this] { socket_->setReadCB(this); }, 2);
938 void readEOF() noexcept override {
942 const AsyncSocketException& ex) noexcept override {
943 ADD_FAILURE() << "server read error: " << ex.what();
946 AsyncSSLSocket::UniquePtr socket_;
949 std::unique_ptr<uint8_t[]> buf_;
953 private AsyncSSLSocket::HandshakeCB,
954 private AsyncTransportWrapper::WriteCallback {
957 AsyncSSLSocket::UniquePtr socket)
958 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
959 socket_->sslConn(this);
962 const unsigned char* nextProto;
963 unsigned nextProtoLength;
964 SSLContext::NextProtocolType protocolType;
965 folly::Optional<AsyncSocketException> except;
968 void handshakeSuc(AsyncSSLSocket*) noexcept override {
969 socket_->getSelectedNextProtocol(
970 &nextProto, &nextProtoLength, &protocolType);
974 const AsyncSocketException& ex) noexcept override {
977 void writeSuccess() noexcept override {
982 const AsyncSocketException& ex) noexcept override {
983 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
987 AsyncSSLSocket::UniquePtr socket_;
991 private AsyncSSLSocket::HandshakeCB,
992 private AsyncTransportWrapper::ReadCallback {
994 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
995 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
996 socket_->sslAccept(this);
999 const unsigned char* nextProto;
1000 unsigned nextProtoLength;
1001 SSLContext::NextProtocolType protocolType;
1002 folly::Optional<AsyncSocketException> except;
1005 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1006 socket_->getSelectedNextProtocol(
1007 &nextProto, &nextProtoLength, &protocolType);
1011 const AsyncSocketException& ex) noexcept override {
1014 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1017 void readDataAvailable(size_t /* len */) noexcept override {}
1018 void readEOF() noexcept override {
1022 const AsyncSocketException& ex) noexcept override {
1023 ADD_FAILURE() << "server read error: " << ex.what();
1026 AsyncSSLSocket::UniquePtr socket_;
1029 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
1030 public AsyncTransportWrapper::ReadCallback {
1032 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
1033 : socket_(std::move(socket)) {
1034 socket_->sslAccept(this);
1037 ~RenegotiatingServer() override {
1038 socket_->setReadCB(nullptr);
1041 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
1042 LOG(INFO) << "Renegotiating server handshake success";
1043 socket_->setReadCB(this);
1047 const AsyncSocketException& ex) noexcept override {
1048 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
1050 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1051 *lenReturn = sizeof(buf);
1054 void readDataAvailable(size_t /* len */) noexcept override {}
1055 void readEOF() noexcept override {}
1056 void readErr(const AsyncSocketException& ex) noexcept override {
1057 LOG(INFO) << "server got read error " << ex.what();
1058 auto exPtr = dynamic_cast<const SSLException*>(&ex);
1059 ASSERT_NE(nullptr, exPtr);
1060 std::string exStr(ex.what());
1061 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
1062 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
1063 renegotiationError_ = true;
1066 AsyncSSLSocket::UniquePtr socket_;
1067 unsigned char buf[128];
1068 bool renegotiationError_{false};
1071 #ifndef OPENSSL_NO_TLSEXT
1073 private AsyncSSLSocket::HandshakeCB,
1074 private AsyncTransportWrapper::WriteCallback {
1077 AsyncSSLSocket::UniquePtr socket)
1078 : serverNameMatch(false), socket_(std::move(socket)) {
1079 socket_->sslConn(this);
1082 bool serverNameMatch;
1085 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1086 serverNameMatch = socket_->isServerNameMatch();
1090 const AsyncSocketException& ex) noexcept override {
1091 ADD_FAILURE() << "client handshake error: " << ex.what();
1093 void writeSuccess() noexcept override {
1097 size_t bytesWritten,
1098 const AsyncSocketException& ex) noexcept override {
1099 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1103 AsyncSSLSocket::UniquePtr socket_;
1107 private AsyncSSLSocket::HandshakeCB,
1108 private AsyncTransportWrapper::ReadCallback {
1111 AsyncSSLSocket::UniquePtr socket,
1112 const std::shared_ptr<folly::SSLContext>& ctx,
1113 const std::shared_ptr<folly::SSLContext>& sniCtx,
1114 const std::string& expectedServerName)
1115 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1116 expectedServerName_(expectedServerName) {
1117 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1118 std::placeholders::_1));
1119 socket_->sslAccept(this);
1122 bool serverNameMatch;
1125 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1128 const AsyncSocketException& ex) noexcept override {
1129 ADD_FAILURE() << "server handshake error: " << ex.what();
1131 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1134 void readDataAvailable(size_t /* len */) noexcept override {}
1135 void readEOF() noexcept override {
1139 const AsyncSocketException& ex) noexcept override {
1140 ADD_FAILURE() << "server read error: " << ex.what();
1143 folly::SSLContext::ServerNameCallbackResult
1144 serverNameCallback(SSL *ssl) {
1145 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1148 !strcasecmp(expectedServerName_.c_str(), sn)) {
1149 AsyncSSLSocket *sslSocket =
1150 AsyncSSLSocket::getFromSSL(ssl);
1151 sslSocket->switchServerSSLContext(sniCtx_);
1152 serverNameMatch = true;
1153 return folly::SSLContext::SERVER_NAME_FOUND;
1155 serverNameMatch = false;
1156 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1160 AsyncSSLSocket::UniquePtr socket_;
1161 std::shared_ptr<folly::SSLContext> sniCtx_;
1162 std::string expectedServerName_;
1166 class SSLClient : public AsyncSocket::ConnectCallback,
1167 public AsyncTransportWrapper::WriteCallback,
1168 public AsyncTransportWrapper::ReadCallback
1171 EventBase *eventBase_;
1172 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1173 SSL_SESSION *session_;
1174 std::shared_ptr<folly::SSLContext> ctx_;
1176 folly::SocketAddress address_;
1180 uint32_t bytesRead_;
1184 uint32_t writeAfterConnectErrors_;
1186 // These settings test that we eventually drain the
1187 // socket, even if the maxReadsPerEvent_ is hit during
1188 // a event loop iteration.
1189 static constexpr size_t kMaxReadsPerEvent = 2;
1190 // 2 event loop iterations
1191 static constexpr size_t kMaxReadBufferSz =
1192 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1195 SSLClient(EventBase *eventBase,
1196 const folly::SocketAddress& address,
1198 uint32_t timeout = 0)
1199 : eventBase_(eventBase),
1201 requests_(requests),
1208 writeAfterConnectErrors_(0) {
1209 ctx_.reset(new folly::SSLContext());
1210 ctx_->setOptions(SSL_OP_NO_TICKET);
1211 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1212 memset(buf_, 'a', sizeof(buf_));
1215 ~SSLClient() override {
1217 SSL_SESSION_free(session_);
1220 EXPECT_EQ(bytesRead_, sizeof(buf_));
1224 uint32_t getHit() const { return hit_; }
1226 uint32_t getMiss() const { return miss_; }
1228 uint32_t getErrors() const { return errors_; }
1230 uint32_t getWriteAfterConnectErrors() const {
1231 return writeAfterConnectErrors_;
1234 void connect(bool writeNow = false) {
1235 sslSocket_ = AsyncSSLSocket::newSocket(
1237 if (session_ != nullptr) {
1238 sslSocket_->setSSLSession(session_);
1241 sslSocket_->connect(this, address_, timeout_);
1242 if (sslSocket_ && writeNow) {
1243 // write some junk, used in an error test
1244 sslSocket_->write(this, buf_, sizeof(buf_));
1248 void connectSuccess() noexcept override {
1249 std::cerr << "client SSL socket connected" << std::endl;
1250 if (sslSocket_->getSSLSessionReused()) {
1254 if (session_ != nullptr) {
1255 SSL_SESSION_free(session_);
1257 session_ = sslSocket_->getSSLSession();
1261 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1262 sslSocket_->write(this, buf_, sizeof(buf_));
1263 sslSocket_->setReadCB(this);
1264 memset(readbuf_, 'b', sizeof(readbuf_));
1269 const AsyncSocketException& ex) noexcept override {
1270 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1275 void writeSuccess() noexcept override {
1276 std::cerr << "client write success" << std::endl;
1279 void writeErr(size_t /* bytesWritten */,
1280 const AsyncSocketException& ex) noexcept override {
1281 std::cerr << "client writeError: " << ex.what() << std::endl;
1283 writeAfterConnectErrors_++;
1287 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1288 *bufReturn = readbuf_ + bytesRead_;
1289 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1292 void readEOF() noexcept override {
1293 std::cerr << "client readEOF" << std::endl;
1297 const AsyncSocketException& ex) noexcept override {
1298 std::cerr << "client readError: " << ex.what() << std::endl;
1301 void readDataAvailable(size_t len) noexcept override {
1302 std::cerr << "client read data: " << len << std::endl;
1304 if (bytesRead_ == sizeof(buf_)) {
1305 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1306 sslSocket_->closeNow();
1308 if (requests_ != 0) {
1316 class SSLHandshakeBase :
1317 public AsyncSSLSocket::HandshakeCB,
1318 private AsyncTransportWrapper::WriteCallback {
1320 explicit SSLHandshakeBase(
1321 AsyncSSLSocket::UniquePtr socket,
1322 bool preverifyResult,
1323 bool verifyResult) :
1324 handshakeVerify_(false),
1325 handshakeSuccess_(false),
1326 handshakeError_(false),
1327 socket_(std::move(socket)),
1328 preverifyResult_(preverifyResult),
1329 verifyResult_(verifyResult) {
1332 AsyncSSLSocket::UniquePtr moveSocket() && {
1333 return std::move(socket_);
1336 bool handshakeVerify_;
1337 bool handshakeSuccess_;
1338 bool handshakeError_;
1339 std::chrono::nanoseconds handshakeTime;
1342 AsyncSSLSocket::UniquePtr socket_;
1343 bool preverifyResult_;
1346 // HandshakeCallback
1347 bool handshakeVer(AsyncSSLSocket* /* sock */,
1349 X509_STORE_CTX* /* ctx */) noexcept override {
1350 handshakeVerify_ = true;
1352 EXPECT_EQ(preverifyResult_, preverifyOk);
1353 return verifyResult_;
1356 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1357 LOG(INFO) << "Handshake success";
1358 handshakeSuccess_ = true;
1359 handshakeTime = socket_->getHandshakeTime();
1364 const AsyncSocketException& ex) noexcept override {
1365 LOG(INFO) << "Handshake error " << ex.what();
1366 handshakeError_ = true;
1367 handshakeTime = socket_->getHandshakeTime();
1371 void writeSuccess() noexcept override {
1376 size_t bytesWritten,
1377 const AsyncSocketException& ex) noexcept override {
1378 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1383 class SSLHandshakeClient : public SSLHandshakeBase {
1386 AsyncSSLSocket::UniquePtr socket,
1387 bool preverifyResult,
1388 bool verifyResult) :
1389 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1390 socket_->sslConn(this, std::chrono::milliseconds::zero());
1394 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1396 SSLHandshakeClientNoVerify(
1397 AsyncSSLSocket::UniquePtr socket,
1398 bool preverifyResult,
1399 bool verifyResult) :
1400 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1403 std::chrono::milliseconds::zero(),
1404 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1408 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1410 SSLHandshakeClientDoVerify(
1411 AsyncSSLSocket::UniquePtr socket,
1412 bool preverifyResult,
1413 bool verifyResult) :
1414 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1417 std::chrono::milliseconds::zero(),
1418 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1422 class SSLHandshakeServer : public SSLHandshakeBase {
1425 AsyncSSLSocket::UniquePtr socket,
1426 bool preverifyResult,
1428 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1429 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1433 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1435 SSLHandshakeServerParseClientHello(
1436 AsyncSSLSocket::UniquePtr socket,
1437 bool preverifyResult,
1439 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1440 socket_->enableClientHelloParsing();
1441 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1444 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1447 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1448 handshakeSuccess_ = true;
1449 sock->getSSLSharedCiphers(sharedCiphers_);
1450 sock->getSSLServerCiphers(serverCiphers_);
1451 sock->getSSLClientCiphers(clientCiphers_);
1452 chosenCipher_ = sock->getNegotiatedCipherName();
1457 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1459 SSLHandshakeServerNoVerify(
1460 AsyncSSLSocket::UniquePtr socket,
1461 bool preverifyResult,
1463 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1466 std::chrono::milliseconds::zero(),
1467 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1471 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1473 SSLHandshakeServerDoVerify(
1474 AsyncSSLSocket::UniquePtr socket,
1475 bool preverifyResult,
1477 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1480 std::chrono::milliseconds::zero(),
1481 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1485 class EventBaseAborter : public AsyncTimeout {
1487 EventBaseAborter(EventBase* eventBase,
1490 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1491 , eventBase_(eventBase) {
1492 scheduleTimeout(timeoutMS);
1495 void timeoutExpired() noexcept override {
1496 FAIL() << "test timed out";
1497 eventBase_->terminateLoopSoon();
1501 EventBase* eventBase_;