2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/io/async/AsyncSSLSocket.h>
24 #include <folly/io/async/AsyncServerSocket.h>
25 #include <folly/io/async/AsyncSocket.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/io/async/AsyncTransport.h>
28 #include <folly/io/async/EventBase.h>
29 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <gtest/gtest.h>
37 #include <sys/types.h>
38 #include <sys/socket.h>
39 #include <netinet/tcp.h>
49 // The destructors of all callback classes assert that the state is
50 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
51 // are responsible for setting the succeeded state properly before the
52 // destructors are called.
54 class WriteCallbackBase :
55 public AsyncTransportWrapper::WriteCallback {
58 : state(STATE_WAITING)
60 , exception(AsyncSocketException::UNKNOWN, "none") {}
62 ~WriteCallbackBase() {
63 EXPECT_EQ(STATE_SUCCEEDED, state);
67 const std::shared_ptr<AsyncSSLSocket> &socket) {
71 void writeSuccess() noexcept override {
72 std::cerr << "writeSuccess" << std::endl;
73 state = STATE_SUCCEEDED;
78 const AsyncSocketException& ex) noexcept override {
79 std::cerr << "writeError: bytesWritten " << bytesWritten
80 << ", exception " << ex.what() << std::endl;
83 this->bytesWritten = bytesWritten;
86 socket_->detachEventBase();
89 std::shared_ptr<AsyncSSLSocket> socket_;
92 AsyncSocketException exception;
95 class ReadCallbackBase :
96 public AsyncTransportWrapper::ReadCallback {
98 explicit ReadCallbackBase(WriteCallbackBase* wcb)
99 : wcb_(wcb), state(STATE_WAITING) {}
101 ~ReadCallbackBase() {
102 EXPECT_EQ(state, STATE_SUCCEEDED);
106 const std::shared_ptr<AsyncSSLSocket> &socket) {
110 void setState(StateEnum s) {
118 const AsyncSocketException& ex) noexcept override {
119 std::cerr << "readError " << ex.what() << std::endl;
120 state = STATE_FAILED;
122 socket_->detachEventBase();
125 void readEOF() noexcept override {
126 std::cerr << "readEOF" << std::endl;
129 socket_->detachEventBase();
132 std::shared_ptr<AsyncSSLSocket> socket_;
133 WriteCallbackBase *wcb_;
137 class ReadCallback : public ReadCallbackBase {
139 explicit ReadCallback(WriteCallbackBase *wcb)
140 : ReadCallbackBase(wcb)
144 for (std::vector<Buffer>::iterator it = buffers.begin();
149 currentBuffer.free();
152 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
153 if (!currentBuffer.buffer) {
154 currentBuffer.allocate(4096);
156 *bufReturn = currentBuffer.buffer;
157 *lenReturn = currentBuffer.length;
160 void readDataAvailable(size_t len) noexcept override {
161 std::cerr << "readDataAvailable, len " << len << std::endl;
163 currentBuffer.length = len;
165 wcb_->setSocket(socket_);
167 // Write back the same data.
168 socket_->write(wcb_, currentBuffer.buffer, len);
170 buffers.push_back(currentBuffer);
171 currentBuffer.reset();
172 state = STATE_SUCCEEDED;
177 Buffer() : buffer(nullptr), length(0) {}
178 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
184 void allocate(size_t length) {
185 assert(buffer == nullptr);
186 this->buffer = static_cast<char*>(malloc(length));
187 this->length = length;
198 std::vector<Buffer> buffers;
199 Buffer currentBuffer;
202 class ReadErrorCallback : public ReadCallbackBase {
204 explicit ReadErrorCallback(WriteCallbackBase *wcb)
205 : ReadCallbackBase(wcb) {}
207 // Return nullptr buffer to trigger readError()
208 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
209 *bufReturn = nullptr;
213 void readDataAvailable(size_t /* len */) noexcept override {
214 // This should never to called.
219 const AsyncSocketException& ex) noexcept override {
220 ReadCallbackBase::readErr(ex);
221 std::cerr << "ReadErrorCallback::readError" << std::endl;
222 setState(STATE_SUCCEEDED);
226 class ReadEOFCallback : public ReadCallbackBase {
228 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
230 // Return nullptr buffer to trigger readError()
231 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
232 *bufReturn = nullptr;
236 void readDataAvailable(size_t /* len */) noexcept override {
237 // This should never to called.
241 void readEOF() noexcept override {
242 ReadCallbackBase::readEOF();
243 setState(STATE_SUCCEEDED);
247 class WriteErrorCallback : public ReadCallback {
249 explicit WriteErrorCallback(WriteCallbackBase *wcb)
250 : ReadCallback(wcb) {}
252 void readDataAvailable(size_t len) noexcept override {
253 std::cerr << "readDataAvailable, len " << len << std::endl;
255 currentBuffer.length = len;
257 // close the socket before writing to trigger writeError().
258 ::close(socket_->getFd());
260 wcb_->setSocket(socket_);
262 // Write back the same data.
263 socket_->write(wcb_, currentBuffer.buffer, len);
265 if (wcb_->state == STATE_FAILED) {
266 setState(STATE_SUCCEEDED);
268 state = STATE_FAILED;
271 buffers.push_back(currentBuffer);
272 currentBuffer.reset();
275 void readErr(const AsyncSocketException& ex) noexcept override {
276 std::cerr << "readError " << ex.what() << std::endl;
277 // do nothing since this is expected
281 class EmptyReadCallback : public ReadCallback {
283 explicit EmptyReadCallback()
284 : ReadCallback(nullptr) {}
286 void readErr(const AsyncSocketException& ex) noexcept override {
287 std::cerr << "readError " << ex.what() << std::endl;
288 state = STATE_FAILED;
290 tcpSocket_->detachEventBase();
293 void readEOF() noexcept override {
294 std::cerr << "readEOF" << std::endl;
297 tcpSocket_->detachEventBase();
298 state = STATE_SUCCEEDED;
301 std::shared_ptr<AsyncSocket> tcpSocket_;
304 class HandshakeCallback :
305 public AsyncSSLSocket::HandshakeCB {
312 explicit HandshakeCallback(ReadCallbackBase *rcb,
313 ExpectType expect = EXPECT_SUCCESS):
314 state(STATE_WAITING),
319 const std::shared_ptr<AsyncSSLSocket> &socket) {
323 void setState(StateEnum s) {
328 // Functions inherited from AsyncSSLSocketHandshakeCallback
329 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
330 std::lock_guard<std::mutex> g(mutex_);
332 EXPECT_EQ(sock, socket_.get());
333 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
334 rcb_->setSocket(socket_);
335 sock->setReadCB(rcb_);
336 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
338 void handshakeErr(AsyncSSLSocket* /* sock */,
339 const AsyncSocketException& ex) noexcept override {
340 std::lock_guard<std::mutex> g(mutex_);
342 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
343 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
344 if (expect_ == EXPECT_ERROR) {
345 // rcb will never be invoked
346 rcb_->setState(STATE_SUCCEEDED);
348 errorString_ = ex.what();
351 void waitForHandshake() {
352 std::unique_lock<std::mutex> lock(mutex_);
353 cv_.wait(lock, [this] { return state != STATE_WAITING; });
356 ~HandshakeCallback() {
357 EXPECT_EQ(state, STATE_SUCCEEDED);
362 state = STATE_SUCCEEDED;
365 std::shared_ptr<AsyncSSLSocket> getSocket() {
370 std::shared_ptr<AsyncSSLSocket> socket_;
371 ReadCallbackBase *rcb_;
374 std::condition_variable cv_;
375 std::string errorString_;
378 class SSLServerAcceptCallbackBase:
379 public folly::AsyncServerSocket::AcceptCallback {
381 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
382 state(STATE_WAITING), hcb_(hcb) {}
384 ~SSLServerAcceptCallbackBase() {
385 EXPECT_EQ(state, STATE_SUCCEEDED);
388 void acceptError(const std::exception& ex) noexcept override {
389 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
390 << ex.what() << std::endl;
391 state = STATE_FAILED;
394 void connectionAccepted(
395 int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
396 printf("Connection accepted\n");
397 std::shared_ptr<AsyncSSLSocket> sslSock;
399 // Create a AsyncSSLSocket object with the fd. The socket should be
400 // added to the event base and in the state of accepting SSL connection.
401 sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
402 } catch (const std::exception &e) {
403 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
404 "object with socket " << e.what() << fd;
410 connAccepted(sslSock);
413 virtual void connAccepted(
414 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
417 HandshakeCallback *hcb_;
418 std::shared_ptr<folly::SSLContext> ctx_;
419 folly::EventBase* base_;
422 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
426 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
427 uint32_t timeout = 0):
428 SSLServerAcceptCallbackBase(hcb),
431 virtual ~SSLServerAcceptCallback() {
433 // if we set a timeout, we expect failure
434 EXPECT_EQ(hcb_->state, STATE_FAILED);
435 hcb_->setState(STATE_SUCCEEDED);
439 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
441 const std::shared_ptr<folly::AsyncSSLSocket> &s)
443 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
444 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
446 hcb_->setSocket(sock);
447 sock->sslAccept(hcb_, timeout_);
448 EXPECT_EQ(sock->getSSLState(),
449 AsyncSSLSocket::STATE_ACCEPTING);
451 state = STATE_SUCCEEDED;
455 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
457 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
458 SSLServerAcceptCallback(hcb) {}
460 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
462 const std::shared_ptr<folly::AsyncSSLSocket> &s)
465 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
467 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
469 int fd = sock->getFd();
473 // The accepted connection should already have TCP_NODELAY set
475 socklen_t valueLength = sizeof(value);
476 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
482 // Unset the TCP_NODELAY option.
484 socklen_t valueLength = sizeof(value);
485 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
488 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
492 SSLServerAcceptCallback::connAccepted(sock);
496 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
498 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
499 uint32_t timeout = 0):
500 SSLServerAcceptCallback(hcb, timeout) {}
502 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
504 const std::shared_ptr<folly::AsyncSSLSocket> &s)
506 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
508 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
510 hcb_->setSocket(sock);
511 sock->sslAccept(hcb_, timeout_);
512 ASSERT_TRUE((sock->getSSLState() ==
513 AsyncSSLSocket::STATE_ACCEPTING) ||
514 (sock->getSSLState() ==
515 AsyncSSLSocket::STATE_CACHE_LOOKUP));
517 state = STATE_SUCCEEDED;
522 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
524 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
525 SSLServerAcceptCallbackBase(hcb) {}
527 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
529 const std::shared_ptr<folly::AsyncSSLSocket> &s)
531 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
533 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
535 // The first call to sslAccept() should succeed.
536 hcb_->setSocket(sock);
537 sock->sslAccept(hcb_);
538 EXPECT_EQ(sock->getSSLState(),
539 AsyncSSLSocket::STATE_ACCEPTING);
541 // The second call to sslAccept() should fail.
542 HandshakeCallback callback2(hcb_->rcb_);
543 callback2.setSocket(sock);
544 sock->sslAccept(&callback2);
545 EXPECT_EQ(sock->getSSLState(),
546 AsyncSSLSocket::STATE_ERROR);
548 // Both callbacks should be in the error state.
549 EXPECT_EQ(hcb_->state, STATE_FAILED);
550 EXPECT_EQ(callback2.state, STATE_FAILED);
552 sock->detachEventBase();
554 state = STATE_SUCCEEDED;
555 hcb_->setState(STATE_SUCCEEDED);
556 callback2.setState(STATE_SUCCEEDED);
560 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
562 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
563 SSLServerAcceptCallbackBase(hcb) {}
565 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
567 const std::shared_ptr<folly::AsyncSSLSocket> &s)
569 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
571 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
573 hcb_->setSocket(sock);
574 sock->getEventBase()->tryRunAfterDelay([=] {
575 std::cerr << "Delayed SSL accept, client will have close by now"
577 // SSL accept will fail
580 AsyncSSLSocket::STATE_UNINIT);
581 hcb_->socket_->sslAccept(hcb_);
582 // This registers for an event
585 AsyncSSLSocket::STATE_ACCEPTING);
587 state = STATE_SUCCEEDED;
593 class TestSSLServer {
596 std::shared_ptr<folly::SSLContext> ctx_;
597 SSLServerAcceptCallbackBase *acb_;
598 std::shared_ptr<folly::AsyncServerSocket> socket_;
599 folly::SocketAddress address_;
602 static void *Main(void *ctx) {
603 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
605 std::cerr << "Server thread exited event loop" << std::endl;
610 // Create a TestSSLServer.
611 // This immediately starts listening on the given port.
612 explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
616 evb_.runInEventBaseThread([&](){
617 socket_->stopAccepting();
619 std::cerr << "Waiting for server thread to exit" << std::endl;
620 pthread_join(thread_, nullptr);
623 EventBase &getEventBase() { return evb_; }
625 const folly::SocketAddress& getAddress() const {
630 class TestSSLAsyncCacheServer : public TestSSLServer {
632 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
633 int lookupDelay = 100) :
635 SSL_CTX *sslCtx = ctx_->getSSLCtx();
636 SSL_CTX_sess_set_get_cb(sslCtx,
637 TestSSLAsyncCacheServer::getSessionCallback);
638 SSL_CTX_set_session_cache_mode(
639 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
642 lookupDelay_ = lookupDelay;
645 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
646 uint32_t getAsyncLookups() const { return asyncLookups_; }
649 static uint32_t asyncCallbacks_;
650 static uint32_t asyncLookups_;
651 static uint32_t lookupDelay_;
653 static SSL_SESSION* getSessionCallback(SSL* ssl,
654 unsigned char* /* sess_id */,
659 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
660 if (!SSL_want_sess_cache_lookup(ssl)) {
661 // libssl.so mismatch
662 std::cerr << "no async support" << std::endl;
666 AsyncSSLSocket *sslSocket =
667 AsyncSSLSocket::getFromSSL(ssl);
668 assert(sslSocket != nullptr);
669 // Going to simulate an async cache by just running delaying the miss 100ms
670 if (asyncCallbacks_ % 2 == 0) {
671 // This socket is already blocked on lookup, return miss
672 std::cerr << "returning miss" << std::endl;
674 // fresh meat - block it
675 std::cerr << "async lookup" << std::endl;
676 sslSocket->getEventBase()->tryRunAfterDelay(
677 std::bind(&AsyncSSLSocket::restartSSLAccept,
678 sslSocket), lookupDelay_);
679 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
687 void getfds(int fds[2]);
690 std::shared_ptr<folly::SSLContext> clientCtx,
691 std::shared_ptr<folly::SSLContext> serverCtx);
694 EventBase* eventBase,
695 AsyncSSLSocket::UniquePtr* clientSock,
696 AsyncSSLSocket::UniquePtr* serverSock);
698 class BlockingWriteClient :
699 private AsyncSSLSocket::HandshakeCB,
700 private AsyncTransportWrapper::WriteCallback {
702 explicit BlockingWriteClient(
703 AsyncSSLSocket::UniquePtr socket)
704 : socket_(std::move(socket)),
708 buf_.reset(new uint8_t[bufLen_]);
709 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
714 iov_.reset(new struct iovec[iovCount_]);
715 for (uint32_t n = 0; n < iovCount_; ++n) {
716 iov_[n].iov_base = buf_.get() + n;
718 iov_[n].iov_len = n % bufLen_;
720 iov_[n].iov_len = bufLen_ - (n % bufLen_);
724 socket_->sslConn(this, 100);
727 struct iovec* getIovec() const {
730 uint32_t getIovecCount() const {
735 void handshakeSuc(AsyncSSLSocket*) noexcept override {
736 socket_->writev(this, iov_.get(), iovCount_);
740 const AsyncSocketException& ex) noexcept override {
741 ADD_FAILURE() << "client handshake error: " << ex.what();
743 void writeSuccess() noexcept override {
748 const AsyncSocketException& ex) noexcept override {
749 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
753 AsyncSSLSocket::UniquePtr socket_;
756 std::unique_ptr<uint8_t[]> buf_;
757 std::unique_ptr<struct iovec[]> iov_;
760 class BlockingWriteServer :
761 private AsyncSSLSocket::HandshakeCB,
762 private AsyncTransportWrapper::ReadCallback {
764 explicit BlockingWriteServer(
765 AsyncSSLSocket::UniquePtr socket)
766 : socket_(std::move(socket)),
767 bufSize_(2500 * 2000),
769 buf_.reset(new uint8_t[bufSize_]);
770 socket_->sslAccept(this, 100);
773 void checkBuffer(struct iovec* iov, uint32_t count) const {
775 for (uint32_t n = 0; n < count; ++n) {
776 size_t bytesLeft = bytesRead_ - idx;
777 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
778 std::min(iov[n].iov_len, bytesLeft));
780 FAIL() << "buffer mismatch at iovec " << n << "/" << count
784 if (iov[n].iov_len > bytesLeft) {
785 FAIL() << "server did not read enough data: "
786 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
787 << " in iovec " << n << "/" << count;
790 idx += iov[n].iov_len;
792 if (idx != bytesRead_) {
793 ADD_FAILURE() << "server read extra data: " << bytesRead_
794 << " bytes read; expected " << idx;
799 void handshakeSuc(AsyncSSLSocket*) noexcept override {
800 // Wait 10ms before reading, so the client's writes will initially block.
801 socket_->getEventBase()->tryRunAfterDelay(
802 [this] { socket_->setReadCB(this); }, 10);
806 const AsyncSocketException& ex) noexcept override {
807 ADD_FAILURE() << "server handshake error: " << ex.what();
809 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
810 *bufReturn = buf_.get() + bytesRead_;
811 *lenReturn = bufSize_ - bytesRead_;
813 void readDataAvailable(size_t len) noexcept override {
815 socket_->setReadCB(nullptr);
816 socket_->getEventBase()->tryRunAfterDelay(
817 [this] { socket_->setReadCB(this); }, 2);
819 void readEOF() noexcept override {
823 const AsyncSocketException& ex) noexcept override {
824 ADD_FAILURE() << "server read error: " << ex.what();
827 AsyncSSLSocket::UniquePtr socket_;
830 std::unique_ptr<uint8_t[]> buf_;
834 private AsyncSSLSocket::HandshakeCB,
835 private AsyncTransportWrapper::WriteCallback {
838 AsyncSSLSocket::UniquePtr socket)
839 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
840 socket_->sslConn(this);
843 const unsigned char* nextProto;
844 unsigned nextProtoLength;
845 SSLContext::NextProtocolType protocolType;
848 void handshakeSuc(AsyncSSLSocket*) noexcept override {
849 socket_->getSelectedNextProtocol(
850 &nextProto, &nextProtoLength, &protocolType);
854 const AsyncSocketException& ex) noexcept override {
855 ADD_FAILURE() << "client handshake error: " << ex.what();
857 void writeSuccess() noexcept override {
862 const AsyncSocketException& ex) noexcept override {
863 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
867 AsyncSSLSocket::UniquePtr socket_;
871 private AsyncSSLSocket::HandshakeCB,
872 private AsyncTransportWrapper::ReadCallback {
874 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
875 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
876 socket_->sslAccept(this);
879 const unsigned char* nextProto;
880 unsigned nextProtoLength;
881 SSLContext::NextProtocolType protocolType;
884 void handshakeSuc(AsyncSSLSocket*) noexcept override {
885 socket_->getSelectedNextProtocol(
886 &nextProto, &nextProtoLength, &protocolType);
890 const AsyncSocketException& ex) noexcept override {
891 ADD_FAILURE() << "server handshake error: " << ex.what();
893 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
896 void readDataAvailable(size_t /* len */) noexcept override {}
897 void readEOF() noexcept override {
901 const AsyncSocketException& ex) noexcept override {
902 ADD_FAILURE() << "server read error: " << ex.what();
905 AsyncSSLSocket::UniquePtr socket_;
908 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
909 public AsyncTransportWrapper::ReadCallback {
911 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
912 : socket_(std::move(socket)) {
913 socket_->sslAccept(this);
916 ~RenegotiatingServer() {
917 socket_->setReadCB(nullptr);
920 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
921 LOG(INFO) << "Renegotiating server handshake success";
922 socket_->setReadCB(this);
926 const AsyncSocketException& ex) noexcept override {
927 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
929 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
930 *lenReturn = sizeof(buf);
933 void readDataAvailable(size_t /* len */) noexcept override {}
934 void readEOF() noexcept override {}
935 void readErr(const AsyncSocketException& ex) noexcept override {
936 LOG(INFO) << "server got read error " << ex.what();
937 auto exPtr = dynamic_cast<const SSLException*>(&ex);
938 ASSERT_NE(nullptr, exPtr);
939 std::string exStr(ex.what());
940 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
941 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
942 renegotiationError_ = true;
945 AsyncSSLSocket::UniquePtr socket_;
946 unsigned char buf[128];
947 bool renegotiationError_{false};
950 #ifndef OPENSSL_NO_TLSEXT
952 private AsyncSSLSocket::HandshakeCB,
953 private AsyncTransportWrapper::WriteCallback {
956 AsyncSSLSocket::UniquePtr socket)
957 : serverNameMatch(false), socket_(std::move(socket)) {
958 socket_->sslConn(this);
961 bool serverNameMatch;
964 void handshakeSuc(AsyncSSLSocket*) noexcept override {
965 serverNameMatch = socket_->isServerNameMatch();
969 const AsyncSocketException& ex) noexcept override {
970 ADD_FAILURE() << "client handshake error: " << ex.what();
972 void writeSuccess() noexcept override {
977 const AsyncSocketException& ex) noexcept override {
978 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
982 AsyncSSLSocket::UniquePtr socket_;
986 private AsyncSSLSocket::HandshakeCB,
987 private AsyncTransportWrapper::ReadCallback {
990 AsyncSSLSocket::UniquePtr socket,
991 const std::shared_ptr<folly::SSLContext>& ctx,
992 const std::shared_ptr<folly::SSLContext>& sniCtx,
993 const std::string& expectedServerName)
994 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
995 expectedServerName_(expectedServerName) {
996 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
997 std::placeholders::_1));
998 socket_->sslAccept(this);
1001 bool serverNameMatch;
1004 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1007 const AsyncSocketException& ex) noexcept override {
1008 ADD_FAILURE() << "server handshake error: " << ex.what();
1010 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1013 void readDataAvailable(size_t /* len */) noexcept override {}
1014 void readEOF() noexcept override {
1018 const AsyncSocketException& ex) noexcept override {
1019 ADD_FAILURE() << "server read error: " << ex.what();
1022 folly::SSLContext::ServerNameCallbackResult
1023 serverNameCallback(SSL *ssl) {
1024 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1027 !strcasecmp(expectedServerName_.c_str(), sn)) {
1028 AsyncSSLSocket *sslSocket =
1029 AsyncSSLSocket::getFromSSL(ssl);
1030 sslSocket->switchServerSSLContext(sniCtx_);
1031 serverNameMatch = true;
1032 return folly::SSLContext::SERVER_NAME_FOUND;
1034 serverNameMatch = false;
1035 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1039 AsyncSSLSocket::UniquePtr socket_;
1040 std::shared_ptr<folly::SSLContext> sniCtx_;
1041 std::string expectedServerName_;
1045 class SSLClient : public AsyncSocket::ConnectCallback,
1046 public AsyncTransportWrapper::WriteCallback,
1047 public AsyncTransportWrapper::ReadCallback
1050 EventBase *eventBase_;
1051 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1052 SSL_SESSION *session_;
1053 std::shared_ptr<folly::SSLContext> ctx_;
1055 folly::SocketAddress address_;
1059 uint32_t bytesRead_;
1063 uint32_t writeAfterConnectErrors_;
1065 // These settings test that we eventually drain the
1066 // socket, even if the maxReadsPerEvent_ is hit during
1067 // a event loop iteration.
1068 static constexpr size_t kMaxReadsPerEvent = 2;
1069 static constexpr size_t kMaxReadBufferSz =
1070 sizeof(readbuf_) / kMaxReadsPerEvent / 2; // 2 event loop iterations
1073 SSLClient(EventBase *eventBase,
1074 const folly::SocketAddress& address,
1076 uint32_t timeout = 0)
1077 : eventBase_(eventBase),
1079 requests_(requests),
1086 writeAfterConnectErrors_(0) {
1087 ctx_.reset(new folly::SSLContext());
1088 ctx_->setOptions(SSL_OP_NO_TICKET);
1089 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1090 memset(buf_, 'a', sizeof(buf_));
1095 SSL_SESSION_free(session_);
1098 EXPECT_EQ(bytesRead_, sizeof(buf_));
1102 uint32_t getHit() const { return hit_; }
1104 uint32_t getMiss() const { return miss_; }
1106 uint32_t getErrors() const { return errors_; }
1108 uint32_t getWriteAfterConnectErrors() const {
1109 return writeAfterConnectErrors_;
1112 void connect(bool writeNow = false) {
1113 sslSocket_ = AsyncSSLSocket::newSocket(
1115 if (session_ != nullptr) {
1116 sslSocket_->setSSLSession(session_);
1119 sslSocket_->connect(this, address_, timeout_);
1120 if (sslSocket_ && writeNow) {
1121 // write some junk, used in an error test
1122 sslSocket_->write(this, buf_, sizeof(buf_));
1126 void connectSuccess() noexcept override {
1127 std::cerr << "client SSL socket connected" << std::endl;
1128 if (sslSocket_->getSSLSessionReused()) {
1132 if (session_ != nullptr) {
1133 SSL_SESSION_free(session_);
1135 session_ = sslSocket_->getSSLSession();
1139 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1140 sslSocket_->write(this, buf_, sizeof(buf_));
1141 sslSocket_->setReadCB(this);
1142 memset(readbuf_, 'b', sizeof(readbuf_));
1147 const AsyncSocketException& ex) noexcept override {
1148 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1153 void writeSuccess() noexcept override {
1154 std::cerr << "client write success" << std::endl;
1157 void writeErr(size_t /* bytesWritten */,
1158 const AsyncSocketException& ex) noexcept override {
1159 std::cerr << "client writeError: " << ex.what() << std::endl;
1161 writeAfterConnectErrors_++;
1165 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1166 *bufReturn = readbuf_ + bytesRead_;
1167 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1170 void readEOF() noexcept override {
1171 std::cerr << "client readEOF" << std::endl;
1175 const AsyncSocketException& ex) noexcept override {
1176 std::cerr << "client readError: " << ex.what() << std::endl;
1179 void readDataAvailable(size_t len) noexcept override {
1180 std::cerr << "client read data: " << len << std::endl;
1182 if (bytesRead_ == sizeof(buf_)) {
1183 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1184 sslSocket_->closeNow();
1186 if (requests_ != 0) {
1194 class SSLHandshakeBase :
1195 public AsyncSSLSocket::HandshakeCB,
1196 private AsyncTransportWrapper::WriteCallback {
1198 explicit SSLHandshakeBase(
1199 AsyncSSLSocket::UniquePtr socket,
1200 bool preverifyResult,
1201 bool verifyResult) :
1202 handshakeVerify_(false),
1203 handshakeSuccess_(false),
1204 handshakeError_(false),
1205 socket_(std::move(socket)),
1206 preverifyResult_(preverifyResult),
1207 verifyResult_(verifyResult) {
1210 AsyncSSLSocket::UniquePtr moveSocket() && {
1211 return std::move(socket_);
1214 bool handshakeVerify_;
1215 bool handshakeSuccess_;
1216 bool handshakeError_;
1217 std::chrono::nanoseconds handshakeTime;
1220 AsyncSSLSocket::UniquePtr socket_;
1221 bool preverifyResult_;
1224 // HandshakeCallback
1225 bool handshakeVer(AsyncSSLSocket* /* sock */,
1227 X509_STORE_CTX* /* ctx */) noexcept override {
1228 handshakeVerify_ = true;
1230 EXPECT_EQ(preverifyResult_, preverifyOk);
1231 return verifyResult_;
1234 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1235 LOG(INFO) << "Handshake success";
1236 handshakeSuccess_ = true;
1237 handshakeTime = socket_->getHandshakeTime();
1242 const AsyncSocketException& ex) noexcept override {
1243 LOG(INFO) << "Handshake error " << ex.what();
1244 handshakeError_ = true;
1245 handshakeTime = socket_->getHandshakeTime();
1249 void writeSuccess() noexcept override {
1254 size_t bytesWritten,
1255 const AsyncSocketException& ex) noexcept override {
1256 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1261 class SSLHandshakeClient : public SSLHandshakeBase {
1264 AsyncSSLSocket::UniquePtr socket,
1265 bool preverifyResult,
1266 bool verifyResult) :
1267 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1268 socket_->sslConn(this, 0);
1272 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1274 SSLHandshakeClientNoVerify(
1275 AsyncSSLSocket::UniquePtr socket,
1276 bool preverifyResult,
1277 bool verifyResult) :
1278 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1279 socket_->sslConn(this, 0,
1280 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1284 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1286 SSLHandshakeClientDoVerify(
1287 AsyncSSLSocket::UniquePtr socket,
1288 bool preverifyResult,
1289 bool verifyResult) :
1290 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1291 socket_->sslConn(this, 0,
1292 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1296 class SSLHandshakeServer : public SSLHandshakeBase {
1299 AsyncSSLSocket::UniquePtr socket,
1300 bool preverifyResult,
1302 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1303 socket_->sslAccept(this, 0);
1307 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1309 SSLHandshakeServerParseClientHello(
1310 AsyncSSLSocket::UniquePtr socket,
1311 bool preverifyResult,
1313 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1314 socket_->enableClientHelloParsing();
1315 socket_->sslAccept(this, 0);
1318 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1321 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1322 handshakeSuccess_ = true;
1323 sock->getSSLSharedCiphers(sharedCiphers_);
1324 sock->getSSLServerCiphers(serverCiphers_);
1325 sock->getSSLClientCiphers(clientCiphers_);
1326 chosenCipher_ = sock->getNegotiatedCipherName();
1331 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1333 SSLHandshakeServerNoVerify(
1334 AsyncSSLSocket::UniquePtr socket,
1335 bool preverifyResult,
1337 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1338 socket_->sslAccept(this, 0,
1339 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1343 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1345 SSLHandshakeServerDoVerify(
1346 AsyncSSLSocket::UniquePtr socket,
1347 bool preverifyResult,
1349 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1350 socket_->sslAccept(this, 0,
1351 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1355 class EventBaseAborter : public AsyncTimeout {
1357 EventBaseAborter(EventBase* eventBase,
1360 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1361 , eventBase_(eventBase) {
1362 scheduleTimeout(timeoutMS);
1365 void timeoutExpired() noexcept override {
1366 FAIL() << "test timed out";
1367 eventBase_->terminateLoopSoon();
1371 EventBase* eventBase_;