Refer to nullptr not NULL
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <folly/SocketAddress.h>
19 #include <folly/io/Cursor.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/portability/GMock.h>
23 #include <folly/portability/GTest.h>
24 #include <folly/portability/OpenSSL.h>
25 #include <folly/portability/Sockets.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <folly/io/async/test/BlockingSocket.h>
29
30 #include <fcntl.h>
31 #include <signal.h>
32 #include <sys/types.h>
33 #include <sys/utsname.h>
34
35 #include <fstream>
36 #include <iostream>
37 #include <list>
38 #include <set>
39 #include <thread>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 using namespace testing;
49
50 namespace folly {
51 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
52 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
53 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 void getfds(int fds[2]) {
59   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
60     FAIL() << "failed to create socketpair: " << strerror(errno);
61   }
62   for (int idx = 0; idx < 2; ++idx) {
63     int flags = fcntl(fds[idx], F_GETFL, 0);
64     if (flags == -1) {
65       FAIL() << "failed to get flags for socket " << idx << ": "
66              << strerror(errno);
67     }
68     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
69       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
70              << strerror(errno);
71     }
72   }
73 }
74
75 void getctx(
76   std::shared_ptr<folly::SSLContext> clientCtx,
77   std::shared_ptr<folly::SSLContext> serverCtx) {
78   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
79
80   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
81   serverCtx->loadCertificate(kTestCert);
82   serverCtx->loadPrivateKey(kTestKey);
83 }
84
85 void sslsocketpair(
86   EventBase* eventBase,
87   AsyncSSLSocket::UniquePtr* clientSock,
88   AsyncSSLSocket::UniquePtr* serverSock) {
89   auto clientCtx = std::make_shared<folly::SSLContext>();
90   auto serverCtx = std::make_shared<folly::SSLContext>();
91   int fds[2];
92   getfds(fds);
93   getctx(clientCtx, serverCtx);
94   clientSock->reset(new AsyncSSLSocket(
95                       clientCtx, eventBase, fds[0], false));
96   serverSock->reset(new AsyncSSLSocket(
97                       serverCtx, eventBase, fds[1], true));
98
99   // (*clientSock)->setSendTimeout(100);
100   // (*serverSock)->setSendTimeout(100);
101 }
102
103 // client protocol filters
104 bool clientProtoFilterPickPony(unsigned char** client,
105   unsigned int* client_len, const unsigned char*, unsigned int ) {
106   //the protocol string in length prefixed byte string. the
107   //length byte is not included in the length
108   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
109   *client = p;
110   *client_len = 7;
111   return true;
112 }
113
114 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
115   const unsigned char*, unsigned int) {
116   return false;
117 }
118
119 std::string getFileAsBuf(const char* fileName) {
120   std::string buffer;
121   folly::readFile(fileName, buffer);
122   return buffer;
123 }
124
125 std::string getCommonName(X509* cert) {
126   X509_NAME* subject = X509_get_subject_name(cert);
127   std::string cn;
128   cn.resize(ub_common_name);
129   X509_NAME_get_text_by_NID(
130       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
131   return cn;
132 }
133
134 /**
135  * Test connecting to, writing to, reading from, and closing the
136  * connection to the SSL server.
137  */
138 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
139   // Start listening on a local port
140   WriteCallbackBase writeCallback;
141   ReadCallback readCallback(&writeCallback);
142   HandshakeCallback handshakeCallback(&readCallback);
143   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
144   TestSSLServer server(&acceptCallback);
145
146   // Set up SSL context.
147   std::shared_ptr<SSLContext> sslContext(new SSLContext());
148   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
149   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
150   //sslContext->authenticate(true, false);
151
152   // connect
153   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
154                                                  sslContext);
155   socket->open(std::chrono::milliseconds(10000));
156
157   // write()
158   uint8_t buf[128];
159   memset(buf, 'a', sizeof(buf));
160   socket->write(buf, sizeof(buf));
161
162   // read()
163   uint8_t readbuf[128];
164   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
165   EXPECT_EQ(bytesRead, 128);
166   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
167
168   // close()
169   socket->close();
170
171   cerr << "ConnectWriteReadClose test completed" << endl;
172   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
173 }
174
175 /**
176  * Test reading after server close.
177  */
178 TEST(AsyncSSLSocketTest, ReadAfterClose) {
179   // Start listening on a local port
180   WriteCallbackBase writeCallback;
181   ReadEOFCallback readCallback(&writeCallback);
182   HandshakeCallback handshakeCallback(&readCallback);
183   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
184   auto server = std::make_unique<TestSSLServer>(&acceptCallback);
185
186   // Set up SSL context.
187   auto sslContext = std::make_shared<SSLContext>();
188   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
189
190   auto socket =
191       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
192   socket->open();
193
194   // This should trigger an EOF on the client.
195   auto evb = handshakeCallback.getSocket()->getEventBase();
196   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
197   std::array<uint8_t, 128> readbuf;
198   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
199   EXPECT_EQ(0, bytesRead);
200 }
201
202 /**
203  * Test bad renegotiation
204  */
205 #if !defined(OPENSSL_IS_BORINGSSL)
206 TEST(AsyncSSLSocketTest, Renegotiate) {
207   EventBase eventBase;
208   auto clientCtx = std::make_shared<SSLContext>();
209   auto dfServerCtx = std::make_shared<SSLContext>();
210   std::array<int, 2> fds;
211   getfds(fds.data());
212   getctx(clientCtx, dfServerCtx);
213
214   AsyncSSLSocket::UniquePtr clientSock(
215       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
216   AsyncSSLSocket::UniquePtr serverSock(
217       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
218   SSLHandshakeClient client(std::move(clientSock), true, true);
219   RenegotiatingServer server(std::move(serverSock));
220
221   while (!client.handshakeSuccess_ && !client.handshakeError_) {
222     eventBase.loopOnce();
223   }
224
225   ASSERT_TRUE(client.handshakeSuccess_);
226
227   auto sslSock = std::move(client).moveSocket();
228   sslSock->detachEventBase();
229   // This is nasty, however we don't want to add support for
230   // renegotiation in AsyncSSLSocket.
231   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
232
233   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
234
235   std::thread t([&]() { eventBase.loopForever(); });
236
237   // Trigger the renegotiation.
238   std::array<uint8_t, 128> buf;
239   memset(buf.data(), 'a', buf.size());
240   try {
241     socket->write(buf.data(), buf.size());
242   } catch (AsyncSocketException& e) {
243     LOG(INFO) << "client got error " << e.what();
244   }
245   eventBase.terminateLoopSoon();
246   t.join();
247
248   eventBase.loop();
249   ASSERT_TRUE(server.renegotiationError_);
250 }
251 #endif
252
253 /**
254  * Negative test for handshakeError().
255  */
256 TEST(AsyncSSLSocketTest, HandshakeError) {
257   // Start listening on a local port
258   WriteCallbackBase writeCallback;
259   WriteErrorCallback readCallback(&writeCallback);
260   HandshakeCallback handshakeCallback(&readCallback);
261   HandshakeErrorCallback acceptCallback(&handshakeCallback);
262   TestSSLServer server(&acceptCallback);
263
264   // Set up SSL context.
265   std::shared_ptr<SSLContext> sslContext(new SSLContext());
266   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
267
268   // connect
269   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
270                                                  sslContext);
271   // read()
272   bool ex = false;
273   try {
274     socket->open();
275
276     uint8_t readbuf[128];
277     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
278     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
279   } catch (AsyncSocketException&) {
280     ex = true;
281   }
282   EXPECT_TRUE(ex);
283
284   // close()
285   socket->close();
286   cerr << "HandshakeError test completed" << endl;
287 }
288
289 /**
290  * Negative test for readError().
291  */
292 TEST(AsyncSSLSocketTest, ReadError) {
293   // Start listening on a local port
294   WriteCallbackBase writeCallback;
295   ReadErrorCallback readCallback(&writeCallback);
296   HandshakeCallback handshakeCallback(&readCallback);
297   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
298   TestSSLServer server(&acceptCallback);
299
300   // Set up SSL context.
301   std::shared_ptr<SSLContext> sslContext(new SSLContext());
302   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
303
304   // connect
305   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
306                                                  sslContext);
307   socket->open();
308
309   // write something to trigger ssl handshake
310   uint8_t buf[128];
311   memset(buf, 'a', sizeof(buf));
312   socket->write(buf, sizeof(buf));
313
314   socket->close();
315   cerr << "ReadError test completed" << endl;
316 }
317
318 /**
319  * Negative test for writeError().
320  */
321 TEST(AsyncSSLSocketTest, WriteError) {
322   // Start listening on a local port
323   WriteCallbackBase writeCallback;
324   WriteErrorCallback readCallback(&writeCallback);
325   HandshakeCallback handshakeCallback(&readCallback);
326   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
327   TestSSLServer server(&acceptCallback);
328
329   // Set up SSL context.
330   std::shared_ptr<SSLContext> sslContext(new SSLContext());
331   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
332
333   // connect
334   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
335                                                  sslContext);
336   socket->open();
337
338   // write something to trigger ssl handshake
339   uint8_t buf[128];
340   memset(buf, 'a', sizeof(buf));
341   socket->write(buf, sizeof(buf));
342
343   socket->close();
344   cerr << "WriteError test completed" << endl;
345 }
346
347 /**
348  * Test a socket with TCP_NODELAY unset.
349  */
350 TEST(AsyncSSLSocketTest, SocketWithDelay) {
351   // Start listening on a local port
352   WriteCallbackBase writeCallback;
353   ReadCallback readCallback(&writeCallback);
354   HandshakeCallback handshakeCallback(&readCallback);
355   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
356   TestSSLServer server(&acceptCallback);
357
358   // Set up SSL context.
359   std::shared_ptr<SSLContext> sslContext(new SSLContext());
360   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
361
362   // connect
363   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
364                                                  sslContext);
365   socket->open();
366
367   // write()
368   uint8_t buf[128];
369   memset(buf, 'a', sizeof(buf));
370   socket->write(buf, sizeof(buf));
371
372   // read()
373   uint8_t readbuf[128];
374   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
375   EXPECT_EQ(bytesRead, 128);
376   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
377
378   // close()
379   socket->close();
380
381   cerr << "SocketWithDelay test completed" << endl;
382 }
383
384 using NextProtocolTypePair =
385     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
386
387 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
388   // For matching protos
389  public:
390   void SetUp() override { getctx(clientCtx, serverCtx); }
391
392   void connect(bool unset = false) {
393     getfds(fds);
394
395     if (unset) {
396       // unsetting NPN for any of [client, server] is enough to make NPN not
397       // work
398       clientCtx->unsetNextProtocols();
399     }
400
401     AsyncSSLSocket::UniquePtr clientSock(
402       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
403     AsyncSSLSocket::UniquePtr serverSock(
404       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
405     client = std::make_unique<NpnClient>(std::move(clientSock));
406     server = std::make_unique<NpnServer>(std::move(serverSock));
407
408     eventBase.loop();
409   }
410
411   void expectProtocol(const std::string& proto) {
412     expectHandshakeSuccess();
413     EXPECT_NE(client->nextProtoLength, 0);
414     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
415     EXPECT_EQ(
416         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
417         0);
418     string selected((const char*)client->nextProto, client->nextProtoLength);
419     EXPECT_EQ(proto, selected);
420   }
421
422   void expectNoProtocol() {
423     expectHandshakeSuccess();
424     EXPECT_EQ(client->nextProtoLength, 0);
425     EXPECT_EQ(server->nextProtoLength, 0);
426     EXPECT_EQ(client->nextProto, nullptr);
427     EXPECT_EQ(server->nextProto, nullptr);
428   }
429
430   void expectProtocolType() {
431     expectHandshakeSuccess();
432     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
433         GetParam().second == SSLContext::NextProtocolType::ANY) {
434       EXPECT_EQ(client->protocolType, server->protocolType);
435     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
436                GetParam().second == SSLContext::NextProtocolType::ANY) {
437       // Well not much we can say
438     } else {
439       expectProtocolType(GetParam());
440     }
441   }
442
443   void expectProtocolType(NextProtocolTypePair expected) {
444     expectHandshakeSuccess();
445     EXPECT_EQ(client->protocolType, expected.first);
446     EXPECT_EQ(server->protocolType, expected.second);
447   }
448
449   void expectHandshakeSuccess() {
450     EXPECT_FALSE(client->except.hasValue())
451         << "client handshake error: " << client->except->what();
452     EXPECT_FALSE(server->except.hasValue())
453         << "server handshake error: " << server->except->what();
454   }
455
456   void expectHandshakeError() {
457     EXPECT_TRUE(client->except.hasValue())
458         << "Expected client handshake error!";
459     EXPECT_TRUE(server->except.hasValue())
460         << "Expected server handshake error!";
461   }
462
463   EventBase eventBase;
464   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
465   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
466   int fds[2];
467   std::unique_ptr<NpnClient> client;
468   std::unique_ptr<NpnServer> server;
469 };
470
471 class NextProtocolTLSExtTest : public NextProtocolTest {
472   // For extended TLS protos
473 };
474
475 class NextProtocolNPNOnlyTest : public NextProtocolTest {
476   // For mismatching protos
477 };
478
479 class NextProtocolMismatchTest : public NextProtocolTest {
480   // For mismatching protos
481 };
482
483 TEST_P(NextProtocolTest, NpnTestOverlap) {
484   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
485   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
486                                         GetParam().second);
487
488   connect();
489
490   expectProtocol("baz");
491   expectProtocolType();
492 }
493
494 TEST_P(NextProtocolTest, NpnTestUnset) {
495   // Identical to above test, except that we want unset NPN before
496   // looping.
497   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
498   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
499                                         GetParam().second);
500
501   connect(true /* unset */);
502
503   // if alpn negotiation fails, type will appear as npn
504   expectNoProtocol();
505   EXPECT_EQ(client->protocolType, server->protocolType);
506 }
507
508 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
509   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
510   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
511                                         GetParam().second);
512
513   connect();
514
515   expectNoProtocol();
516   expectProtocolType(
517       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
518 }
519
520 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
521 // will fail on 1.0.2 before that.
522 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
523   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
524   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
525                                         GetParam().second);
526   connect();
527
528   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
529       GetParam().second == SSLContext::NextProtocolType::ALPN) {
530     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
531     // mismatch should result in a fatal alert, but this is the current behavior
532     // on all OpenSSL versions/variants, and we want to know if it changes.
533     expectNoProtocol();
534   }
535 #if FOLLY_OPENSSL_IS_110 || defined(OPENSSL_IS_BORINGSSL)
536   else if (
537       GetParam().first == SSLContext::NextProtocolType::ANY &&
538       GetParam().second == SSLContext::NextProtocolType::ANY) {
539 # if FOLLY_OPENSSL_IS_110
540     // OpenSSL 1.1.0 sends a fatal alert on mismatch, which is probavbly the
541     // correct behavior per RFC7301
542     expectHandshakeError();
543 # else
544     // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
545     // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
546     // NPN *and* ALPN
547     expectNoProtocol();
548 # endif
549   }
550 #endif
551    else {
552     expectProtocol("blub");
553     expectProtocolType(
554         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
555   }
556 }
557
558 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
559   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
560   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
561   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
562                                         GetParam().second);
563
564   connect();
565
566   expectProtocol("ponies");
567   expectProtocolType();
568 }
569
570 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
571   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
572   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
573   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
574                                         GetParam().second);
575
576   connect();
577
578   expectProtocol("blub");
579   expectProtocolType();
580 }
581
582 TEST_P(NextProtocolTest, RandomizedNpnTest) {
583   // Probability that this test will fail is 2^-64, which could be considered
584   // as negligible.
585   const int kTries = 64;
586
587   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
588                                         GetParam().first);
589   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
590                                                   GetParam().second);
591
592   std::set<string> selectedProtocols;
593   for (int i = 0; i < kTries; ++i) {
594     connect();
595
596     EXPECT_NE(client->nextProtoLength, 0);
597     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
598     EXPECT_EQ(
599         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
600         0);
601     string selected((const char*)client->nextProto, client->nextProtoLength);
602     selectedProtocols.insert(selected);
603     expectProtocolType();
604   }
605   EXPECT_EQ(selectedProtocols.size(), 2);
606 }
607
608 INSTANTIATE_TEST_CASE_P(
609     AsyncSSLSocketTest,
610     NextProtocolTest,
611     ::testing::Values(
612         NextProtocolTypePair(
613             SSLContext::NextProtocolType::NPN,
614             SSLContext::NextProtocolType::NPN),
615         NextProtocolTypePair(
616             SSLContext::NextProtocolType::NPN,
617             SSLContext::NextProtocolType::ANY),
618         NextProtocolTypePair(
619             SSLContext::NextProtocolType::ANY,
620             SSLContext::NextProtocolType::ANY)));
621
622 #if FOLLY_OPENSSL_HAS_ALPN
623 INSTANTIATE_TEST_CASE_P(
624     AsyncSSLSocketTest,
625     NextProtocolTLSExtTest,
626     ::testing::Values(
627         NextProtocolTypePair(
628             SSLContext::NextProtocolType::ALPN,
629             SSLContext::NextProtocolType::ALPN),
630         NextProtocolTypePair(
631             SSLContext::NextProtocolType::ALPN,
632             SSLContext::NextProtocolType::ANY),
633         NextProtocolTypePair(
634             SSLContext::NextProtocolType::ANY,
635             SSLContext::NextProtocolType::ALPN)));
636 #endif
637
638 INSTANTIATE_TEST_CASE_P(
639     AsyncSSLSocketTest,
640     NextProtocolNPNOnlyTest,
641     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
642                                            SSLContext::NextProtocolType::NPN)));
643
644 #if FOLLY_OPENSSL_HAS_ALPN
645 INSTANTIATE_TEST_CASE_P(
646     AsyncSSLSocketTest,
647     NextProtocolMismatchTest,
648     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
649                                            SSLContext::NextProtocolType::ALPN),
650                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
651                                            SSLContext::NextProtocolType::NPN)));
652 #endif
653
654 #ifndef OPENSSL_NO_TLSEXT
655 /**
656  * 1. Client sends TLSEXT_HOSTNAME in client hello.
657  * 2. Server found a match SSL_CTX and use this SSL_CTX to
658  *    continue the SSL handshake.
659  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
660  */
661 TEST(AsyncSSLSocketTest, SNITestMatch) {
662   EventBase eventBase;
663   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
664   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
665   // Use the same SSLContext to continue the handshake after
666   // tlsext_hostname match.
667   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
668   const std::string serverName("xyz.newdev.facebook.com");
669   int fds[2];
670   getfds(fds);
671   getctx(clientCtx, dfServerCtx);
672
673   AsyncSSLSocket::UniquePtr clientSock(
674     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
675   AsyncSSLSocket::UniquePtr serverSock(
676     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
677   SNIClient client(std::move(clientSock));
678   SNIServer server(std::move(serverSock),
679                    dfServerCtx,
680                    hskServerCtx,
681                    serverName);
682
683   eventBase.loop();
684
685   EXPECT_TRUE(client.serverNameMatch);
686   EXPECT_TRUE(server.serverNameMatch);
687 }
688
689 /**
690  * 1. Client sends TLSEXT_HOSTNAME in client hello.
691  * 2. Server cannot find a matching SSL_CTX and continue to use
692  *    the current SSL_CTX to do the handshake.
693  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
694  */
695 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
696   EventBase eventBase;
697   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
698   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
699   // Use the same SSLContext to continue the handshake after
700   // tlsext_hostname match.
701   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
702   const std::string clientRequestingServerName("foo.com");
703   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
704
705   int fds[2];
706   getfds(fds);
707   getctx(clientCtx, dfServerCtx);
708
709   AsyncSSLSocket::UniquePtr clientSock(
710     new AsyncSSLSocket(clientCtx,
711                         &eventBase,
712                         fds[0],
713                         clientRequestingServerName));
714   AsyncSSLSocket::UniquePtr serverSock(
715     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
716   SNIClient client(std::move(clientSock));
717   SNIServer server(std::move(serverSock),
718                    dfServerCtx,
719                    hskServerCtx,
720                    serverExpectedServerName);
721
722   eventBase.loop();
723
724   EXPECT_TRUE(!client.serverNameMatch);
725   EXPECT_TRUE(!server.serverNameMatch);
726 }
727 /**
728  * 1. Client sends TLSEXT_HOSTNAME in client hello.
729  * 2. We then change the serverName.
730  * 3. We expect that we get 'false' as the result for serNameMatch.
731  */
732
733 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
734    EventBase eventBase;
735   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
736   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
737   // Use the same SSLContext to continue the handshake after
738   // tlsext_hostname match.
739   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
740   const std::string serverName("xyz.newdev.facebook.com");
741   int fds[2];
742   getfds(fds);
743   getctx(clientCtx, dfServerCtx);
744
745   AsyncSSLSocket::UniquePtr clientSock(
746     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
747   //Change the server name
748   std::string newName("new.com");
749   clientSock->setServerName(newName);
750   AsyncSSLSocket::UniquePtr serverSock(
751     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
752   SNIClient client(std::move(clientSock));
753   SNIServer server(std::move(serverSock),
754                    dfServerCtx,
755                    hskServerCtx,
756                    serverName);
757
758   eventBase.loop();
759
760   EXPECT_TRUE(!client.serverNameMatch);
761 }
762
763 /**
764  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
765  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
766  */
767 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
768   EventBase eventBase;
769   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
770   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
771   // Use the same SSLContext to continue the handshake after
772   // tlsext_hostname match.
773   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
774   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
775
776   int fds[2];
777   getfds(fds);
778   getctx(clientCtx, dfServerCtx);
779
780   AsyncSSLSocket::UniquePtr clientSock(
781     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
782   AsyncSSLSocket::UniquePtr serverSock(
783     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
784   SNIClient client(std::move(clientSock));
785   SNIServer server(std::move(serverSock),
786                    dfServerCtx,
787                    hskServerCtx,
788                    serverExpectedServerName);
789
790   eventBase.loop();
791
792   EXPECT_TRUE(!client.serverNameMatch);
793   EXPECT_TRUE(!server.serverNameMatch);
794 }
795
796 #endif
797 /**
798  * Test SSL client socket
799  */
800 TEST(AsyncSSLSocketTest, SSLClientTest) {
801   // Start listening on a local port
802   WriteCallbackBase writeCallback;
803   ReadCallback readCallback(&writeCallback);
804   HandshakeCallback handshakeCallback(&readCallback);
805   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
806   TestSSLServer server(&acceptCallback);
807
808   // Set up SSL client
809   EventBase eventBase;
810   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
811
812   client->connect();
813   EventBaseAborter eba(&eventBase, 3000);
814   eventBase.loop();
815
816   EXPECT_EQ(client->getMiss(), 1);
817   EXPECT_EQ(client->getHit(), 0);
818
819   cerr << "SSLClientTest test completed" << endl;
820 }
821
822
823 /**
824  * Test SSL client socket session re-use
825  */
826 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
827   // Start listening on a local port
828   WriteCallbackBase writeCallback;
829   ReadCallback readCallback(&writeCallback);
830   HandshakeCallback handshakeCallback(&readCallback);
831   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
832   TestSSLServer server(&acceptCallback);
833
834   // Set up SSL client
835   EventBase eventBase;
836   auto client =
837       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
838
839   client->connect();
840   EventBaseAborter eba(&eventBase, 3000);
841   eventBase.loop();
842
843   EXPECT_EQ(client->getMiss(), 1);
844   EXPECT_EQ(client->getHit(), 9);
845
846   cerr << "SSLClientTestReuse test completed" << endl;
847 }
848
849 /**
850  * Test SSL client socket timeout
851  */
852 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
853   // Start listening on a local port
854   EmptyReadCallback readCallback;
855   HandshakeCallback handshakeCallback(&readCallback,
856                                       HandshakeCallback::EXPECT_ERROR);
857   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
858   TestSSLServer server(&acceptCallback);
859
860   // Set up SSL client
861   EventBase eventBase;
862   auto client =
863       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
864   client->connect(true /* write before connect completes */);
865   EventBaseAborter eba(&eventBase, 3000);
866   eventBase.loop();
867
868   usleep(100000);
869   // This is checking that the connectError callback precedes any queued
870   // writeError callbacks.  This matches AsyncSocket's behavior
871   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
872   EXPECT_EQ(client->getErrors(), 1);
873   EXPECT_EQ(client->getMiss(), 0);
874   EXPECT_EQ(client->getHit(), 0);
875
876   cerr << "SSLClientTimeoutTest test completed" << endl;
877 }
878
879 // The next 3 tests need an FB-only extension, and will fail without it
880 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
881 /**
882  * Test SSL server async cache
883  */
884 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
885   // Start listening on a local port
886   WriteCallbackBase writeCallback;
887   ReadCallback readCallback(&writeCallback);
888   HandshakeCallback handshakeCallback(&readCallback);
889   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
890   TestSSLAsyncCacheServer server(&acceptCallback);
891
892   // Set up SSL client
893   EventBase eventBase;
894   auto client =
895       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
896
897   client->connect();
898   EventBaseAborter eba(&eventBase, 3000);
899   eventBase.loop();
900
901   EXPECT_EQ(server.getAsyncCallbacks(), 18);
902   EXPECT_EQ(server.getAsyncLookups(), 9);
903   EXPECT_EQ(client->getMiss(), 10);
904   EXPECT_EQ(client->getHit(), 0);
905
906   cerr << "SSLServerAsyncCacheTest test completed" << endl;
907 }
908
909 /**
910  * Test SSL server accept timeout with cache path
911  */
912 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
913   // Start listening on a local port
914   WriteCallbackBase writeCallback;
915   ReadCallback readCallback(&writeCallback);
916   HandshakeCallback handshakeCallback(&readCallback);
917   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
918   TestSSLAsyncCacheServer server(&acceptCallback);
919
920   // Set up SSL client
921   EventBase eventBase;
922   // only do a TCP connect
923   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
924   sock->connect(nullptr, server.getAddress());
925
926   EmptyReadCallback clientReadCallback;
927   clientReadCallback.tcpSocket_ = sock;
928   sock->setReadCB(&clientReadCallback);
929
930   EventBaseAborter eba(&eventBase, 3000);
931   eventBase.loop();
932
933   EXPECT_EQ(readCallback.state, STATE_WAITING);
934
935   cerr << "SSLServerTimeoutTest test completed" << endl;
936 }
937
938 /**
939  * Test SSL server accept timeout with cache path
940  */
941 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
942   // Start listening on a local port
943   WriteCallbackBase writeCallback;
944   ReadCallback readCallback(&writeCallback);
945   HandshakeCallback handshakeCallback(&readCallback);
946   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
947   TestSSLAsyncCacheServer server(&acceptCallback);
948
949   // Set up SSL client
950   EventBase eventBase;
951   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
952
953   client->connect();
954   EventBaseAborter eba(&eventBase, 3000);
955   eventBase.loop();
956
957   EXPECT_EQ(server.getAsyncCallbacks(), 1);
958   EXPECT_EQ(server.getAsyncLookups(), 1);
959   EXPECT_EQ(client->getErrors(), 1);
960   EXPECT_EQ(client->getMiss(), 1);
961   EXPECT_EQ(client->getHit(), 0);
962
963   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
964 }
965
966 /**
967  * Test SSL server accept timeout with cache path
968  */
969 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
970   // Start listening on a local port
971   WriteCallbackBase writeCallback;
972   ReadCallback readCallback(&writeCallback);
973   HandshakeCallback handshakeCallback(&readCallback,
974                                       HandshakeCallback::EXPECT_ERROR);
975   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
976   TestSSLAsyncCacheServer server(&acceptCallback, 500);
977
978   // Set up SSL client
979   EventBase eventBase;
980   auto client =
981       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
982
983   client->connect();
984   EventBaseAborter eba(&eventBase, 3000);
985   eventBase.loop();
986
987   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
988       handshakeCallback.closeSocket();});
989   // give time for the cache lookup to come back and find it closed
990   handshakeCallback.waitForHandshake();
991
992   EXPECT_EQ(server.getAsyncCallbacks(), 1);
993   EXPECT_EQ(server.getAsyncLookups(), 1);
994   EXPECT_EQ(client->getErrors(), 1);
995   EXPECT_EQ(client->getMiss(), 1);
996   EXPECT_EQ(client->getHit(), 0);
997
998   cerr << "SSLServerCacheCloseTest test completed" << endl;
999 }
1000 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
1001
1002 /**
1003  * Verify Client Ciphers obtained using SSL MSG Callback.
1004  */
1005 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1006   EventBase eventBase;
1007   auto clientCtx = std::make_shared<SSLContext>();
1008   auto serverCtx = std::make_shared<SSLContext>();
1009   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1010   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1011   serverCtx->loadPrivateKey(kTestKey);
1012   serverCtx->loadCertificate(kTestCert);
1013   serverCtx->loadTrustedCertificates(kTestCA);
1014   serverCtx->loadClientCAList(kTestCA);
1015
1016   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1017   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1018   clientCtx->loadPrivateKey(kTestKey);
1019   clientCtx->loadCertificate(kTestCert);
1020   clientCtx->loadTrustedCertificates(kTestCA);
1021
1022   int fds[2];
1023   getfds(fds);
1024
1025   AsyncSSLSocket::UniquePtr clientSock(
1026       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1027   AsyncSSLSocket::UniquePtr serverSock(
1028       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1029
1030   SSLHandshakeClient client(std::move(clientSock), true, true);
1031   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1032
1033   eventBase.loop();
1034
1035 #if defined(OPENSSL_IS_BORINGSSL)
1036   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1037 #else
1038   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1039 #endif
1040   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1041   EXPECT_TRUE(client.handshakeVerify_);
1042   EXPECT_TRUE(client.handshakeSuccess_);
1043   EXPECT_TRUE(!client.handshakeError_);
1044   EXPECT_TRUE(server.handshakeVerify_);
1045   EXPECT_TRUE(server.handshakeSuccess_);
1046   EXPECT_TRUE(!server.handshakeError_);
1047 }
1048
1049 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1050   EventBase eventBase;
1051   auto ctx = std::make_shared<SSLContext>();
1052
1053   int fds[2];
1054   getfds(fds);
1055
1056   int bufLen = 42;
1057   uint8_t majorVersion = 18;
1058   uint8_t minorVersion = 25;
1059
1060   // Create callback buf
1061   auto buf = IOBuf::create(bufLen);
1062   buf->append(bufLen);
1063   folly::io::RWPrivateCursor cursor(buf.get());
1064   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1065   cursor.write<uint16_t>(0);
1066   cursor.write<uint8_t>(38);
1067   cursor.write<uint8_t>(majorVersion);
1068   cursor.write<uint8_t>(minorVersion);
1069   cursor.skip(32);
1070   cursor.write<uint32_t>(0);
1071
1072   SSL* ssl = ctx->createSSL();
1073   SCOPE_EXIT { SSL_free(ssl); };
1074   AsyncSSLSocket::UniquePtr sock(
1075       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1076   sock->enableClientHelloParsing();
1077
1078   // Test client hello parsing in one packet
1079   AsyncSSLSocket::clientHelloParsingCallback(
1080       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1081   buf.reset();
1082
1083   auto parsedClientHello = sock->getClientHelloInfo();
1084   EXPECT_TRUE(parsedClientHello != nullptr);
1085   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1086   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1087 }
1088
1089 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1090   EventBase eventBase;
1091   auto ctx = std::make_shared<SSLContext>();
1092
1093   int fds[2];
1094   getfds(fds);
1095
1096   int bufLen = 42;
1097   uint8_t majorVersion = 18;
1098   uint8_t minorVersion = 25;
1099
1100   // Create callback buf
1101   auto buf = IOBuf::create(bufLen);
1102   buf->append(bufLen);
1103   folly::io::RWPrivateCursor cursor(buf.get());
1104   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1105   cursor.write<uint16_t>(0);
1106   cursor.write<uint8_t>(38);
1107   cursor.write<uint8_t>(majorVersion);
1108   cursor.write<uint8_t>(minorVersion);
1109   cursor.skip(32);
1110   cursor.write<uint32_t>(0);
1111
1112   SSL* ssl = ctx->createSSL();
1113   SCOPE_EXIT { SSL_free(ssl); };
1114   AsyncSSLSocket::UniquePtr sock(
1115       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1116   sock->enableClientHelloParsing();
1117
1118   // Test parsing with two packets with first packet size < 3
1119   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1120   AsyncSSLSocket::clientHelloParsingCallback(
1121       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1122       ssl, sock.get());
1123   bufCopy.reset();
1124   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1125   AsyncSSLSocket::clientHelloParsingCallback(
1126       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1127       ssl, sock.get());
1128   bufCopy.reset();
1129
1130   auto parsedClientHello = sock->getClientHelloInfo();
1131   EXPECT_TRUE(parsedClientHello != nullptr);
1132   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1133   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1134 }
1135
1136 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1137   EventBase eventBase;
1138   auto ctx = std::make_shared<SSLContext>();
1139
1140   int fds[2];
1141   getfds(fds);
1142
1143   int bufLen = 42;
1144   uint8_t majorVersion = 18;
1145   uint8_t minorVersion = 25;
1146
1147   // Create callback buf
1148   auto buf = IOBuf::create(bufLen);
1149   buf->append(bufLen);
1150   folly::io::RWPrivateCursor cursor(buf.get());
1151   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1152   cursor.write<uint16_t>(0);
1153   cursor.write<uint8_t>(38);
1154   cursor.write<uint8_t>(majorVersion);
1155   cursor.write<uint8_t>(minorVersion);
1156   cursor.skip(32);
1157   cursor.write<uint32_t>(0);
1158
1159   SSL* ssl = ctx->createSSL();
1160   SCOPE_EXIT { SSL_free(ssl); };
1161   AsyncSSLSocket::UniquePtr sock(
1162       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1163   sock->enableClientHelloParsing();
1164
1165   // Test parsing with multiple small packets
1166   for (uint64_t i = 0; i < buf->length(); i += 3) {
1167     auto bufCopy = folly::IOBuf::copyBuffer(
1168         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1169     AsyncSSLSocket::clientHelloParsingCallback(
1170         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1171         ssl, sock.get());
1172     bufCopy.reset();
1173   }
1174
1175   auto parsedClientHello = sock->getClientHelloInfo();
1176   EXPECT_TRUE(parsedClientHello != nullptr);
1177   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1178   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1179 }
1180
1181 /**
1182  * Verify sucessful behavior of SSL certificate validation.
1183  */
1184 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1185   EventBase eventBase;
1186   auto clientCtx = std::make_shared<SSLContext>();
1187   auto dfServerCtx = std::make_shared<SSLContext>();
1188
1189   int fds[2];
1190   getfds(fds);
1191   getctx(clientCtx, dfServerCtx);
1192
1193   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1194   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1195
1196   AsyncSSLSocket::UniquePtr clientSock(
1197     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1198   AsyncSSLSocket::UniquePtr serverSock(
1199     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1200
1201   SSLHandshakeClient client(std::move(clientSock), true, true);
1202   clientCtx->loadTrustedCertificates(kTestCA);
1203
1204   SSLHandshakeServer server(std::move(serverSock), true, true);
1205
1206   eventBase.loop();
1207
1208   EXPECT_TRUE(client.handshakeVerify_);
1209   EXPECT_TRUE(client.handshakeSuccess_);
1210   EXPECT_TRUE(!client.handshakeError_);
1211   EXPECT_LE(0, client.handshakeTime.count());
1212   EXPECT_TRUE(!server.handshakeVerify_);
1213   EXPECT_TRUE(server.handshakeSuccess_);
1214   EXPECT_TRUE(!server.handshakeError_);
1215   EXPECT_LE(0, server.handshakeTime.count());
1216 }
1217
1218 /**
1219  * Verify that the client's verification callback is able to fail SSL
1220  * connection establishment.
1221  */
1222 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1223   EventBase eventBase;
1224   auto clientCtx = std::make_shared<SSLContext>();
1225   auto dfServerCtx = std::make_shared<SSLContext>();
1226
1227   int fds[2];
1228   getfds(fds);
1229   getctx(clientCtx, dfServerCtx);
1230
1231   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1232   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1233
1234   AsyncSSLSocket::UniquePtr clientSock(
1235     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1236   AsyncSSLSocket::UniquePtr serverSock(
1237     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1238
1239   SSLHandshakeClient client(std::move(clientSock), true, false);
1240   clientCtx->loadTrustedCertificates(kTestCA);
1241
1242   SSLHandshakeServer server(std::move(serverSock), true, true);
1243
1244   eventBase.loop();
1245
1246   EXPECT_TRUE(client.handshakeVerify_);
1247   EXPECT_TRUE(!client.handshakeSuccess_);
1248   EXPECT_TRUE(client.handshakeError_);
1249   EXPECT_LE(0, client.handshakeTime.count());
1250   EXPECT_TRUE(!server.handshakeVerify_);
1251   EXPECT_TRUE(!server.handshakeSuccess_);
1252   EXPECT_TRUE(server.handshakeError_);
1253   EXPECT_LE(0, server.handshakeTime.count());
1254 }
1255
1256 /**
1257  * Verify that the options in SSLContext can be overridden in
1258  * sslConnect/Accept.i.e specifying that no validation should be performed
1259  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1260  * the validation callback.
1261  */
1262 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1263   EventBase eventBase;
1264   auto clientCtx = std::make_shared<SSLContext>();
1265   auto dfServerCtx = std::make_shared<SSLContext>();
1266
1267   int fds[2];
1268   getfds(fds);
1269   getctx(clientCtx, dfServerCtx);
1270
1271   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1272   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1273
1274   AsyncSSLSocket::UniquePtr clientSock(
1275     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1276   AsyncSSLSocket::UniquePtr serverSock(
1277     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1278
1279   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1280   clientCtx->loadTrustedCertificates(kTestCA);
1281
1282   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1283
1284   eventBase.loop();
1285
1286   EXPECT_TRUE(!client.handshakeVerify_);
1287   EXPECT_TRUE(client.handshakeSuccess_);
1288   EXPECT_TRUE(!client.handshakeError_);
1289   EXPECT_LE(0, client.handshakeTime.count());
1290   EXPECT_TRUE(!server.handshakeVerify_);
1291   EXPECT_TRUE(server.handshakeSuccess_);
1292   EXPECT_TRUE(!server.handshakeError_);
1293   EXPECT_LE(0, server.handshakeTime.count());
1294 }
1295
1296 /**
1297  * Verify that the options in SSLContext can be overridden in
1298  * sslConnect/Accept. Enable verification even if context says otherwise.
1299  * Test requireClientCert with client cert
1300  */
1301 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1302   EventBase eventBase;
1303   auto clientCtx = std::make_shared<SSLContext>();
1304   auto serverCtx = std::make_shared<SSLContext>();
1305   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1306   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1307   serverCtx->loadPrivateKey(kTestKey);
1308   serverCtx->loadCertificate(kTestCert);
1309   serverCtx->loadTrustedCertificates(kTestCA);
1310   serverCtx->loadClientCAList(kTestCA);
1311
1312   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1313   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1314   clientCtx->loadPrivateKey(kTestKey);
1315   clientCtx->loadCertificate(kTestCert);
1316   clientCtx->loadTrustedCertificates(kTestCA);
1317
1318   int fds[2];
1319   getfds(fds);
1320
1321   AsyncSSLSocket::UniquePtr clientSock(
1322       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1323   AsyncSSLSocket::UniquePtr serverSock(
1324       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1325
1326   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1327   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1328
1329   eventBase.loop();
1330
1331   EXPECT_TRUE(client.handshakeVerify_);
1332   EXPECT_TRUE(client.handshakeSuccess_);
1333   EXPECT_FALSE(client.handshakeError_);
1334   EXPECT_LE(0, client.handshakeTime.count());
1335   EXPECT_TRUE(server.handshakeVerify_);
1336   EXPECT_TRUE(server.handshakeSuccess_);
1337   EXPECT_FALSE(server.handshakeError_);
1338   EXPECT_LE(0, server.handshakeTime.count());
1339 }
1340
1341 /**
1342  * Verify that the client's verification callback is able to override
1343  * the preverification failure and allow a successful connection.
1344  */
1345 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1346   EventBase eventBase;
1347   auto clientCtx = std::make_shared<SSLContext>();
1348   auto dfServerCtx = std::make_shared<SSLContext>();
1349
1350   int fds[2];
1351   getfds(fds);
1352   getctx(clientCtx, dfServerCtx);
1353
1354   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1355   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1356
1357   AsyncSSLSocket::UniquePtr clientSock(
1358     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1359   AsyncSSLSocket::UniquePtr serverSock(
1360     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1361
1362   SSLHandshakeClient client(std::move(clientSock), false, true);
1363   SSLHandshakeServer server(std::move(serverSock), true, true);
1364
1365   eventBase.loop();
1366
1367   EXPECT_TRUE(client.handshakeVerify_);
1368   EXPECT_TRUE(client.handshakeSuccess_);
1369   EXPECT_TRUE(!client.handshakeError_);
1370   EXPECT_LE(0, client.handshakeTime.count());
1371   EXPECT_TRUE(!server.handshakeVerify_);
1372   EXPECT_TRUE(server.handshakeSuccess_);
1373   EXPECT_TRUE(!server.handshakeError_);
1374   EXPECT_LE(0, server.handshakeTime.count());
1375 }
1376
1377 /**
1378  * Verify that specifying that no validation should be performed allows an
1379  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1380  * callback.
1381  */
1382 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1383   EventBase eventBase;
1384   auto clientCtx = std::make_shared<SSLContext>();
1385   auto dfServerCtx = std::make_shared<SSLContext>();
1386
1387   int fds[2];
1388   getfds(fds);
1389   getctx(clientCtx, dfServerCtx);
1390
1391   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1392   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1393
1394   AsyncSSLSocket::UniquePtr clientSock(
1395     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1396   AsyncSSLSocket::UniquePtr serverSock(
1397     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1398
1399   SSLHandshakeClient client(std::move(clientSock), false, false);
1400   SSLHandshakeServer server(std::move(serverSock), false, false);
1401
1402   eventBase.loop();
1403
1404   EXPECT_TRUE(!client.handshakeVerify_);
1405   EXPECT_TRUE(client.handshakeSuccess_);
1406   EXPECT_TRUE(!client.handshakeError_);
1407   EXPECT_LE(0, client.handshakeTime.count());
1408   EXPECT_TRUE(!server.handshakeVerify_);
1409   EXPECT_TRUE(server.handshakeSuccess_);
1410   EXPECT_TRUE(!server.handshakeError_);
1411   EXPECT_LE(0, server.handshakeTime.count());
1412 }
1413
1414 /**
1415  * Test requireClientCert with client cert
1416  */
1417 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1418   EventBase eventBase;
1419   auto clientCtx = std::make_shared<SSLContext>();
1420   auto serverCtx = std::make_shared<SSLContext>();
1421   serverCtx->setVerificationOption(
1422       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1423   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1424   serverCtx->loadPrivateKey(kTestKey);
1425   serverCtx->loadCertificate(kTestCert);
1426   serverCtx->loadTrustedCertificates(kTestCA);
1427   serverCtx->loadClientCAList(kTestCA);
1428
1429   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1430   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1431   clientCtx->loadPrivateKey(kTestKey);
1432   clientCtx->loadCertificate(kTestCert);
1433   clientCtx->loadTrustedCertificates(kTestCA);
1434
1435   int fds[2];
1436   getfds(fds);
1437
1438   AsyncSSLSocket::UniquePtr clientSock(
1439       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1440   AsyncSSLSocket::UniquePtr serverSock(
1441       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1442
1443   SSLHandshakeClient client(std::move(clientSock), true, true);
1444   SSLHandshakeServer server(std::move(serverSock), true, true);
1445
1446   eventBase.loop();
1447
1448   EXPECT_TRUE(client.handshakeVerify_);
1449   EXPECT_TRUE(client.handshakeSuccess_);
1450   EXPECT_FALSE(client.handshakeError_);
1451   EXPECT_LE(0, client.handshakeTime.count());
1452   EXPECT_TRUE(server.handshakeVerify_);
1453   EXPECT_TRUE(server.handshakeSuccess_);
1454   EXPECT_FALSE(server.handshakeError_);
1455   EXPECT_LE(0, server.handshakeTime.count());
1456 }
1457
1458
1459 /**
1460  * Test requireClientCert with no client cert
1461  */
1462 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1463   EventBase eventBase;
1464   auto clientCtx = std::make_shared<SSLContext>();
1465   auto serverCtx = std::make_shared<SSLContext>();
1466   serverCtx->setVerificationOption(
1467       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1468   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1469   serverCtx->loadPrivateKey(kTestKey);
1470   serverCtx->loadCertificate(kTestCert);
1471   serverCtx->loadTrustedCertificates(kTestCA);
1472   serverCtx->loadClientCAList(kTestCA);
1473   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1474   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1475
1476   int fds[2];
1477   getfds(fds);
1478
1479   AsyncSSLSocket::UniquePtr clientSock(
1480       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1481   AsyncSSLSocket::UniquePtr serverSock(
1482       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1483
1484   SSLHandshakeClient client(std::move(clientSock), false, false);
1485   SSLHandshakeServer server(std::move(serverSock), false, false);
1486
1487   eventBase.loop();
1488
1489   EXPECT_FALSE(server.handshakeVerify_);
1490   EXPECT_FALSE(server.handshakeSuccess_);
1491   EXPECT_TRUE(server.handshakeError_);
1492   EXPECT_LE(0, client.handshakeTime.count());
1493   EXPECT_LE(0, server.handshakeTime.count());
1494 }
1495
1496 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1497   auto cert = getFileAsBuf(kTestCert);
1498   auto key = getFileAsBuf(kTestKey);
1499
1500   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1501   BIO_write(certBio.get(), cert.data(), cert.size());
1502   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1503   BIO_write(keyBio.get(), key.data(), key.size());
1504
1505   // Create SSL structs from buffers to get properties
1506   ssl::X509UniquePtr certStruct(
1507       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1508   ssl::EvpPkeyUniquePtr keyStruct(
1509       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1510   certBio = nullptr;
1511   keyBio = nullptr;
1512
1513   auto origCommonName = getCommonName(certStruct.get());
1514   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1515   certStruct = nullptr;
1516   keyStruct = nullptr;
1517
1518   auto ctx = std::make_shared<SSLContext>();
1519   ctx->loadPrivateKeyFromBufferPEM(key);
1520   ctx->loadCertificateFromBufferPEM(cert);
1521   ctx->loadTrustedCertificates(kTestCA);
1522
1523   ssl::SSLUniquePtr ssl(ctx->createSSL());
1524
1525   auto newCert = SSL_get_certificate(ssl.get());
1526   auto newKey = SSL_get_privatekey(ssl.get());
1527
1528   // Get properties from SSL struct
1529   auto newCommonName = getCommonName(newCert);
1530   auto newKeySize = EVP_PKEY_bits(newKey);
1531
1532   // Check that the key and cert have the expected properties
1533   EXPECT_EQ(origCommonName, newCommonName);
1534   EXPECT_EQ(origKeySize, newKeySize);
1535 }
1536
1537 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1538   EventBase eb;
1539
1540   // Set up SSL context.
1541   auto sslContext = std::make_shared<SSLContext>();
1542   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1543
1544   // create SSL socket
1545   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1546
1547   EXPECT_EQ(1500, socket->getMinWriteSize());
1548
1549   socket->setMinWriteSize(0);
1550   EXPECT_EQ(0, socket->getMinWriteSize());
1551   socket->setMinWriteSize(50000);
1552   EXPECT_EQ(50000, socket->getMinWriteSize());
1553 }
1554
1555 class ReadCallbackTerminator : public ReadCallback {
1556  public:
1557   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1558       : ReadCallback(wcb)
1559       , base_(base) {}
1560
1561   // Do not write data back, terminate the loop.
1562   void readDataAvailable(size_t len) noexcept override {
1563     std::cerr << "readDataAvailable, len " << len << std::endl;
1564
1565     currentBuffer.length = len;
1566
1567     buffers.push_back(currentBuffer);
1568     currentBuffer.reset();
1569     state = STATE_SUCCEEDED;
1570
1571     socket_->setReadCB(nullptr);
1572     base_->terminateLoopSoon();
1573   }
1574  private:
1575   EventBase* base_;
1576 };
1577
1578
1579 /**
1580  * Test a full unencrypted codepath
1581  */
1582 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1583   EventBase base;
1584
1585   auto clientCtx = std::make_shared<folly::SSLContext>();
1586   auto serverCtx = std::make_shared<folly::SSLContext>();
1587   int fds[2];
1588   getfds(fds);
1589   getctx(clientCtx, serverCtx);
1590   auto client = AsyncSSLSocket::newSocket(
1591                   clientCtx, &base, fds[0], false, true);
1592   auto server = AsyncSSLSocket::newSocket(
1593                   serverCtx, &base, fds[1], true, true);
1594
1595   ReadCallbackTerminator readCallback(&base, nullptr);
1596   server->setReadCB(&readCallback);
1597   readCallback.setSocket(server);
1598
1599   uint8_t buf[128];
1600   memset(buf, 'a', sizeof(buf));
1601   client->write(nullptr, buf, sizeof(buf));
1602
1603   // Check that bytes are unencrypted
1604   char c;
1605   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1606   EXPECT_EQ('a', c);
1607
1608   EventBaseAborter eba(&base, 3000);
1609   base.loop();
1610
1611   EXPECT_EQ(1, readCallback.buffers.size());
1612   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1613
1614   server->setReadCB(&readCallback);
1615
1616   // Unencrypted
1617   server->sslAccept(nullptr);
1618   client->sslConn(nullptr);
1619
1620   // Do NOT wait for handshake, writing should be queued and happen after
1621
1622   client->write(nullptr, buf, sizeof(buf));
1623
1624   // Check that bytes are *not* unencrypted
1625   char c2;
1626   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1627   EXPECT_NE('a', c2);
1628
1629
1630   base.loop();
1631
1632   EXPECT_EQ(2, readCallback.buffers.size());
1633   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1634 }
1635
1636 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1637   // Start listening on a local port
1638   WriteCallbackBase writeCallback;
1639   WriteErrorCallback readCallback(&writeCallback);
1640   HandshakeCallback handshakeCallback(&readCallback,
1641                                       HandshakeCallback::EXPECT_ERROR);
1642   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1643   TestSSLServer server(&acceptCallback);
1644
1645   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1646   socket->open();
1647   uint8_t buf[3] = {0x16, 0x03, 0x01};
1648   socket->write(buf, sizeof(buf));
1649   socket->closeWithReset();
1650
1651   handshakeCallback.waitForHandshake();
1652   EXPECT_NE(
1653       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1654   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1655 }
1656
1657 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1658   // Start listening on a local port
1659   WriteCallbackBase writeCallback;
1660   WriteErrorCallback readCallback(&writeCallback);
1661   HandshakeCallback handshakeCallback(&readCallback,
1662                                       HandshakeCallback::EXPECT_ERROR);
1663   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1664   TestSSLServer server(&acceptCallback);
1665
1666   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1667   socket->open();
1668   uint8_t buf[3] = {0x16, 0x03, 0x01};
1669   socket->write(buf, sizeof(buf));
1670   socket->close();
1671
1672   handshakeCallback.waitForHandshake();
1673 #if FOLLY_OPENSSL_IS_110
1674   EXPECT_NE(
1675       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1676 #else
1677   EXPECT_NE(
1678       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1679 #endif
1680 }
1681
1682 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1683   // Start listening on a local port
1684   WriteCallbackBase writeCallback;
1685   WriteErrorCallback readCallback(&writeCallback);
1686   HandshakeCallback handshakeCallback(&readCallback,
1687                                       HandshakeCallback::EXPECT_ERROR);
1688   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1689   TestSSLServer server(&acceptCallback);
1690
1691   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1692   socket->open();
1693   uint8_t buf[256] = {0x16, 0x03};
1694   memset(buf + 2, 'a', sizeof(buf) - 2);
1695   socket->write(buf, sizeof(buf));
1696   socket->close();
1697
1698   handshakeCallback.waitForHandshake();
1699   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1700             std::string::npos);
1701 #if defined(OPENSSL_IS_BORINGSSL)
1702   EXPECT_NE(
1703       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1704       std::string::npos);
1705 #elif FOLLY_OPENSSL_IS_110
1706   EXPECT_NE(handshakeCallback.errorString_.find("packet length too long"),
1707             std::string::npos);
1708 #else
1709   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1710             std::string::npos);
1711 #endif
1712 }
1713
1714 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1715   using folly::ssl::OpenSSLUtils;
1716   EXPECT_EQ(
1717       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1718   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1719   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1720   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1721   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1722 }
1723
1724 #if FOLLY_ALLOW_TFO
1725
1726 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1727  public:
1728   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1729
1730   explicit MockAsyncTFOSSLSocket(
1731       std::shared_ptr<folly::SSLContext> sslCtx,
1732       EventBase* evb)
1733       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1734
1735   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1736 };
1737
1738 /**
1739  * Test connecting to, writing to, reading from, and closing the
1740  * connection to the SSL server with TFO.
1741  */
1742 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1743   // Start listening on a local port
1744   WriteCallbackBase writeCallback;
1745   ReadCallback readCallback(&writeCallback);
1746   HandshakeCallback handshakeCallback(&readCallback);
1747   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1748   TestSSLServer server(&acceptCallback, true);
1749
1750   // Set up SSL context.
1751   auto sslContext = std::make_shared<SSLContext>();
1752
1753   // connect
1754   auto socket =
1755       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1756   socket->enableTFO();
1757   socket->open();
1758
1759   // write()
1760   std::array<uint8_t, 128> buf;
1761   memset(buf.data(), 'a', buf.size());
1762   socket->write(buf.data(), buf.size());
1763
1764   // read()
1765   std::array<uint8_t, 128> readbuf;
1766   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1767   EXPECT_EQ(bytesRead, 128);
1768   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1769
1770   // close()
1771   socket->close();
1772 }
1773
1774 /**
1775  * Test connecting to, writing to, reading from, and closing the
1776  * connection to the SSL server with TFO.
1777  */
1778 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1779   // Start listening on a local port
1780   WriteCallbackBase writeCallback;
1781   ReadCallback readCallback(&writeCallback);
1782   HandshakeCallback handshakeCallback(&readCallback);
1783   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1784   TestSSLServer server(&acceptCallback, false);
1785
1786   // Set up SSL context.
1787   auto sslContext = std::make_shared<SSLContext>();
1788
1789   // connect
1790   auto socket =
1791       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1792   socket->enableTFO();
1793   socket->open();
1794
1795   // write()
1796   std::array<uint8_t, 128> buf;
1797   memset(buf.data(), 'a', buf.size());
1798   socket->write(buf.data(), buf.size());
1799
1800   // read()
1801   std::array<uint8_t, 128> readbuf;
1802   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1803   EXPECT_EQ(bytesRead, 128);
1804   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1805
1806   // close()
1807   socket->close();
1808 }
1809
1810 class ConnCallback : public AsyncSocket::ConnectCallback {
1811  public:
1812   void connectSuccess() noexcept override {
1813     state = State::SUCCESS;
1814   }
1815
1816   void connectErr(const AsyncSocketException& ex) noexcept override {
1817     state = State::ERROR;
1818     error = ex.what();
1819   }
1820
1821   enum class State { WAITING, SUCCESS, ERROR };
1822
1823   State state{State::WAITING};
1824   std::string error;
1825 };
1826
1827 template <class Cardinality>
1828 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1829     EventBase* evb,
1830     const SocketAddress& address,
1831     Cardinality cardinality) {
1832   // Set up SSL context.
1833   auto sslContext = std::make_shared<SSLContext>();
1834
1835   // connect
1836   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1837       new MockAsyncTFOSSLSocket(sslContext, evb));
1838   socket->enableTFO();
1839
1840   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1841       .Times(cardinality)
1842       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1843         sockaddr_storage addr;
1844         auto len = address.getAddress(&addr);
1845         return connect(fd, (const struct sockaddr*)&addr, len);
1846       }));
1847   return socket;
1848 }
1849
1850 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1851   // Start listening on a local port
1852   WriteCallbackBase writeCallback;
1853   ReadCallback readCallback(&writeCallback);
1854   HandshakeCallback handshakeCallback(&readCallback);
1855   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1856   TestSSLServer server(&acceptCallback, true);
1857
1858   EventBase evb;
1859
1860   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1861   ConnCallback ccb;
1862   socket->connect(&ccb, server.getAddress(), 30);
1863
1864   evb.loop();
1865   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1866
1867   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1868   evb.loop();
1869
1870   BlockingSocket sock(std::move(socket));
1871   // write()
1872   std::array<uint8_t, 128> buf;
1873   memset(buf.data(), 'a', buf.size());
1874   sock.write(buf.data(), buf.size());
1875
1876   // read()
1877   std::array<uint8_t, 128> readbuf;
1878   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1879   EXPECT_EQ(bytesRead, 128);
1880   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1881
1882   // close()
1883   sock.close();
1884 }
1885
1886 #if !defined(OPENSSL_IS_BORINGSSL)
1887 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1888   // Start listening on a local port
1889   ConnectTimeoutCallback acceptCallback;
1890   TestSSLServer server(&acceptCallback, true);
1891
1892   // Set up SSL context.
1893   auto sslContext = std::make_shared<SSLContext>();
1894
1895   // connect
1896   auto socket =
1897       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1898   socket->enableTFO();
1899   EXPECT_THROW(
1900       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1901 }
1902 #endif
1903
1904 #if !defined(OPENSSL_IS_BORINGSSL)
1905 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1906   // Start listening on a local port
1907   ConnectTimeoutCallback acceptCallback;
1908   TestSSLServer server(&acceptCallback, true);
1909
1910   EventBase evb;
1911
1912   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1913   ConnCallback ccb;
1914   // Set a short timeout
1915   socket->connect(&ccb, server.getAddress(), 1);
1916
1917   evb.loop();
1918   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1919 }
1920 #endif
1921
1922 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1923   // Start listening on a local port
1924   EmptyReadCallback readCallback;
1925   HandshakeCallback handshakeCallback(
1926       &readCallback, HandshakeCallback::EXPECT_ERROR);
1927   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1928   TestSSLServer server(&acceptCallback, true);
1929
1930   EventBase evb;
1931
1932   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1933   ConnCallback ccb;
1934   socket->connect(&ccb, server.getAddress(), 100);
1935
1936   evb.loop();
1937   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1938   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1939 }
1940
1941 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1942   // Start listening on a local port
1943   EventBase evb;
1944
1945   // Hopefully nothing is listening on this address
1946   SocketAddress addr("127.0.0.1", 65535);
1947   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
1948   ConnCallback ccb;
1949   socket->connect(&ccb, addr, 100);
1950
1951   evb.loop();
1952   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1953   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
1954 }
1955
1956 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
1957   EventBase clientEventBase;
1958   EventBase serverEventBase;
1959   auto clientCtx = std::make_shared<SSLContext>();
1960   auto dfServerCtx = std::make_shared<SSLContext>();
1961   std::array<int, 2> fds;
1962   getfds(fds.data());
1963   getctx(clientCtx, dfServerCtx);
1964
1965   AsyncSSLSocket::UniquePtr clientSockPtr(
1966       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
1967   AsyncSSLSocket::UniquePtr serverSockPtr(
1968       new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
1969   auto clientSock = clientSockPtr.get();
1970   auto serverSock = serverSockPtr.get();
1971   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
1972
1973   // Steal some data from the server.
1974   clientEventBase.loopOnce();
1975   std::array<uint8_t, 10> buf;
1976   recv(fds[1], buf.data(), buf.size(), 0);
1977
1978   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
1979   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
1980   while (!client.handshakeSuccess_ && !client.handshakeError_) {
1981     serverEventBase.loopOnce();
1982     clientEventBase.loopOnce();
1983   }
1984
1985   EXPECT_TRUE(client.handshakeSuccess_);
1986   EXPECT_TRUE(server.handshakeSuccess_);
1987   EXPECT_EQ(
1988       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
1989 }
1990
1991 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
1992   EventBase clientEventBase;
1993   EventBase serverEventBase;
1994   auto clientCtx = std::make_shared<SSLContext>();
1995   auto dfServerCtx = std::make_shared<SSLContext>();
1996   std::array<int, 2> fds;
1997   getfds(fds.data());
1998   getctx(clientCtx, dfServerCtx);
1999
2000   AsyncSSLSocket::UniquePtr clientSockPtr(
2001       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2002   AsyncSocket::UniquePtr serverSockPtr(
2003       new AsyncSocket(&serverEventBase, fds[1]));
2004   auto clientSock = clientSockPtr.get();
2005   auto serverSock = serverSockPtr.get();
2006   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2007
2008   // Steal some data from the server.
2009   clientEventBase.loopOnce();
2010   std::array<uint8_t, 10> buf;
2011   recv(fds[1], buf.data(), buf.size(), 0);
2012   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2013   AsyncSSLSocket::UniquePtr serverSSLSockPtr(
2014       new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
2015   auto serverSSLSock = serverSSLSockPtr.get();
2016   SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
2017   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2018     serverEventBase.loopOnce();
2019     clientEventBase.loopOnce();
2020   }
2021
2022   EXPECT_TRUE(client.handshakeSuccess_);
2023   EXPECT_TRUE(server.handshakeSuccess_);
2024   EXPECT_EQ(
2025       serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2026 }
2027
2028 /**
2029  * Test overriding the flags passed to "sendmsg()" system call,
2030  * and verifying that write requests fail properly.
2031  */
2032 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
2033   // Start listening on a local port
2034   SendMsgFlagsCallback msgCallback;
2035   ExpectWriteErrorCallback writeCallback(&msgCallback);
2036   ReadCallback readCallback(&writeCallback);
2037   HandshakeCallback handshakeCallback(&readCallback);
2038   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2039   TestSSLServer server(&acceptCallback);
2040
2041   // Set up SSL context.
2042   auto sslContext = std::make_shared<SSLContext>();
2043   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2044
2045   // connect
2046   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
2047                                                  sslContext);
2048   socket->open();
2049
2050   // Setting flags to "-1" to trigger "Invalid argument" error
2051   // on attempt to use this flags in sendmsg() system call.
2052   msgCallback.resetFlags(-1);
2053
2054   // write()
2055   std::vector<uint8_t> buf(128, 'a');
2056   ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
2057
2058   // close()
2059   socket->close();
2060
2061   cerr << "SendMsgParamsCallback test completed" << endl;
2062 }
2063
2064 #ifdef MSG_ERRQUEUE
2065 /**
2066  * Test connecting to, writing to, reading from, and closing the
2067  * connection to the SSL server.
2068  */
2069 TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
2070   // This test requires Linux kernel v4.6 or later
2071   struct utsname s_uname;
2072   memset(&s_uname, 0, sizeof(s_uname));
2073   ASSERT_EQ(uname(&s_uname), 0);
2074   int major, minor;
2075   folly::StringPiece extra;
2076   if (folly::split<false>(
2077         '.', std::string(s_uname.release) + ".", major, minor, extra)) {
2078     if (major < 4 || (major == 4 && minor < 6)) {
2079       LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
2080                 << "kernel ver. " << s_uname.release << " detected).";
2081       return;
2082     }
2083   }
2084
2085   // Start listening on a local port
2086   SendMsgDataCallback msgCallback;
2087   WriteCheckTimestampCallback writeCallback(&msgCallback);
2088   ReadCallback readCallback(&writeCallback);
2089   HandshakeCallback handshakeCallback(&readCallback);
2090   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2091   TestSSLServer server(&acceptCallback);
2092
2093   // Set up SSL context.
2094   auto sslContext = std::make_shared<SSLContext>();
2095   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2096
2097   // connect
2098   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
2099                                                  sslContext);
2100   socket->open();
2101
2102   // Adding MSG_EOR flag to the message flags - it'll trigger
2103   // timestamp generation for the last byte of the message.
2104   msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR);
2105
2106   // Init ancillary data buffer to trigger timestamp notification
2107   union {
2108     uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
2109     struct cmsghdr cmsg;
2110   } u;
2111   u.cmsg.cmsg_level = SOL_SOCKET;
2112   u.cmsg.cmsg_type = SO_TIMESTAMPING;
2113   u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
2114   uint32_t flags =
2115       SOF_TIMESTAMPING_TX_SCHED |
2116       SOF_TIMESTAMPING_TX_SOFTWARE |
2117       SOF_TIMESTAMPING_TX_ACK;
2118   memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
2119   std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
2120   memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
2121   msgCallback.resetData(std::move(ctrl));
2122
2123   // write()
2124   std::vector<uint8_t> buf(128, 'a');
2125   socket->write(buf.data(), buf.size());
2126
2127   // read()
2128   std::vector<uint8_t> readbuf(buf.size());
2129   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2130   EXPECT_EQ(bytesRead, buf.size());
2131   EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
2132
2133   writeCallback.checkForTimestampNotifications();
2134
2135   // close()
2136   socket->close();
2137
2138   cerr << "SendMsgDataCallback test completed" << endl;
2139 }
2140 #endif // MSG_ERRQUEUE
2141
2142 #endif
2143
2144 } // namespace
2145
2146 #ifdef SIGPIPE
2147 ///////////////////////////////////////////////////////////////////////////
2148 // init_unit_test_suite
2149 ///////////////////////////////////////////////////////////////////////////
2150 namespace {
2151 struct Initializer {
2152   Initializer() {
2153     signal(SIGPIPE, SIG_IGN);
2154   }
2155 };
2156 Initializer initializer;
2157 } // anonymous
2158 #endif