2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/io/async/AsyncSSLSocket.h>
24 #include <folly/io/async/AsyncServerSocket.h>
25 #include <folly/io/async/AsyncSocket.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/io/async/AsyncTransport.h>
28 #include <folly/io/async/EventBase.h>
29 #include <folly/io/async/ssl/SSLErrors.h>
30 #include <folly/portability/Sockets.h>
31 #include <folly/portability/Unistd.h>
33 #include <gtest/gtest.h>
37 #include <sys/types.h>
47 // The destructors of all callback classes assert that the state is
48 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
49 // are responsible for setting the succeeded state properly before the
50 // destructors are called.
52 class WriteCallbackBase :
53 public AsyncTransportWrapper::WriteCallback {
56 : state(STATE_WAITING)
58 , exception(AsyncSocketException::UNKNOWN, "none") {}
60 ~WriteCallbackBase() {
61 EXPECT_EQ(STATE_SUCCEEDED, state);
65 const std::shared_ptr<AsyncSSLSocket> &socket) {
69 void writeSuccess() noexcept override {
70 std::cerr << "writeSuccess" << std::endl;
71 state = STATE_SUCCEEDED;
76 const AsyncSocketException& ex) noexcept override {
77 std::cerr << "writeError: bytesWritten " << bytesWritten
78 << ", exception " << ex.what() << std::endl;
81 this->bytesWritten = bytesWritten;
84 socket_->detachEventBase();
87 std::shared_ptr<AsyncSSLSocket> socket_;
90 AsyncSocketException exception;
93 class ReadCallbackBase :
94 public AsyncTransportWrapper::ReadCallback {
96 explicit ReadCallbackBase(WriteCallbackBase* wcb)
97 : wcb_(wcb), state(STATE_WAITING) {}
100 EXPECT_EQ(state, STATE_SUCCEEDED);
104 const std::shared_ptr<AsyncSSLSocket> &socket) {
108 void setState(StateEnum s) {
116 const AsyncSocketException& ex) noexcept override {
117 std::cerr << "readError " << ex.what() << std::endl;
118 state = STATE_FAILED;
120 socket_->detachEventBase();
123 void readEOF() noexcept override {
124 std::cerr << "readEOF" << std::endl;
127 socket_->detachEventBase();
130 std::shared_ptr<AsyncSSLSocket> socket_;
131 WriteCallbackBase *wcb_;
135 class ReadCallback : public ReadCallbackBase {
137 explicit ReadCallback(WriteCallbackBase *wcb)
138 : ReadCallbackBase(wcb)
142 for (std::vector<Buffer>::iterator it = buffers.begin();
147 currentBuffer.free();
150 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
151 if (!currentBuffer.buffer) {
152 currentBuffer.allocate(4096);
154 *bufReturn = currentBuffer.buffer;
155 *lenReturn = currentBuffer.length;
158 void readDataAvailable(size_t len) noexcept override {
159 std::cerr << "readDataAvailable, len " << len << std::endl;
161 currentBuffer.length = len;
163 wcb_->setSocket(socket_);
165 // Write back the same data.
166 socket_->write(wcb_, currentBuffer.buffer, len);
168 buffers.push_back(currentBuffer);
169 currentBuffer.reset();
170 state = STATE_SUCCEEDED;
175 Buffer() : buffer(nullptr), length(0) {}
176 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
182 void allocate(size_t length) {
183 assert(buffer == nullptr);
184 this->buffer = static_cast<char*>(malloc(length));
185 this->length = length;
196 std::vector<Buffer> buffers;
197 Buffer currentBuffer;
200 class ReadErrorCallback : public ReadCallbackBase {
202 explicit ReadErrorCallback(WriteCallbackBase *wcb)
203 : ReadCallbackBase(wcb) {}
205 // Return nullptr buffer to trigger readError()
206 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
207 *bufReturn = nullptr;
211 void readDataAvailable(size_t /* len */) noexcept override {
212 // This should never to called.
217 const AsyncSocketException& ex) noexcept override {
218 ReadCallbackBase::readErr(ex);
219 std::cerr << "ReadErrorCallback::readError" << std::endl;
220 setState(STATE_SUCCEEDED);
224 class ReadEOFCallback : public ReadCallbackBase {
226 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
228 // Return nullptr buffer to trigger readError()
229 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
230 *bufReturn = nullptr;
234 void readDataAvailable(size_t /* len */) noexcept override {
235 // This should never to called.
239 void readEOF() noexcept override {
240 ReadCallbackBase::readEOF();
241 setState(STATE_SUCCEEDED);
245 class WriteErrorCallback : public ReadCallback {
247 explicit WriteErrorCallback(WriteCallbackBase *wcb)
248 : ReadCallback(wcb) {}
250 void readDataAvailable(size_t len) noexcept override {
251 std::cerr << "readDataAvailable, len " << len << std::endl;
253 currentBuffer.length = len;
255 // close the socket before writing to trigger writeError().
256 ::close(socket_->getFd());
258 wcb_->setSocket(socket_);
260 // Write back the same data.
261 socket_->write(wcb_, currentBuffer.buffer, len);
263 if (wcb_->state == STATE_FAILED) {
264 setState(STATE_SUCCEEDED);
266 state = STATE_FAILED;
269 buffers.push_back(currentBuffer);
270 currentBuffer.reset();
273 void readErr(const AsyncSocketException& ex) noexcept override {
274 std::cerr << "readError " << ex.what() << std::endl;
275 // do nothing since this is expected
279 class EmptyReadCallback : public ReadCallback {
281 explicit EmptyReadCallback()
282 : ReadCallback(nullptr) {}
284 void readErr(const AsyncSocketException& ex) noexcept override {
285 std::cerr << "readError " << ex.what() << std::endl;
286 state = STATE_FAILED;
288 tcpSocket_->detachEventBase();
291 void readEOF() noexcept override {
292 std::cerr << "readEOF" << std::endl;
295 tcpSocket_->detachEventBase();
296 state = STATE_SUCCEEDED;
299 std::shared_ptr<AsyncSocket> tcpSocket_;
302 class HandshakeCallback :
303 public AsyncSSLSocket::HandshakeCB {
310 explicit HandshakeCallback(ReadCallbackBase *rcb,
311 ExpectType expect = EXPECT_SUCCESS):
312 state(STATE_WAITING),
317 const std::shared_ptr<AsyncSSLSocket> &socket) {
321 void setState(StateEnum s) {
326 // Functions inherited from AsyncSSLSocketHandshakeCallback
327 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
328 std::lock_guard<std::mutex> g(mutex_);
330 EXPECT_EQ(sock, socket_.get());
331 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
332 rcb_->setSocket(socket_);
333 sock->setReadCB(rcb_);
334 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
336 void handshakeErr(AsyncSSLSocket* /* sock */,
337 const AsyncSocketException& ex) noexcept override {
338 std::lock_guard<std::mutex> g(mutex_);
340 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
341 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
342 if (expect_ == EXPECT_ERROR) {
343 // rcb will never be invoked
344 rcb_->setState(STATE_SUCCEEDED);
346 errorString_ = ex.what();
349 void waitForHandshake() {
350 std::unique_lock<std::mutex> lock(mutex_);
351 cv_.wait(lock, [this] { return state != STATE_WAITING; });
354 ~HandshakeCallback() {
355 EXPECT_EQ(state, STATE_SUCCEEDED);
360 state = STATE_SUCCEEDED;
363 std::shared_ptr<AsyncSSLSocket> getSocket() {
368 std::shared_ptr<AsyncSSLSocket> socket_;
369 ReadCallbackBase *rcb_;
372 std::condition_variable cv_;
373 std::string errorString_;
376 class SSLServerAcceptCallbackBase:
377 public folly::AsyncServerSocket::AcceptCallback {
379 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
380 state(STATE_WAITING), hcb_(hcb) {}
382 ~SSLServerAcceptCallbackBase() {
383 EXPECT_EQ(state, STATE_SUCCEEDED);
386 void acceptError(const std::exception& ex) noexcept override {
387 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
388 << ex.what() << std::endl;
389 state = STATE_FAILED;
392 void connectionAccepted(
393 int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
394 printf("Connection accepted\n");
395 std::shared_ptr<AsyncSSLSocket> sslSock;
397 // Create a AsyncSSLSocket object with the fd. The socket should be
398 // added to the event base and in the state of accepting SSL connection.
399 sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
400 } catch (const std::exception &e) {
401 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
402 "object with socket " << e.what() << fd;
408 connAccepted(sslSock);
411 virtual void connAccepted(
412 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
415 HandshakeCallback *hcb_;
416 std::shared_ptr<folly::SSLContext> ctx_;
417 folly::EventBase* base_;
420 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
424 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
425 uint32_t timeout = 0):
426 SSLServerAcceptCallbackBase(hcb),
429 virtual ~SSLServerAcceptCallback() {
431 // if we set a timeout, we expect failure
432 EXPECT_EQ(hcb_->state, STATE_FAILED);
433 hcb_->setState(STATE_SUCCEEDED);
437 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
439 const std::shared_ptr<folly::AsyncSSLSocket> &s)
441 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
442 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
444 hcb_->setSocket(sock);
445 sock->sslAccept(hcb_, timeout_);
446 EXPECT_EQ(sock->getSSLState(),
447 AsyncSSLSocket::STATE_ACCEPTING);
449 state = STATE_SUCCEEDED;
453 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
455 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
456 SSLServerAcceptCallback(hcb) {}
458 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
460 const std::shared_ptr<folly::AsyncSSLSocket> &s)
463 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
465 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
467 int fd = sock->getFd();
471 // The accepted connection should already have TCP_NODELAY set
473 socklen_t valueLength = sizeof(value);
474 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
480 // Unset the TCP_NODELAY option.
482 socklen_t valueLength = sizeof(value);
483 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
486 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
490 SSLServerAcceptCallback::connAccepted(sock);
494 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
496 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
497 uint32_t timeout = 0):
498 SSLServerAcceptCallback(hcb, timeout) {}
500 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
502 const std::shared_ptr<folly::AsyncSSLSocket> &s)
504 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
506 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
508 hcb_->setSocket(sock);
509 sock->sslAccept(hcb_, timeout_);
510 ASSERT_TRUE((sock->getSSLState() ==
511 AsyncSSLSocket::STATE_ACCEPTING) ||
512 (sock->getSSLState() ==
513 AsyncSSLSocket::STATE_CACHE_LOOKUP));
515 state = STATE_SUCCEEDED;
520 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
522 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
523 SSLServerAcceptCallbackBase(hcb) {}
525 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
527 const std::shared_ptr<folly::AsyncSSLSocket> &s)
529 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
531 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
533 // The first call to sslAccept() should succeed.
534 hcb_->setSocket(sock);
535 sock->sslAccept(hcb_);
536 EXPECT_EQ(sock->getSSLState(),
537 AsyncSSLSocket::STATE_ACCEPTING);
539 // The second call to sslAccept() should fail.
540 HandshakeCallback callback2(hcb_->rcb_);
541 callback2.setSocket(sock);
542 sock->sslAccept(&callback2);
543 EXPECT_EQ(sock->getSSLState(),
544 AsyncSSLSocket::STATE_ERROR);
546 // Both callbacks should be in the error state.
547 EXPECT_EQ(hcb_->state, STATE_FAILED);
548 EXPECT_EQ(callback2.state, STATE_FAILED);
550 sock->detachEventBase();
552 state = STATE_SUCCEEDED;
553 hcb_->setState(STATE_SUCCEEDED);
554 callback2.setState(STATE_SUCCEEDED);
558 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
560 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
561 SSLServerAcceptCallbackBase(hcb) {}
563 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
565 const std::shared_ptr<folly::AsyncSSLSocket> &s)
567 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
569 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
571 hcb_->setSocket(sock);
572 sock->getEventBase()->tryRunAfterDelay([=] {
573 std::cerr << "Delayed SSL accept, client will have close by now"
575 // SSL accept will fail
578 AsyncSSLSocket::STATE_UNINIT);
579 hcb_->socket_->sslAccept(hcb_);
580 // This registers for an event
583 AsyncSSLSocket::STATE_ACCEPTING);
585 state = STATE_SUCCEEDED;
591 class TestSSLServer {
594 std::shared_ptr<folly::SSLContext> ctx_;
595 SSLServerAcceptCallbackBase *acb_;
596 std::shared_ptr<folly::AsyncServerSocket> socket_;
597 folly::SocketAddress address_;
600 static void *Main(void *ctx) {
601 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
603 std::cerr << "Server thread exited event loop" << std::endl;
608 // Create a TestSSLServer.
609 // This immediately starts listening on the given port.
610 explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
614 evb_.runInEventBaseThread([&](){
615 socket_->stopAccepting();
617 std::cerr << "Waiting for server thread to exit" << std::endl;
618 pthread_join(thread_, nullptr);
621 EventBase &getEventBase() { return evb_; }
623 const folly::SocketAddress& getAddress() const {
628 class TestSSLAsyncCacheServer : public TestSSLServer {
630 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
631 int lookupDelay = 100) :
633 SSL_CTX *sslCtx = ctx_->getSSLCtx();
634 SSL_CTX_sess_set_get_cb(sslCtx,
635 TestSSLAsyncCacheServer::getSessionCallback);
636 SSL_CTX_set_session_cache_mode(
637 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
640 lookupDelay_ = lookupDelay;
643 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
644 uint32_t getAsyncLookups() const { return asyncLookups_; }
647 static uint32_t asyncCallbacks_;
648 static uint32_t asyncLookups_;
649 static uint32_t lookupDelay_;
651 static SSL_SESSION* getSessionCallback(SSL* ssl,
652 unsigned char* /* sess_id */,
657 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
658 if (!SSL_want_sess_cache_lookup(ssl)) {
659 // libssl.so mismatch
660 std::cerr << "no async support" << std::endl;
664 AsyncSSLSocket *sslSocket =
665 AsyncSSLSocket::getFromSSL(ssl);
666 assert(sslSocket != nullptr);
667 // Going to simulate an async cache by just running delaying the miss 100ms
668 if (asyncCallbacks_ % 2 == 0) {
669 // This socket is already blocked on lookup, return miss
670 std::cerr << "returning miss" << std::endl;
672 // fresh meat - block it
673 std::cerr << "async lookup" << std::endl;
674 sslSocket->getEventBase()->tryRunAfterDelay(
675 std::bind(&AsyncSSLSocket::restartSSLAccept,
676 sslSocket), lookupDelay_);
677 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
685 void getfds(int fds[2]);
688 std::shared_ptr<folly::SSLContext> clientCtx,
689 std::shared_ptr<folly::SSLContext> serverCtx);
692 EventBase* eventBase,
693 AsyncSSLSocket::UniquePtr* clientSock,
694 AsyncSSLSocket::UniquePtr* serverSock);
696 class BlockingWriteClient :
697 private AsyncSSLSocket::HandshakeCB,
698 private AsyncTransportWrapper::WriteCallback {
700 explicit BlockingWriteClient(
701 AsyncSSLSocket::UniquePtr socket)
702 : socket_(std::move(socket)),
706 buf_.reset(new uint8_t[bufLen_]);
707 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
712 iov_.reset(new struct iovec[iovCount_]);
713 for (uint32_t n = 0; n < iovCount_; ++n) {
714 iov_[n].iov_base = buf_.get() + n;
716 iov_[n].iov_len = n % bufLen_;
718 iov_[n].iov_len = bufLen_ - (n % bufLen_);
722 socket_->sslConn(this, 100);
725 struct iovec* getIovec() const {
728 uint32_t getIovecCount() const {
733 void handshakeSuc(AsyncSSLSocket*) noexcept override {
734 socket_->writev(this, iov_.get(), iovCount_);
738 const AsyncSocketException& ex) noexcept override {
739 ADD_FAILURE() << "client handshake error: " << ex.what();
741 void writeSuccess() noexcept override {
746 const AsyncSocketException& ex) noexcept override {
747 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
751 AsyncSSLSocket::UniquePtr socket_;
754 std::unique_ptr<uint8_t[]> buf_;
755 std::unique_ptr<struct iovec[]> iov_;
758 class BlockingWriteServer :
759 private AsyncSSLSocket::HandshakeCB,
760 private AsyncTransportWrapper::ReadCallback {
762 explicit BlockingWriteServer(
763 AsyncSSLSocket::UniquePtr socket)
764 : socket_(std::move(socket)),
765 bufSize_(2500 * 2000),
767 buf_.reset(new uint8_t[bufSize_]);
768 socket_->sslAccept(this, 100);
771 void checkBuffer(struct iovec* iov, uint32_t count) const {
773 for (uint32_t n = 0; n < count; ++n) {
774 size_t bytesLeft = bytesRead_ - idx;
775 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
776 std::min(iov[n].iov_len, bytesLeft));
778 FAIL() << "buffer mismatch at iovec " << n << "/" << count
782 if (iov[n].iov_len > bytesLeft) {
783 FAIL() << "server did not read enough data: "
784 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
785 << " in iovec " << n << "/" << count;
788 idx += iov[n].iov_len;
790 if (idx != bytesRead_) {
791 ADD_FAILURE() << "server read extra data: " << bytesRead_
792 << " bytes read; expected " << idx;
797 void handshakeSuc(AsyncSSLSocket*) noexcept override {
798 // Wait 10ms before reading, so the client's writes will initially block.
799 socket_->getEventBase()->tryRunAfterDelay(
800 [this] { socket_->setReadCB(this); }, 10);
804 const AsyncSocketException& ex) noexcept override {
805 ADD_FAILURE() << "server handshake error: " << ex.what();
807 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
808 *bufReturn = buf_.get() + bytesRead_;
809 *lenReturn = bufSize_ - bytesRead_;
811 void readDataAvailable(size_t len) noexcept override {
813 socket_->setReadCB(nullptr);
814 socket_->getEventBase()->tryRunAfterDelay(
815 [this] { socket_->setReadCB(this); }, 2);
817 void readEOF() noexcept override {
821 const AsyncSocketException& ex) noexcept override {
822 ADD_FAILURE() << "server read error: " << ex.what();
825 AsyncSSLSocket::UniquePtr socket_;
828 std::unique_ptr<uint8_t[]> buf_;
832 private AsyncSSLSocket::HandshakeCB,
833 private AsyncTransportWrapper::WriteCallback {
836 AsyncSSLSocket::UniquePtr socket)
837 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
838 socket_->sslConn(this);
841 const unsigned char* nextProto;
842 unsigned nextProtoLength;
843 SSLContext::NextProtocolType protocolType;
846 void handshakeSuc(AsyncSSLSocket*) noexcept override {
847 socket_->getSelectedNextProtocol(
848 &nextProto, &nextProtoLength, &protocolType);
852 const AsyncSocketException& ex) noexcept override {
853 ADD_FAILURE() << "client handshake error: " << ex.what();
855 void writeSuccess() noexcept override {
860 const AsyncSocketException& ex) noexcept override {
861 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
865 AsyncSSLSocket::UniquePtr socket_;
869 private AsyncSSLSocket::HandshakeCB,
870 private AsyncTransportWrapper::ReadCallback {
872 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
873 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
874 socket_->sslAccept(this);
877 const unsigned char* nextProto;
878 unsigned nextProtoLength;
879 SSLContext::NextProtocolType protocolType;
882 void handshakeSuc(AsyncSSLSocket*) noexcept override {
883 socket_->getSelectedNextProtocol(
884 &nextProto, &nextProtoLength, &protocolType);
888 const AsyncSocketException& ex) noexcept override {
889 ADD_FAILURE() << "server handshake error: " << ex.what();
891 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
894 void readDataAvailable(size_t /* len */) noexcept override {}
895 void readEOF() noexcept override {
899 const AsyncSocketException& ex) noexcept override {
900 ADD_FAILURE() << "server read error: " << ex.what();
903 AsyncSSLSocket::UniquePtr socket_;
906 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
907 public AsyncTransportWrapper::ReadCallback {
909 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
910 : socket_(std::move(socket)) {
911 socket_->sslAccept(this);
914 ~RenegotiatingServer() {
915 socket_->setReadCB(nullptr);
918 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
919 LOG(INFO) << "Renegotiating server handshake success";
920 socket_->setReadCB(this);
924 const AsyncSocketException& ex) noexcept override {
925 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
927 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
928 *lenReturn = sizeof(buf);
931 void readDataAvailable(size_t /* len */) noexcept override {}
932 void readEOF() noexcept override {}
933 void readErr(const AsyncSocketException& ex) noexcept override {
934 LOG(INFO) << "server got read error " << ex.what();
935 auto exPtr = dynamic_cast<const SSLException*>(&ex);
936 ASSERT_NE(nullptr, exPtr);
937 std::string exStr(ex.what());
938 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
939 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
940 renegotiationError_ = true;
943 AsyncSSLSocket::UniquePtr socket_;
944 unsigned char buf[128];
945 bool renegotiationError_{false};
948 #ifndef OPENSSL_NO_TLSEXT
950 private AsyncSSLSocket::HandshakeCB,
951 private AsyncTransportWrapper::WriteCallback {
954 AsyncSSLSocket::UniquePtr socket)
955 : serverNameMatch(false), socket_(std::move(socket)) {
956 socket_->sslConn(this);
959 bool serverNameMatch;
962 void handshakeSuc(AsyncSSLSocket*) noexcept override {
963 serverNameMatch = socket_->isServerNameMatch();
967 const AsyncSocketException& ex) noexcept override {
968 ADD_FAILURE() << "client handshake error: " << ex.what();
970 void writeSuccess() noexcept override {
975 const AsyncSocketException& ex) noexcept override {
976 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
980 AsyncSSLSocket::UniquePtr socket_;
984 private AsyncSSLSocket::HandshakeCB,
985 private AsyncTransportWrapper::ReadCallback {
988 AsyncSSLSocket::UniquePtr socket,
989 const std::shared_ptr<folly::SSLContext>& ctx,
990 const std::shared_ptr<folly::SSLContext>& sniCtx,
991 const std::string& expectedServerName)
992 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
993 expectedServerName_(expectedServerName) {
994 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
995 std::placeholders::_1));
996 socket_->sslAccept(this);
999 bool serverNameMatch;
1002 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1005 const AsyncSocketException& ex) noexcept override {
1006 ADD_FAILURE() << "server handshake error: " << ex.what();
1008 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1011 void readDataAvailable(size_t /* len */) noexcept override {}
1012 void readEOF() noexcept override {
1016 const AsyncSocketException& ex) noexcept override {
1017 ADD_FAILURE() << "server read error: " << ex.what();
1020 folly::SSLContext::ServerNameCallbackResult
1021 serverNameCallback(SSL *ssl) {
1022 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1025 !strcasecmp(expectedServerName_.c_str(), sn)) {
1026 AsyncSSLSocket *sslSocket =
1027 AsyncSSLSocket::getFromSSL(ssl);
1028 sslSocket->switchServerSSLContext(sniCtx_);
1029 serverNameMatch = true;
1030 return folly::SSLContext::SERVER_NAME_FOUND;
1032 serverNameMatch = false;
1033 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1037 AsyncSSLSocket::UniquePtr socket_;
1038 std::shared_ptr<folly::SSLContext> sniCtx_;
1039 std::string expectedServerName_;
1043 class SSLClient : public AsyncSocket::ConnectCallback,
1044 public AsyncTransportWrapper::WriteCallback,
1045 public AsyncTransportWrapper::ReadCallback
1048 EventBase *eventBase_;
1049 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1050 SSL_SESSION *session_;
1051 std::shared_ptr<folly::SSLContext> ctx_;
1053 folly::SocketAddress address_;
1057 uint32_t bytesRead_;
1061 uint32_t writeAfterConnectErrors_;
1063 // These settings test that we eventually drain the
1064 // socket, even if the maxReadsPerEvent_ is hit during
1065 // a event loop iteration.
1066 static constexpr size_t kMaxReadsPerEvent = 2;
1067 static constexpr size_t kMaxReadBufferSz =
1068 sizeof(readbuf_) / kMaxReadsPerEvent / 2; // 2 event loop iterations
1071 SSLClient(EventBase *eventBase,
1072 const folly::SocketAddress& address,
1074 uint32_t timeout = 0)
1075 : eventBase_(eventBase),
1077 requests_(requests),
1084 writeAfterConnectErrors_(0) {
1085 ctx_.reset(new folly::SSLContext());
1086 ctx_->setOptions(SSL_OP_NO_TICKET);
1087 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1088 memset(buf_, 'a', sizeof(buf_));
1093 SSL_SESSION_free(session_);
1096 EXPECT_EQ(bytesRead_, sizeof(buf_));
1100 uint32_t getHit() const { return hit_; }
1102 uint32_t getMiss() const { return miss_; }
1104 uint32_t getErrors() const { return errors_; }
1106 uint32_t getWriteAfterConnectErrors() const {
1107 return writeAfterConnectErrors_;
1110 void connect(bool writeNow = false) {
1111 sslSocket_ = AsyncSSLSocket::newSocket(
1113 if (session_ != nullptr) {
1114 sslSocket_->setSSLSession(session_);
1117 sslSocket_->connect(this, address_, timeout_);
1118 if (sslSocket_ && writeNow) {
1119 // write some junk, used in an error test
1120 sslSocket_->write(this, buf_, sizeof(buf_));
1124 void connectSuccess() noexcept override {
1125 std::cerr << "client SSL socket connected" << std::endl;
1126 if (sslSocket_->getSSLSessionReused()) {
1130 if (session_ != nullptr) {
1131 SSL_SESSION_free(session_);
1133 session_ = sslSocket_->getSSLSession();
1137 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1138 sslSocket_->write(this, buf_, sizeof(buf_));
1139 sslSocket_->setReadCB(this);
1140 memset(readbuf_, 'b', sizeof(readbuf_));
1145 const AsyncSocketException& ex) noexcept override {
1146 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1151 void writeSuccess() noexcept override {
1152 std::cerr << "client write success" << std::endl;
1155 void writeErr(size_t /* bytesWritten */,
1156 const AsyncSocketException& ex) noexcept override {
1157 std::cerr << "client writeError: " << ex.what() << std::endl;
1159 writeAfterConnectErrors_++;
1163 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1164 *bufReturn = readbuf_ + bytesRead_;
1165 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1168 void readEOF() noexcept override {
1169 std::cerr << "client readEOF" << std::endl;
1173 const AsyncSocketException& ex) noexcept override {
1174 std::cerr << "client readError: " << ex.what() << std::endl;
1177 void readDataAvailable(size_t len) noexcept override {
1178 std::cerr << "client read data: " << len << std::endl;
1180 if (bytesRead_ == sizeof(buf_)) {
1181 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1182 sslSocket_->closeNow();
1184 if (requests_ != 0) {
1192 class SSLHandshakeBase :
1193 public AsyncSSLSocket::HandshakeCB,
1194 private AsyncTransportWrapper::WriteCallback {
1196 explicit SSLHandshakeBase(
1197 AsyncSSLSocket::UniquePtr socket,
1198 bool preverifyResult,
1199 bool verifyResult) :
1200 handshakeVerify_(false),
1201 handshakeSuccess_(false),
1202 handshakeError_(false),
1203 socket_(std::move(socket)),
1204 preverifyResult_(preverifyResult),
1205 verifyResult_(verifyResult) {
1208 AsyncSSLSocket::UniquePtr moveSocket() && {
1209 return std::move(socket_);
1212 bool handshakeVerify_;
1213 bool handshakeSuccess_;
1214 bool handshakeError_;
1215 std::chrono::nanoseconds handshakeTime;
1218 AsyncSSLSocket::UniquePtr socket_;
1219 bool preverifyResult_;
1222 // HandshakeCallback
1223 bool handshakeVer(AsyncSSLSocket* /* sock */,
1225 X509_STORE_CTX* /* ctx */) noexcept override {
1226 handshakeVerify_ = true;
1228 EXPECT_EQ(preverifyResult_, preverifyOk);
1229 return verifyResult_;
1232 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1233 LOG(INFO) << "Handshake success";
1234 handshakeSuccess_ = true;
1235 handshakeTime = socket_->getHandshakeTime();
1240 const AsyncSocketException& ex) noexcept override {
1241 LOG(INFO) << "Handshake error " << ex.what();
1242 handshakeError_ = true;
1243 handshakeTime = socket_->getHandshakeTime();
1247 void writeSuccess() noexcept override {
1252 size_t bytesWritten,
1253 const AsyncSocketException& ex) noexcept override {
1254 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1259 class SSLHandshakeClient : public SSLHandshakeBase {
1262 AsyncSSLSocket::UniquePtr socket,
1263 bool preverifyResult,
1264 bool verifyResult) :
1265 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1266 socket_->sslConn(this, 0);
1270 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1272 SSLHandshakeClientNoVerify(
1273 AsyncSSLSocket::UniquePtr socket,
1274 bool preverifyResult,
1275 bool verifyResult) :
1276 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1277 socket_->sslConn(this, 0,
1278 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1282 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1284 SSLHandshakeClientDoVerify(
1285 AsyncSSLSocket::UniquePtr socket,
1286 bool preverifyResult,
1287 bool verifyResult) :
1288 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1289 socket_->sslConn(this, 0,
1290 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1294 class SSLHandshakeServer : public SSLHandshakeBase {
1297 AsyncSSLSocket::UniquePtr socket,
1298 bool preverifyResult,
1300 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1301 socket_->sslAccept(this, 0);
1305 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1307 SSLHandshakeServerParseClientHello(
1308 AsyncSSLSocket::UniquePtr socket,
1309 bool preverifyResult,
1311 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1312 socket_->enableClientHelloParsing();
1313 socket_->sslAccept(this, 0);
1316 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1319 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1320 handshakeSuccess_ = true;
1321 sock->getSSLSharedCiphers(sharedCiphers_);
1322 sock->getSSLServerCiphers(serverCiphers_);
1323 sock->getSSLClientCiphers(clientCiphers_);
1324 chosenCipher_ = sock->getNegotiatedCipherName();
1329 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1331 SSLHandshakeServerNoVerify(
1332 AsyncSSLSocket::UniquePtr socket,
1333 bool preverifyResult,
1335 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1336 socket_->sslAccept(this, 0,
1337 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1341 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1343 SSLHandshakeServerDoVerify(
1344 AsyncSSLSocket::UniquePtr socket,
1345 bool preverifyResult,
1347 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1348 socket_->sslAccept(this, 0,
1349 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1353 class EventBaseAborter : public AsyncTimeout {
1355 EventBaseAborter(EventBase* eventBase,
1358 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1359 , eventBase_(eventBase) {
1360 scheduleTimeout(timeoutMS);
1363 void timeoutExpired() noexcept override {
1364 FAIL() << "test timed out";
1365 eventBase_->terminateLoopSoon();
1369 EventBase* eventBase_;