folly copyright 2015 -> copyright 2016
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/SocketAddress.h>
28
29 #include <gtest/gtest.h>
30 #include <iostream>
31 #include <list>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <poll.h>
35 #include <sys/types.h>
36 #include <sys/socket.h>
37 #include <netinet/tcp.h>
38
39 namespace folly {
40
41 enum StateEnum {
42   STATE_WAITING,
43   STATE_SUCCEEDED,
44   STATE_FAILED
45 };
46
47 // The destructors of all callback classes assert that the state is
48 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
49 // are responsible for setting the succeeded state properly before the
50 // destructors are called.
51
52 class WriteCallbackBase :
53 public AsyncTransportWrapper::WriteCallback {
54 public:
55   WriteCallbackBase()
56       : state(STATE_WAITING)
57       , bytesWritten(0)
58       , exception(AsyncSocketException::UNKNOWN, "none") {}
59
60   ~WriteCallbackBase() {
61     EXPECT_EQ(state, STATE_SUCCEEDED);
62   }
63
64   void setSocket(
65     const std::shared_ptr<AsyncSSLSocket> &socket) {
66     socket_ = socket;
67   }
68
69   void writeSuccess() noexcept override {
70     std::cerr << "writeSuccess" << std::endl;
71     state = STATE_SUCCEEDED;
72   }
73
74   void writeErr(
75     size_t bytesWritten,
76     const AsyncSocketException& ex) noexcept override {
77     std::cerr << "writeError: bytesWritten " << bytesWritten
78          << ", exception " << ex.what() << std::endl;
79
80     state = STATE_FAILED;
81     this->bytesWritten = bytesWritten;
82     exception = ex;
83     socket_->close();
84     socket_->detachEventBase();
85   }
86
87   std::shared_ptr<AsyncSSLSocket> socket_;
88   StateEnum state;
89   size_t bytesWritten;
90   AsyncSocketException exception;
91 };
92
93 class ReadCallbackBase :
94 public AsyncTransportWrapper::ReadCallback {
95 public:
96   explicit ReadCallbackBase(WriteCallbackBase *wcb)
97       : wcb_(wcb)
98       , state(STATE_WAITING) {}
99
100   ~ReadCallbackBase() {
101     EXPECT_EQ(state, STATE_SUCCEEDED);
102   }
103
104   void setSocket(
105     const std::shared_ptr<AsyncSSLSocket> &socket) {
106     socket_ = socket;
107   }
108
109   void setState(StateEnum s) {
110     state = s;
111     if (wcb_) {
112       wcb_->state = s;
113     }
114   }
115
116   void readErr(
117     const AsyncSocketException& ex) noexcept override {
118     std::cerr << "readError " << ex.what() << std::endl;
119     state = STATE_FAILED;
120     socket_->close();
121     socket_->detachEventBase();
122   }
123
124   void readEOF() noexcept override {
125     std::cerr << "readEOF" << std::endl;
126
127     socket_->close();
128     socket_->detachEventBase();
129   }
130
131   std::shared_ptr<AsyncSSLSocket> socket_;
132   WriteCallbackBase *wcb_;
133   StateEnum state;
134 };
135
136 class ReadCallback : public ReadCallbackBase {
137 public:
138   explicit ReadCallback(WriteCallbackBase *wcb)
139       : ReadCallbackBase(wcb)
140       , buffers() {}
141
142   ~ReadCallback() {
143     for (std::vector<Buffer>::iterator it = buffers.begin();
144          it != buffers.end();
145          ++it) {
146       it->free();
147     }
148     currentBuffer.free();
149   }
150
151   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
152     if (!currentBuffer.buffer) {
153       currentBuffer.allocate(4096);
154     }
155     *bufReturn = currentBuffer.buffer;
156     *lenReturn = currentBuffer.length;
157   }
158
159   void readDataAvailable(size_t len) noexcept override {
160     std::cerr << "readDataAvailable, len " << len << std::endl;
161
162     currentBuffer.length = len;
163
164     wcb_->setSocket(socket_);
165
166     // Write back the same data.
167     socket_->write(wcb_, currentBuffer.buffer, len);
168
169     buffers.push_back(currentBuffer);
170     currentBuffer.reset();
171     state = STATE_SUCCEEDED;
172   }
173
174   class Buffer {
175   public:
176     Buffer() : buffer(nullptr), length(0) {}
177     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
178
179     void reset() {
180       buffer = nullptr;
181       length = 0;
182     }
183     void allocate(size_t length) {
184       assert(buffer == nullptr);
185       this->buffer = static_cast<char*>(malloc(length));
186       this->length = length;
187     }
188     void free() {
189       ::free(buffer);
190       reset();
191     }
192
193     char* buffer;
194     size_t length;
195   };
196
197   std::vector<Buffer> buffers;
198   Buffer currentBuffer;
199 };
200
201 class ReadErrorCallback : public ReadCallbackBase {
202 public:
203   explicit ReadErrorCallback(WriteCallbackBase *wcb)
204       : ReadCallbackBase(wcb) {}
205
206   // Return nullptr buffer to trigger readError()
207   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
208     *bufReturn = nullptr;
209     *lenReturn = 0;
210   }
211
212   void readDataAvailable(size_t /* len */) noexcept override {
213     // This should never to called.
214     FAIL();
215   }
216
217   void readErr(
218     const AsyncSocketException& ex) noexcept override {
219     ReadCallbackBase::readErr(ex);
220     std::cerr << "ReadErrorCallback::readError" << std::endl;
221     setState(STATE_SUCCEEDED);
222   }
223 };
224
225 class WriteErrorCallback : public ReadCallback {
226 public:
227   explicit WriteErrorCallback(WriteCallbackBase *wcb)
228       : ReadCallback(wcb) {}
229
230   void readDataAvailable(size_t len) noexcept override {
231     std::cerr << "readDataAvailable, len " << len << std::endl;
232
233     currentBuffer.length = len;
234
235     // close the socket before writing to trigger writeError().
236     ::close(socket_->getFd());
237
238     wcb_->setSocket(socket_);
239
240     // Write back the same data.
241     socket_->write(wcb_, currentBuffer.buffer, len);
242
243     if (wcb_->state == STATE_FAILED) {
244       setState(STATE_SUCCEEDED);
245     } else {
246       state = STATE_FAILED;
247     }
248
249     buffers.push_back(currentBuffer);
250     currentBuffer.reset();
251   }
252
253   void readErr(const AsyncSocketException& ex) noexcept override {
254     std::cerr << "readError " << ex.what() << std::endl;
255     // do nothing since this is expected
256   }
257 };
258
259 class EmptyReadCallback : public ReadCallback {
260 public:
261   explicit EmptyReadCallback()
262       : ReadCallback(nullptr) {}
263
264   void readErr(const AsyncSocketException& ex) noexcept override {
265     std::cerr << "readError " << ex.what() << std::endl;
266     state = STATE_FAILED;
267     tcpSocket_->close();
268     tcpSocket_->detachEventBase();
269   }
270
271   void readEOF() noexcept override {
272     std::cerr << "readEOF" << std::endl;
273
274     tcpSocket_->close();
275     tcpSocket_->detachEventBase();
276     state = STATE_SUCCEEDED;
277   }
278
279   std::shared_ptr<AsyncSocket> tcpSocket_;
280 };
281
282 class HandshakeCallback :
283 public AsyncSSLSocket::HandshakeCB {
284 public:
285   enum ExpectType {
286     EXPECT_SUCCESS,
287     EXPECT_ERROR
288   };
289
290   explicit HandshakeCallback(ReadCallbackBase *rcb,
291                              ExpectType expect = EXPECT_SUCCESS):
292       state(STATE_WAITING),
293       rcb_(rcb),
294       expect_(expect) {}
295
296   void setSocket(
297     const std::shared_ptr<AsyncSSLSocket> &socket) {
298     socket_ = socket;
299   }
300
301   void setState(StateEnum s) {
302     state = s;
303     rcb_->setState(s);
304   }
305
306   // Functions inherited from AsyncSSLSocketHandshakeCallback
307   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
308     EXPECT_EQ(sock, socket_.get());
309     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
310     rcb_->setSocket(socket_);
311     sock->setReadCB(rcb_);
312     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
313   }
314   void handshakeErr(AsyncSSLSocket* /* sock */,
315                     const AsyncSocketException& ex) noexcept override {
316     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
317     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
318     if (expect_ == EXPECT_ERROR) {
319       // rcb will never be invoked
320       rcb_->setState(STATE_SUCCEEDED);
321     }
322   }
323
324   ~HandshakeCallback() {
325     EXPECT_EQ(state, STATE_SUCCEEDED);
326   }
327
328   void closeSocket() {
329     socket_->close();
330     state = STATE_SUCCEEDED;
331   }
332
333   StateEnum state;
334   std::shared_ptr<AsyncSSLSocket> socket_;
335   ReadCallbackBase *rcb_;
336   ExpectType expect_;
337 };
338
339 class SSLServerAcceptCallbackBase:
340 public folly::AsyncServerSocket::AcceptCallback {
341 public:
342   explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
343   state(STATE_WAITING), hcb_(hcb) {}
344
345   ~SSLServerAcceptCallbackBase() {
346     EXPECT_EQ(state, STATE_SUCCEEDED);
347   }
348
349   void acceptError(const std::exception& ex) noexcept override {
350     std::cerr << "SSLServerAcceptCallbackBase::acceptError "
351               << ex.what() << std::endl;
352     state = STATE_FAILED;
353   }
354
355   void connectionAccepted(
356       int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
357     printf("Connection accepted\n");
358     std::shared_ptr<AsyncSSLSocket> sslSock;
359     try {
360       // Create a AsyncSSLSocket object with the fd. The socket should be
361       // added to the event base and in the state of accepting SSL connection.
362       sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
363     } catch (const std::exception &e) {
364       LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
365         "object with socket " << e.what() << fd;
366       ::close(fd);
367       acceptError(e);
368       return;
369     }
370
371     connAccepted(sslSock);
372   }
373
374   virtual void connAccepted(
375     const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
376
377   StateEnum state;
378   HandshakeCallback *hcb_;
379   std::shared_ptr<folly::SSLContext> ctx_;
380   folly::EventBase* base_;
381 };
382
383 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
384 public:
385   uint32_t timeout_;
386
387   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
388                                    uint32_t timeout = 0):
389       SSLServerAcceptCallbackBase(hcb),
390       timeout_(timeout) {}
391
392   virtual ~SSLServerAcceptCallback() {
393     if (timeout_ > 0) {
394       // if we set a timeout, we expect failure
395       EXPECT_EQ(hcb_->state, STATE_FAILED);
396       hcb_->setState(STATE_SUCCEEDED);
397     }
398   }
399
400   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
401   void connAccepted(
402     const std::shared_ptr<folly::AsyncSSLSocket> &s)
403     noexcept override {
404     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
405     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
406
407     hcb_->setSocket(sock);
408     sock->sslAccept(hcb_, timeout_);
409     EXPECT_EQ(sock->getSSLState(),
410                       AsyncSSLSocket::STATE_ACCEPTING);
411
412     state = STATE_SUCCEEDED;
413   }
414 };
415
416 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
417 public:
418   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
419       SSLServerAcceptCallback(hcb) {}
420
421   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
422   void connAccepted(
423     const std::shared_ptr<folly::AsyncSSLSocket> &s)
424     noexcept override {
425
426     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
427
428     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
429               << std::endl;
430     int fd = sock->getFd();
431
432 #ifndef TCP_NOPUSH
433     {
434     // The accepted connection should already have TCP_NODELAY set
435     int value;
436     socklen_t valueLength = sizeof(value);
437     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
438     EXPECT_EQ(rc, 0);
439     EXPECT_EQ(value, 1);
440     }
441 #endif
442
443     // Unset the TCP_NODELAY option.
444     int value = 0;
445     socklen_t valueLength = sizeof(value);
446     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
447     EXPECT_EQ(rc, 0);
448
449     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
450     EXPECT_EQ(rc, 0);
451     EXPECT_EQ(value, 0);
452
453     SSLServerAcceptCallback::connAccepted(sock);
454   }
455 };
456
457 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
458 public:
459   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
460                                              uint32_t timeout = 0):
461     SSLServerAcceptCallback(hcb, timeout) {}
462
463   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
464   void connAccepted(
465     const std::shared_ptr<folly::AsyncSSLSocket> &s)
466     noexcept override {
467     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
468
469     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
470
471     hcb_->setSocket(sock);
472     sock->sslAccept(hcb_, timeout_);
473     ASSERT_TRUE((sock->getSSLState() ==
474                  AsyncSSLSocket::STATE_ACCEPTING) ||
475                 (sock->getSSLState() ==
476                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
477
478     state = STATE_SUCCEEDED;
479   }
480 };
481
482
483 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
484 public:
485   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
486   SSLServerAcceptCallbackBase(hcb)  {}
487
488   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
489   void connAccepted(
490     const std::shared_ptr<folly::AsyncSSLSocket> &s)
491     noexcept override {
492     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
493
494     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
495
496     // The first call to sslAccept() should succeed.
497     hcb_->setSocket(sock);
498     sock->sslAccept(hcb_);
499     EXPECT_EQ(sock->getSSLState(),
500                       AsyncSSLSocket::STATE_ACCEPTING);
501
502     // The second call to sslAccept() should fail.
503     HandshakeCallback callback2(hcb_->rcb_);
504     callback2.setSocket(sock);
505     sock->sslAccept(&callback2);
506     EXPECT_EQ(sock->getSSLState(),
507                       AsyncSSLSocket::STATE_ERROR);
508
509     // Both callbacks should be in the error state.
510     EXPECT_EQ(hcb_->state, STATE_FAILED);
511     EXPECT_EQ(callback2.state, STATE_FAILED);
512
513     sock->detachEventBase();
514
515     state = STATE_SUCCEEDED;
516     hcb_->setState(STATE_SUCCEEDED);
517     callback2.setState(STATE_SUCCEEDED);
518   }
519 };
520
521 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
522 public:
523   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
524   SSLServerAcceptCallbackBase(hcb)  {}
525
526   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
527   void connAccepted(
528     const std::shared_ptr<folly::AsyncSSLSocket> &s)
529     noexcept override {
530     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
531
532     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
533
534     hcb_->setSocket(sock);
535     sock->getEventBase()->tryRunAfterDelay([=] {
536         std::cerr << "Delayed SSL accept, client will have close by now"
537                   << std::endl;
538         // SSL accept will fail
539         EXPECT_EQ(
540           sock->getSSLState(),
541           AsyncSSLSocket::STATE_UNINIT);
542         hcb_->socket_->sslAccept(hcb_);
543         // This registers for an event
544         EXPECT_EQ(
545           sock->getSSLState(),
546           AsyncSSLSocket::STATE_ACCEPTING);
547
548         state = STATE_SUCCEEDED;
549       }, 100);
550   }
551 };
552
553
554 class TestSSLServer {
555  protected:
556   EventBase evb_;
557   std::shared_ptr<folly::SSLContext> ctx_;
558   SSLServerAcceptCallbackBase *acb_;
559   std::shared_ptr<folly::AsyncServerSocket> socket_;
560   folly::SocketAddress address_;
561   pthread_t thread_;
562
563   static void *Main(void *ctx) {
564     TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
565     self->evb_.loop();
566     std::cerr << "Server thread exited event loop" << std::endl;
567     return nullptr;
568   }
569
570  public:
571   // Create a TestSSLServer.
572   // This immediately starts listening on the given port.
573   explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
574
575   // Kill the thread.
576   ~TestSSLServer() {
577     evb_.runInEventBaseThread([&](){
578       socket_->stopAccepting();
579     });
580     std::cerr << "Waiting for server thread to exit" << std::endl;
581     pthread_join(thread_, nullptr);
582   }
583
584   EventBase &getEventBase() { return evb_; }
585
586   const folly::SocketAddress& getAddress() const {
587     return address_;
588   }
589 };
590
591 class TestSSLAsyncCacheServer : public TestSSLServer {
592  public:
593   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
594         int lookupDelay = 100) :
595       TestSSLServer(acb) {
596     SSL_CTX *sslCtx = ctx_->getSSLCtx();
597     SSL_CTX_sess_set_get_cb(sslCtx,
598                             TestSSLAsyncCacheServer::getSessionCallback);
599     SSL_CTX_set_session_cache_mode(
600       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
601     asyncCallbacks_ = 0;
602     asyncLookups_ = 0;
603     lookupDelay_ = lookupDelay;
604   }
605
606   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
607   uint32_t getAsyncLookups() const { return asyncLookups_; }
608
609  private:
610   static uint32_t asyncCallbacks_;
611   static uint32_t asyncLookups_;
612   static uint32_t lookupDelay_;
613
614   static SSL_SESSION* getSessionCallback(SSL* ssl,
615                                          unsigned char* /* sess_id */,
616                                          int /* id_len */,
617                                          int* copyflag) {
618     *copyflag = 0;
619     asyncCallbacks_++;
620 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
621     if (!SSL_want_sess_cache_lookup(ssl)) {
622       // libssl.so mismatch
623       std::cerr << "no async support" << std::endl;
624       return nullptr;
625     }
626
627     AsyncSSLSocket *sslSocket =
628         AsyncSSLSocket::getFromSSL(ssl);
629     assert(sslSocket != nullptr);
630     // Going to simulate an async cache by just running delaying the miss 100ms
631     if (asyncCallbacks_ % 2 == 0) {
632       // This socket is already blocked on lookup, return miss
633       std::cerr << "returning miss" << std::endl;
634     } else {
635       // fresh meat - block it
636       std::cerr << "async lookup" << std::endl;
637       sslSocket->getEventBase()->tryRunAfterDelay(
638         std::bind(&AsyncSSLSocket::restartSSLAccept,
639                   sslSocket), lookupDelay_);
640       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
641       asyncLookups_++;
642     }
643 #endif
644     return nullptr;
645   }
646 };
647
648 void getfds(int fds[2]);
649
650 void getctx(
651   std::shared_ptr<folly::SSLContext> clientCtx,
652   std::shared_ptr<folly::SSLContext> serverCtx);
653
654 void sslsocketpair(
655   EventBase* eventBase,
656   AsyncSSLSocket::UniquePtr* clientSock,
657   AsyncSSLSocket::UniquePtr* serverSock);
658
659 class BlockingWriteClient :
660   private AsyncSSLSocket::HandshakeCB,
661   private AsyncTransportWrapper::WriteCallback {
662  public:
663   explicit BlockingWriteClient(
664     AsyncSSLSocket::UniquePtr socket)
665     : socket_(std::move(socket)),
666       bufLen_(2500),
667       iovCount_(2000) {
668     // Fill buf_
669     buf_.reset(new uint8_t[bufLen_]);
670     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
671       buf_[n] = n % 0xff;
672     }
673
674     // Initialize iov_
675     iov_.reset(new struct iovec[iovCount_]);
676     for (uint32_t n = 0; n < iovCount_; ++n) {
677       iov_[n].iov_base = buf_.get() + n;
678       if (n & 0x1) {
679         iov_[n].iov_len = n % bufLen_;
680       } else {
681         iov_[n].iov_len = bufLen_ - (n % bufLen_);
682       }
683     }
684
685     socket_->sslConn(this, 100);
686   }
687
688   struct iovec* getIovec() const {
689     return iov_.get();
690   }
691   uint32_t getIovecCount() const {
692     return iovCount_;
693   }
694
695  private:
696   void handshakeSuc(AsyncSSLSocket*) noexcept override {
697     socket_->writev(this, iov_.get(), iovCount_);
698   }
699   void handshakeErr(
700     AsyncSSLSocket*,
701     const AsyncSocketException& ex) noexcept override {
702     ADD_FAILURE() << "client handshake error: " << ex.what();
703   }
704   void writeSuccess() noexcept override {
705     socket_->close();
706   }
707   void writeErr(
708     size_t bytesWritten,
709     const AsyncSocketException& ex) noexcept override {
710     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
711                   << ex.what();
712   }
713
714   AsyncSSLSocket::UniquePtr socket_;
715   uint32_t bufLen_;
716   uint32_t iovCount_;
717   std::unique_ptr<uint8_t[]> buf_;
718   std::unique_ptr<struct iovec[]> iov_;
719 };
720
721 class BlockingWriteServer :
722     private AsyncSSLSocket::HandshakeCB,
723     private AsyncTransportWrapper::ReadCallback {
724  public:
725   explicit BlockingWriteServer(
726     AsyncSSLSocket::UniquePtr socket)
727     : socket_(std::move(socket)),
728       bufSize_(2500 * 2000),
729       bytesRead_(0) {
730     buf_.reset(new uint8_t[bufSize_]);
731     socket_->sslAccept(this, 100);
732   }
733
734   void checkBuffer(struct iovec* iov, uint32_t count) const {
735     uint32_t idx = 0;
736     for (uint32_t n = 0; n < count; ++n) {
737       size_t bytesLeft = bytesRead_ - idx;
738       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
739                       std::min(iov[n].iov_len, bytesLeft));
740       if (rc != 0) {
741         FAIL() << "buffer mismatch at iovec " << n << "/" << count
742                << ": rc=" << rc;
743
744       }
745       if (iov[n].iov_len > bytesLeft) {
746         FAIL() << "server did not read enough data: "
747                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
748                << " in iovec " << n << "/" << count;
749       }
750
751       idx += iov[n].iov_len;
752     }
753     if (idx != bytesRead_) {
754       ADD_FAILURE() << "server read extra data: " << bytesRead_
755                     << " bytes read; expected " << idx;
756     }
757   }
758
759  private:
760   void handshakeSuc(AsyncSSLSocket*) noexcept override {
761     // Wait 10ms before reading, so the client's writes will initially block.
762     socket_->getEventBase()->tryRunAfterDelay(
763         [this] { socket_->setReadCB(this); }, 10);
764   }
765   void handshakeErr(
766     AsyncSSLSocket*,
767     const AsyncSocketException& ex) noexcept override {
768     ADD_FAILURE() << "server handshake error: " << ex.what();
769   }
770   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
771     *bufReturn = buf_.get() + bytesRead_;
772     *lenReturn = bufSize_ - bytesRead_;
773   }
774   void readDataAvailable(size_t len) noexcept override {
775     bytesRead_ += len;
776     socket_->setReadCB(nullptr);
777     socket_->getEventBase()->tryRunAfterDelay(
778         [this] { socket_->setReadCB(this); }, 2);
779   }
780   void readEOF() noexcept override {
781     socket_->close();
782   }
783   void readErr(
784     const AsyncSocketException& ex) noexcept override {
785     ADD_FAILURE() << "server read error: " << ex.what();
786   }
787
788   AsyncSSLSocket::UniquePtr socket_;
789   uint32_t bufSize_;
790   uint32_t bytesRead_;
791   std::unique_ptr<uint8_t[]> buf_;
792 };
793
794 class NpnClient :
795   private AsyncSSLSocket::HandshakeCB,
796   private AsyncTransportWrapper::WriteCallback {
797  public:
798   explicit NpnClient(
799     AsyncSSLSocket::UniquePtr socket)
800       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
801     socket_->sslConn(this);
802   }
803
804   const unsigned char* nextProto;
805   unsigned nextProtoLength;
806   SSLContext::NextProtocolType protocolType;
807
808  private:
809   void handshakeSuc(AsyncSSLSocket*) noexcept override {
810     socket_->getSelectedNextProtocol(
811         &nextProto, &nextProtoLength, &protocolType);
812   }
813   void handshakeErr(
814     AsyncSSLSocket*,
815     const AsyncSocketException& ex) noexcept override {
816     ADD_FAILURE() << "client handshake error: " << ex.what();
817   }
818   void writeSuccess() noexcept override {
819     socket_->close();
820   }
821   void writeErr(
822     size_t bytesWritten,
823     const AsyncSocketException& ex) noexcept override {
824     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
825                   << ex.what();
826   }
827
828   AsyncSSLSocket::UniquePtr socket_;
829 };
830
831 class NpnServer :
832     private AsyncSSLSocket::HandshakeCB,
833     private AsyncTransportWrapper::ReadCallback {
834  public:
835   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
836       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
837     socket_->sslAccept(this);
838   }
839
840   const unsigned char* nextProto;
841   unsigned nextProtoLength;
842   SSLContext::NextProtocolType protocolType;
843
844  private:
845   void handshakeSuc(AsyncSSLSocket*) noexcept override {
846     socket_->getSelectedNextProtocol(
847         &nextProto, &nextProtoLength, &protocolType);
848   }
849   void handshakeErr(
850     AsyncSSLSocket*,
851     const AsyncSocketException& ex) noexcept override {
852     ADD_FAILURE() << "server handshake error: " << ex.what();
853   }
854   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
855     *lenReturn = 0;
856   }
857   void readDataAvailable(size_t /* len */) noexcept override {}
858   void readEOF() noexcept override {
859     socket_->close();
860   }
861   void readErr(
862     const AsyncSocketException& ex) noexcept override {
863     ADD_FAILURE() << "server read error: " << ex.what();
864   }
865
866   AsyncSSLSocket::UniquePtr socket_;
867 };
868
869 #ifndef OPENSSL_NO_TLSEXT
870 class SNIClient :
871   private AsyncSSLSocket::HandshakeCB,
872   private AsyncTransportWrapper::WriteCallback {
873  public:
874   explicit SNIClient(
875     AsyncSSLSocket::UniquePtr socket)
876       : serverNameMatch(false), socket_(std::move(socket)) {
877     socket_->sslConn(this);
878   }
879
880   bool serverNameMatch;
881
882  private:
883   void handshakeSuc(AsyncSSLSocket*) noexcept override {
884     serverNameMatch = socket_->isServerNameMatch();
885   }
886   void handshakeErr(
887     AsyncSSLSocket*,
888     const AsyncSocketException& ex) noexcept override {
889     ADD_FAILURE() << "client handshake error: " << ex.what();
890   }
891   void writeSuccess() noexcept override {
892     socket_->close();
893   }
894   void writeErr(
895     size_t bytesWritten,
896     const AsyncSocketException& ex) noexcept override {
897     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
898                   << ex.what();
899   }
900
901   AsyncSSLSocket::UniquePtr socket_;
902 };
903
904 class SNIServer :
905     private AsyncSSLSocket::HandshakeCB,
906     private AsyncTransportWrapper::ReadCallback {
907  public:
908   explicit SNIServer(
909     AsyncSSLSocket::UniquePtr socket,
910     const std::shared_ptr<folly::SSLContext>& ctx,
911     const std::shared_ptr<folly::SSLContext>& sniCtx,
912     const std::string& expectedServerName)
913       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
914         expectedServerName_(expectedServerName) {
915     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
916                                          std::placeholders::_1));
917     socket_->sslAccept(this);
918   }
919
920   bool serverNameMatch;
921
922  private:
923   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
924   void handshakeErr(
925     AsyncSSLSocket*,
926     const AsyncSocketException& ex) noexcept override {
927     ADD_FAILURE() << "server handshake error: " << ex.what();
928   }
929   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
930     *lenReturn = 0;
931   }
932   void readDataAvailable(size_t /* len */) noexcept override {}
933   void readEOF() noexcept override {
934     socket_->close();
935   }
936   void readErr(
937     const AsyncSocketException& ex) noexcept override {
938     ADD_FAILURE() << "server read error: " << ex.what();
939   }
940
941   folly::SSLContext::ServerNameCallbackResult
942     serverNameCallback(SSL *ssl) {
943     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
944     if (sniCtx_ &&
945         sn &&
946         !strcasecmp(expectedServerName_.c_str(), sn)) {
947       AsyncSSLSocket *sslSocket =
948           AsyncSSLSocket::getFromSSL(ssl);
949       sslSocket->switchServerSSLContext(sniCtx_);
950       serverNameMatch = true;
951       return folly::SSLContext::SERVER_NAME_FOUND;
952     } else {
953       serverNameMatch = false;
954       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
955     }
956   }
957
958   AsyncSSLSocket::UniquePtr socket_;
959   std::shared_ptr<folly::SSLContext> sniCtx_;
960   std::string expectedServerName_;
961 };
962 #endif
963
964 class SSLClient : public AsyncSocket::ConnectCallback,
965                   public AsyncTransportWrapper::WriteCallback,
966                   public AsyncTransportWrapper::ReadCallback
967 {
968  private:
969   EventBase *eventBase_;
970   std::shared_ptr<AsyncSSLSocket> sslSocket_;
971   SSL_SESSION *session_;
972   std::shared_ptr<folly::SSLContext> ctx_;
973   uint32_t requests_;
974   folly::SocketAddress address_;
975   uint32_t timeout_;
976   char buf_[128];
977   char readbuf_[128];
978   uint32_t bytesRead_;
979   uint32_t hit_;
980   uint32_t miss_;
981   uint32_t errors_;
982   uint32_t writeAfterConnectErrors_;
983
984   // These settings test that we eventually drain the
985   // socket, even if the maxReadsPerEvent_ is hit during
986   // a event loop iteration.
987   static constexpr size_t kMaxReadsPerEvent = 2;
988   static constexpr size_t kMaxReadBufferSz =
989     sizeof(readbuf_) / kMaxReadsPerEvent / 2;  // 2 event loop iterations
990
991  public:
992   SSLClient(EventBase *eventBase,
993             const folly::SocketAddress& address,
994             uint32_t requests,
995             uint32_t timeout = 0)
996       : eventBase_(eventBase),
997         session_(nullptr),
998         requests_(requests),
999         address_(address),
1000         timeout_(timeout),
1001         bytesRead_(0),
1002         hit_(0),
1003         miss_(0),
1004         errors_(0),
1005         writeAfterConnectErrors_(0) {
1006     ctx_.reset(new folly::SSLContext());
1007     ctx_->setOptions(SSL_OP_NO_TICKET);
1008     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1009     memset(buf_, 'a', sizeof(buf_));
1010   }
1011
1012   ~SSLClient() {
1013     if (session_) {
1014       SSL_SESSION_free(session_);
1015     }
1016     if (errors_ == 0) {
1017       EXPECT_EQ(bytesRead_, sizeof(buf_));
1018     }
1019   }
1020
1021   uint32_t getHit() const { return hit_; }
1022
1023   uint32_t getMiss() const { return miss_; }
1024
1025   uint32_t getErrors() const { return errors_; }
1026
1027   uint32_t getWriteAfterConnectErrors() const {
1028     return writeAfterConnectErrors_;
1029   }
1030
1031   void connect(bool writeNow = false) {
1032     sslSocket_ = AsyncSSLSocket::newSocket(
1033       ctx_, eventBase_);
1034     if (session_ != nullptr) {
1035       sslSocket_->setSSLSession(session_);
1036     }
1037     requests_--;
1038     sslSocket_->connect(this, address_, timeout_);
1039     if (sslSocket_ && writeNow) {
1040       // write some junk, used in an error test
1041       sslSocket_->write(this, buf_, sizeof(buf_));
1042     }
1043   }
1044
1045   void connectSuccess() noexcept override {
1046     std::cerr << "client SSL socket connected" << std::endl;
1047     if (sslSocket_->getSSLSessionReused()) {
1048       hit_++;
1049     } else {
1050       miss_++;
1051       if (session_ != nullptr) {
1052         SSL_SESSION_free(session_);
1053       }
1054       session_ = sslSocket_->getSSLSession();
1055     }
1056
1057     // write()
1058     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1059     sslSocket_->write(this, buf_, sizeof(buf_));
1060     sslSocket_->setReadCB(this);
1061     memset(readbuf_, 'b', sizeof(readbuf_));
1062     bytesRead_ = 0;
1063   }
1064
1065   void connectErr(
1066     const AsyncSocketException& ex) noexcept override {
1067     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1068     errors_++;
1069     sslSocket_.reset();
1070   }
1071
1072   void writeSuccess() noexcept override {
1073     std::cerr << "client write success" << std::endl;
1074   }
1075
1076   void writeErr(size_t /* bytesWritten */,
1077                 const AsyncSocketException& ex) noexcept override {
1078     std::cerr << "client writeError: " << ex.what() << std::endl;
1079     if (!sslSocket_) {
1080       writeAfterConnectErrors_++;
1081     }
1082   }
1083
1084   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1085     *bufReturn = readbuf_ + bytesRead_;
1086     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1087   }
1088
1089   void readEOF() noexcept override {
1090     std::cerr << "client readEOF" << std::endl;
1091   }
1092
1093   void readErr(
1094     const AsyncSocketException& ex) noexcept override {
1095     std::cerr << "client readError: " << ex.what() << std::endl;
1096   }
1097
1098   void readDataAvailable(size_t len) noexcept override {
1099     std::cerr << "client read data: " << len << std::endl;
1100     bytesRead_ += len;
1101     if (bytesRead_ == sizeof(buf_)) {
1102       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1103       sslSocket_->closeNow();
1104       sslSocket_.reset();
1105       if (requests_ != 0) {
1106         connect();
1107       }
1108     }
1109   }
1110
1111 };
1112
1113 class SSLHandshakeBase :
1114   public AsyncSSLSocket::HandshakeCB,
1115   private AsyncTransportWrapper::WriteCallback {
1116  public:
1117   explicit SSLHandshakeBase(
1118    AsyncSSLSocket::UniquePtr socket,
1119    bool preverifyResult,
1120    bool verifyResult) :
1121     handshakeVerify_(false),
1122     handshakeSuccess_(false),
1123     handshakeError_(false),
1124     socket_(std::move(socket)),
1125     preverifyResult_(preverifyResult),
1126     verifyResult_(verifyResult) {
1127   }
1128
1129   bool handshakeVerify_;
1130   bool handshakeSuccess_;
1131   bool handshakeError_;
1132   std::chrono::nanoseconds handshakeTime;
1133
1134  protected:
1135   AsyncSSLSocket::UniquePtr socket_;
1136   bool preverifyResult_;
1137   bool verifyResult_;
1138
1139   // HandshakeCallback
1140   bool handshakeVer(AsyncSSLSocket* /* sock */,
1141                     bool preverifyOk,
1142                     X509_STORE_CTX* /* ctx */) noexcept override {
1143     handshakeVerify_ = true;
1144
1145     EXPECT_EQ(preverifyResult_, preverifyOk);
1146     return verifyResult_;
1147   }
1148
1149   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1150     handshakeSuccess_ = true;
1151     handshakeTime = socket_->getHandshakeTime();
1152   }
1153
1154   void handshakeErr(AsyncSSLSocket*,
1155                     const AsyncSocketException& /* ex */) noexcept override {
1156     handshakeError_ = true;
1157     handshakeTime = socket_->getHandshakeTime();
1158   }
1159
1160   // WriteCallback
1161   void writeSuccess() noexcept override {
1162     socket_->close();
1163   }
1164
1165   void writeErr(
1166    size_t bytesWritten,
1167    const AsyncSocketException& ex) noexcept override {
1168     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1169                   << ex.what();
1170   }
1171 };
1172
1173 class SSLHandshakeClient : public SSLHandshakeBase {
1174  public:
1175   SSLHandshakeClient(
1176    AsyncSSLSocket::UniquePtr socket,
1177    bool preverifyResult,
1178    bool verifyResult) :
1179     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1180     socket_->sslConn(this, 0);
1181   }
1182 };
1183
1184 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1185  public:
1186   SSLHandshakeClientNoVerify(
1187    AsyncSSLSocket::UniquePtr socket,
1188    bool preverifyResult,
1189    bool verifyResult) :
1190     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1191     socket_->sslConn(this, 0,
1192       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1193   }
1194 };
1195
1196 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1197  public:
1198   SSLHandshakeClientDoVerify(
1199    AsyncSSLSocket::UniquePtr socket,
1200    bool preverifyResult,
1201    bool verifyResult) :
1202     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1203     socket_->sslConn(this, 0,
1204       folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1205   }
1206 };
1207
1208 class SSLHandshakeServer : public SSLHandshakeBase {
1209  public:
1210   SSLHandshakeServer(
1211       AsyncSSLSocket::UniquePtr socket,
1212       bool preverifyResult,
1213       bool verifyResult)
1214     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1215     socket_->sslAccept(this, 0);
1216   }
1217 };
1218
1219 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1220  public:
1221   SSLHandshakeServerParseClientHello(
1222       AsyncSSLSocket::UniquePtr socket,
1223       bool preverifyResult,
1224       bool verifyResult)
1225       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1226     socket_->enableClientHelloParsing();
1227     socket_->sslAccept(this, 0);
1228   }
1229
1230   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1231
1232  protected:
1233   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1234     handshakeSuccess_ = true;
1235     sock->getSSLSharedCiphers(sharedCiphers_);
1236     sock->getSSLServerCiphers(serverCiphers_);
1237     sock->getSSLClientCiphers(clientCiphers_);
1238     chosenCipher_ = sock->getNegotiatedCipherName();
1239   }
1240 };
1241
1242
1243 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1244  public:
1245   SSLHandshakeServerNoVerify(
1246       AsyncSSLSocket::UniquePtr socket,
1247       bool preverifyResult,
1248       bool verifyResult)
1249     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1250     socket_->sslAccept(this, 0,
1251       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1252   }
1253 };
1254
1255 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1256  public:
1257   SSLHandshakeServerDoVerify(
1258       AsyncSSLSocket::UniquePtr socket,
1259       bool preverifyResult,
1260       bool verifyResult)
1261     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1262     socket_->sslAccept(this, 0,
1263       folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1264   }
1265 };
1266
1267 class EventBaseAborter : public AsyncTimeout {
1268  public:
1269   EventBaseAborter(EventBase* eventBase,
1270                    uint32_t timeoutMS)
1271     : AsyncTimeout(
1272       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1273     , eventBase_(eventBase) {
1274     scheduleTimeout(timeoutMS);
1275   }
1276
1277   void timeoutExpired() noexcept override {
1278     FAIL() << "test timed out";
1279     eventBase_->terminateLoopSoon();
1280   }
1281
1282  private:
1283   EventBase* eventBase_;
1284 };
1285
1286 }