Fix buck build for SSLContext
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/SpinLock.h>
27
28 // ---------------------------------------------------------------------
29 // SSLContext implementation
30 // ---------------------------------------------------------------------
31
32 struct CRYPTO_dynlock_value {
33   std::mutex mutex;
34 };
35
36 namespace folly {
37
38 bool SSLContext::initialized_ = false;
39
40 namespace {
41
42 std::mutex& initMutex() {
43   static std::mutex m;
44   return m;
45 }
46
47 inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
48 using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
49 using X509_deleter = folly::static_function_deleter<X509, &X509_free>;
50 using EVP_PKEY_deleter =
51     folly::static_function_deleter<EVP_PKEY, &EVP_PKEY_free>;
52
53 } // anonymous namespace
54
55 #ifdef OPENSSL_NPN_NEGOTIATED
56 int SSLContext::sNextProtocolsExDataIndex_ = -1;
57 #endif
58
59 // SSLContext implementation
60 SSLContext::SSLContext(SSLVersion version) {
61   {
62     std::lock_guard<std::mutex> g(initMutex());
63     initializeOpenSSLLocked();
64   }
65
66   ctx_ = SSL_CTX_new(SSLv23_method());
67   if (ctx_ == nullptr) {
68     throw std::runtime_error("SSL_CTX_new: " + getErrors());
69   }
70
71   int opt = 0;
72   switch (version) {
73     case TLSv1:
74       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
75       break;
76     case SSLv3:
77       opt = SSL_OP_NO_SSLv2;
78       break;
79     default:
80       // do nothing
81       break;
82   }
83   int newOpt = SSL_CTX_set_options(ctx_, opt);
84   DCHECK((newOpt & opt) == opt);
85
86   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
87
88   checkPeerName_ = false;
89
90 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
91   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
92   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
93 #endif
94
95 #ifdef OPENSSL_NPN_NEGOTIATED
96   Random::seed(nextProtocolPicker_);
97 #endif
98 }
99
100 SSLContext::~SSLContext() {
101   if (ctx_ != nullptr) {
102     SSL_CTX_free(ctx_);
103     ctx_ = nullptr;
104   }
105
106 #ifdef OPENSSL_NPN_NEGOTIATED
107   deleteNextProtocolsStrings();
108 #endif
109 }
110
111 void SSLContext::ciphers(const std::string& ciphers) {
112   providedCiphersString_ = ciphers;
113   setCiphersOrThrow(ciphers);
114 }
115
116 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
117   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
118   if (ERR_peek_error() != 0) {
119     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
120   }
121   if (rc == 0) {
122     throw std::runtime_error("None of specified ciphers are supported");
123   }
124 }
125
126 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
127     verifyPeer) {
128   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
129   verifyPeer_ = verifyPeer;
130 }
131
132 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
133     verifyPeer) {
134   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
135   int mode = SSL_VERIFY_NONE;
136   switch(verifyPeer) {
137     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
138     // break;
139
140     case SSLVerifyPeerEnum::VERIFY:
141       mode = SSL_VERIFY_PEER;
142       break;
143
144     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
145       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
146       break;
147
148     case SSLVerifyPeerEnum::NO_VERIFY:
149       mode = SSL_VERIFY_NONE;
150       break;
151
152     default:
153       break;
154   }
155   return mode;
156 }
157
158 int SSLContext::getVerificationMode() {
159   return getVerificationMode(verifyPeer_);
160 }
161
162 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
163                               const std::string& peerName) {
164   int mode;
165   if (checkPeerCert) {
166     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
167     checkPeerName_ = checkPeerName;
168     peerFixedName_ = peerName;
169   } else {
170     mode = SSL_VERIFY_NONE;
171     checkPeerName_ = false; // can't check name without cert!
172     peerFixedName_.clear();
173   }
174   SSL_CTX_set_verify(ctx_, mode, nullptr);
175 }
176
177 void SSLContext::loadCertificate(const char* path, const char* format) {
178   if (path == nullptr || format == nullptr) {
179     throw std::invalid_argument(
180          "loadCertificateChain: either <path> or <format> is nullptr");
181   }
182   if (strcmp(format, "PEM") == 0) {
183     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
184       int errnoCopy = errno;
185       std::string reason("SSL_CTX_use_certificate_chain_file: ");
186       reason.append(path);
187       reason.append(": ");
188       reason.append(getErrors(errnoCopy));
189       throw std::runtime_error(reason);
190     }
191   } else {
192     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
193   }
194 }
195
196 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
197   if (cert.data() == nullptr) {
198     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
199   }
200
201   std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
202   if (bio == nullptr) {
203     throw std::runtime_error("BIO_new: " + getErrors());
204   }
205
206   int written = BIO_write(bio.get(), cert.data(), cert.size());
207   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
208     throw std::runtime_error("BIO_write: " + getErrors());
209   }
210
211   std::unique_ptr<X509, X509_deleter> x509(
212       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
213   if (x509 == nullptr) {
214     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
215   }
216
217   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
218     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
219   }
220 }
221
222 void SSLContext::loadPrivateKey(const char* path, const char* format) {
223   if (path == nullptr || format == nullptr) {
224     throw std::invalid_argument(
225         "loadPrivateKey: either <path> or <format> is nullptr");
226   }
227   if (strcmp(format, "PEM") == 0) {
228     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
229       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
230     }
231   } else {
232     throw std::runtime_error("Unsupported private key format: " + std::string(format));
233   }
234 }
235
236 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
237   if (pkey.data() == nullptr) {
238     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
239   }
240
241   std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
242   if (bio == nullptr) {
243     throw std::runtime_error("BIO_new: " + getErrors());
244   }
245
246   int written = BIO_write(bio.get(), pkey.data(), pkey.size());
247   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
248     throw std::runtime_error("BIO_write: " + getErrors());
249   }
250
251   std::unique_ptr<EVP_PKEY, EVP_PKEY_deleter> key(
252       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
253   if (key == nullptr) {
254     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
255   }
256
257   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
258     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
259   }
260 }
261
262 void SSLContext::loadTrustedCertificates(const char* path) {
263   if (path == nullptr) {
264     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
265   }
266   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
267     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
268   }
269 }
270
271 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
272   SSL_CTX_set_cert_store(ctx_, store);
273 }
274
275 void SSLContext::loadClientCAList(const char* path) {
276   auto clientCAs = SSL_load_client_CA_file(path);
277   if (clientCAs == nullptr) {
278     LOG(ERROR) << "Unable to load ca file: " << path;
279     return;
280   }
281   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
282 }
283
284 void SSLContext::randomize() {
285   RAND_poll();
286 }
287
288 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
289   if (collector == nullptr) {
290     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
291     return;
292   }
293   collector_ = collector;
294   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
295   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
296 }
297
298 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
299
300 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
301   serverNameCb_ = cb;
302 }
303
304 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
305   clientHelloCbs_.push_back(cb);
306 }
307
308 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
309   SSLContext* context = (SSLContext*)data;
310
311   if (context == nullptr) {
312     return SSL_TLSEXT_ERR_NOACK;
313   }
314
315   for (auto& cb : context->clientHelloCbs_) {
316     // Generic callbacks to happen after we receive the Client Hello.
317     // For example, we use one to switch which cipher we use depending
318     // on the user's TLS version.  Because the primary purpose of
319     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
320     // are side-uses, we ignore any possible failures other than just logging
321     // them.
322     cb(ssl);
323   }
324
325   if (!context->serverNameCb_) {
326     return SSL_TLSEXT_ERR_NOACK;
327   }
328
329   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
330   switch (ret) {
331     case SERVER_NAME_FOUND:
332       return SSL_TLSEXT_ERR_OK;
333     case SERVER_NAME_NOT_FOUND:
334       return SSL_TLSEXT_ERR_NOACK;
335     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
336       *al = TLS1_AD_UNRECOGNIZED_NAME;
337       return SSL_TLSEXT_ERR_ALERT_FATAL;
338     default:
339       CHECK(false);
340   }
341
342   return SSL_TLSEXT_ERR_NOACK;
343 }
344
345 void SSLContext::switchCiphersIfTLS11(
346     SSL* ssl,
347     const std::string& tls11CipherString) {
348
349   CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
350
351   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
352     // We only do this for TLS v 1.1 and later
353     return;
354   }
355
356   // Prefer AES for TLS versions 1.1 and later since these are not
357   // vulnerable to BEAST attacks on AES.  Note that we're setting the
358   // cipher list on the SSL object, not the SSL_CTX object, so it will
359   // only last for this request.
360   int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
361   if ((rc == 0) || ERR_peek_error() != 0) {
362     // This shouldn't happen since we checked for this when proxygen
363     // started up.
364     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
365     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
366   }
367 }
368 #endif
369
370 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
371 int SSLContext::alpnSelectCallback(SSL* ssl,
372                                    const unsigned char** out,
373                                    unsigned char* outlen,
374                                    const unsigned char* in,
375                                    unsigned int inlen,
376                                    void* data) {
377   SSLContext* context = (SSLContext*)data;
378   CHECK(context);
379   if (context->advertisedNextProtocols_.empty()) {
380     *out = nullptr;
381     *outlen = 0;
382   } else {
383     auto i = context->pickNextProtocols();
384     const auto& item = context->advertisedNextProtocols_[i];
385     if (SSL_select_next_proto((unsigned char**)out,
386                               outlen,
387                               item.protocols,
388                               item.length,
389                               in,
390                               inlen) != OPENSSL_NPN_NEGOTIATED) {
391       return SSL_TLSEXT_ERR_NOACK;
392     }
393   }
394   return SSL_TLSEXT_ERR_OK;
395 }
396 #endif
397
398 #ifdef OPENSSL_NPN_NEGOTIATED
399
400 bool SSLContext::setAdvertisedNextProtocols(
401     const std::list<std::string>& protocols, NextProtocolType protocolType) {
402   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
403 }
404
405 bool SSLContext::setRandomizedAdvertisedNextProtocols(
406     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
407   unsetNextProtocols();
408   if (items.size() == 0) {
409     return false;
410   }
411   int total_weight = 0;
412   for (const auto &item : items) {
413     if (item.protocols.size() == 0) {
414       continue;
415     }
416     AdvertisedNextProtocolsItem advertised_item;
417     advertised_item.length = 0;
418     for (const auto& proto : item.protocols) {
419       ++advertised_item.length;
420       unsigned protoLength = proto.length();
421       if (protoLength >= 256) {
422         deleteNextProtocolsStrings();
423         return false;
424       }
425       advertised_item.length += protoLength;
426     }
427     advertised_item.protocols = new unsigned char[advertised_item.length];
428     if (!advertised_item.protocols) {
429       throw std::runtime_error("alloc failure");
430     }
431     unsigned char* dst = advertised_item.protocols;
432     for (auto& proto : item.protocols) {
433       unsigned protoLength = proto.length();
434       *dst++ = (unsigned char)protoLength;
435       memcpy(dst, proto.data(), protoLength);
436       dst += protoLength;
437     }
438     total_weight += item.weight;
439     advertisedNextProtocols_.push_back(advertised_item);
440     advertisedNextProtocolWeights_.push_back(item.weight);
441   }
442   if (total_weight == 0) {
443     deleteNextProtocolsStrings();
444     return false;
445   }
446   nextProtocolDistribution_ =
447       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
448                                    advertisedNextProtocolWeights_.end());
449   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
450     SSL_CTX_set_next_protos_advertised_cb(
451         ctx_, advertisedNextProtocolCallback, this);
452     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
453   }
454 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
455   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
456     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
457     // Client cannot really use randomized alpn
458     SSL_CTX_set_alpn_protos(ctx_,
459                             advertisedNextProtocols_[0].protocols,
460                             advertisedNextProtocols_[0].length);
461   }
462 #endif
463   return true;
464 }
465
466 void SSLContext::deleteNextProtocolsStrings() {
467   for (auto protocols : advertisedNextProtocols_) {
468     delete[] protocols.protocols;
469   }
470   advertisedNextProtocols_.clear();
471   advertisedNextProtocolWeights_.clear();
472 }
473
474 void SSLContext::unsetNextProtocols() {
475   deleteNextProtocolsStrings();
476   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
477   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
478 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
479   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
480   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
481 #endif
482 }
483
484 size_t SSLContext::pickNextProtocols() {
485   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
486   return nextProtocolDistribution_(nextProtocolPicker_);
487 }
488
489 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
490       const unsigned char** out, unsigned int* outlen, void* data) {
491   SSLContext* context = (SSLContext*)data;
492   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
493     *out = nullptr;
494     *outlen = 0;
495   } else if (context->advertisedNextProtocols_.size() == 1) {
496     *out = context->advertisedNextProtocols_[0].protocols;
497     *outlen = context->advertisedNextProtocols_[0].length;
498   } else {
499     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
500           sNextProtocolsExDataIndex_));
501     if (selected_index) {
502       --selected_index;
503       *out = context->advertisedNextProtocols_[selected_index].protocols;
504       *outlen = context->advertisedNextProtocols_[selected_index].length;
505     } else {
506       auto i = context->pickNextProtocols();
507       uintptr_t selected = i + 1;
508       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
509       *out = context->advertisedNextProtocols_[i].protocols;
510       *outlen = context->advertisedNextProtocols_[i].length;
511     }
512   }
513   return SSL_TLSEXT_ERR_OK;
514 }
515
516 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
517   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
518 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
519   ciphers_{
520     TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
521     TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
522     TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
523     TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
524     TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
525     TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
526     TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
527     TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
528     TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
529     TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
530     TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
531     TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
532     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
533     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
534     TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
535     TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
536     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
537     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
538     TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
539     TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
540     TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
541     TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
542     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
543     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
544     TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
545     TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
546   } {
547   length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
548   width_ = sizeof(ciphers_[0]);
549   qsort(ciphers_, length_, width_, compare_ulong);
550 }
551
552 bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
553   const SSL_CIPHER *cipher) {
554   unsigned long cid = cipher->id;
555   unsigned long *r =
556     (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
557   return r != nullptr;
558 }
559
560 int
561 SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
562   if (*(unsigned long *)x < *(unsigned long *)y) {
563     return -1;
564   }
565   if (*(unsigned long *)x > *(unsigned long *)y) {
566     return 1;
567   }
568   return 0;
569 };
570
571 bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
572   return falseStartChecker_.canUseFalseStartWithCipher(cipher);
573 }
574 #endif
575
576 int SSLContext::selectNextProtocolCallback(
577   SSL* ssl, unsigned char **out, unsigned char *outlen,
578   const unsigned char *server, unsigned int server_len, void *data) {
579
580   SSLContext* ctx = (SSLContext*)data;
581   if (ctx->advertisedNextProtocols_.size() > 1) {
582     VLOG(3) << "SSLContext::selectNextProcolCallback() "
583             << "client should be deterministic in selecting protocols.";
584   }
585
586   unsigned char *client;
587   unsigned int client_len;
588   bool filtered = false;
589   auto cpf = ctx->getClientProtocolFilterCallback();
590   if (cpf) {
591     filtered = (*cpf)(&client, &client_len, server, server_len);
592   }
593
594   if (!filtered) {
595     if (ctx->advertisedNextProtocols_.empty()) {
596       client = (unsigned char *) "";
597       client_len = 0;
598     } else {
599       client = ctx->advertisedNextProtocols_[0].protocols;
600       client_len = ctx->advertisedNextProtocols_[0].length;
601     }
602   }
603
604   int retval = SSL_select_next_proto(out, outlen, server, server_len,
605                                      client, client_len);
606   if (retval != OPENSSL_NPN_NEGOTIATED) {
607     VLOG(3) << "SSLContext::selectNextProcolCallback() "
608             << "unable to pick a next protocol.";
609 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
610   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
611   } else {
612     const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
613     if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
614       SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
615     }
616 #endif
617   }
618   return SSL_TLSEXT_ERR_OK;
619 }
620 #endif // OPENSSL_NPN_NEGOTIATED
621
622 SSL* SSLContext::createSSL() const {
623   SSL* ssl = SSL_new(ctx_);
624   if (ssl == nullptr) {
625     throw std::runtime_error("SSL_new: " + getErrors());
626   }
627   return ssl;
628 }
629
630 /**
631  * Match a name with a pattern. The pattern may include wildcard. A single
632  * wildcard "*" can match up to one component in the domain name.
633  *
634  * @param  host    Host name, typically the name of the remote host
635  * @param  pattern Name retrieved from certificate
636  * @param  size    Size of "pattern"
637  * @return True, if "host" matches "pattern". False otherwise.
638  */
639 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
640   bool match = false;
641   int i = 0, j = 0;
642   while (i < size && host[j] != '\0') {
643     if (toupper(pattern[i]) == toupper(host[j])) {
644       i++;
645       j++;
646       continue;
647     }
648     if (pattern[i] == '*') {
649       while (host[j] != '.' && host[j] != '\0') {
650         j++;
651       }
652       i++;
653       continue;
654     }
655     break;
656   }
657   if (i == size && host[j] == '\0') {
658     match = true;
659   }
660   return match;
661 }
662
663 int SSLContext::passwordCallback(char* password,
664                                  int size,
665                                  int,
666                                  void* data) {
667   SSLContext* context = (SSLContext*)data;
668   if (context == nullptr || context->passwordCollector() == nullptr) {
669     return 0;
670   }
671   std::string userPassword;
672   // call user defined password collector to get password
673   context->passwordCollector()->getPassword(userPassword, size);
674   int length = userPassword.size();
675   if (length > size) {
676     length = size;
677   }
678   strncpy(password, userPassword.c_str(), length);
679   return length;
680 }
681
682 struct SSLLock {
683   explicit SSLLock(
684     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
685       lockType(inLockType) {
686   }
687
688   void lock() {
689     if (lockType == SSLContext::LOCK_MUTEX) {
690       mutex.lock();
691     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
692       spinLock.lock();
693     }
694     // lockType == LOCK_NONE, no-op
695   }
696
697   void unlock() {
698     if (lockType == SSLContext::LOCK_MUTEX) {
699       mutex.unlock();
700     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
701       spinLock.unlock();
702     }
703     // lockType == LOCK_NONE, no-op
704   }
705
706   SSLContext::SSLLockType lockType;
707   folly::SpinLock spinLock{};
708   std::mutex mutex;
709 };
710
711 // Statics are unsafe in environments that call exit().
712 // If one thread calls exit() while another thread is
713 // references a member of SSLContext, bad things can happen.
714 // SSLContext runs in such environments.
715 // Instead of declaring a static member we "new" the static
716 // member so that it won't be destructed on exit().
717 static std::unique_ptr<SSLLock[]>& locks() {
718   static auto locksInst = new std::unique_ptr<SSLLock[]>();
719   return *locksInst;
720 }
721
722 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
723   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
724   return *lockTypesInst;
725 }
726
727 static void callbackLocking(int mode, int n, const char*, int) {
728   if (mode & CRYPTO_LOCK) {
729     locks()[n].lock();
730   } else {
731     locks()[n].unlock();
732   }
733 }
734
735 static unsigned long callbackThreadID() {
736   return static_cast<unsigned long>(
737 #ifdef __APPLE__
738     pthread_mach_thread_np(pthread_self())
739 #else
740     pthread_self()
741 #endif
742   );
743 }
744
745 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
746   return new CRYPTO_dynlock_value;
747 }
748
749 static void dyn_lock(int mode,
750                      struct CRYPTO_dynlock_value* lock,
751                      const char*, int) {
752   if (lock != nullptr) {
753     if (mode & CRYPTO_LOCK) {
754       lock->mutex.lock();
755     } else {
756       lock->mutex.unlock();
757     }
758   }
759 }
760
761 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
762   delete lock;
763 }
764
765 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
766   lockTypes() = inLockTypes;
767 }
768
769 void SSLContext::markInitialized() {
770   std::lock_guard<std::mutex> g(initMutex());
771   initialized_ = true;
772 }
773
774 void SSLContext::initializeOpenSSL() {
775   std::lock_guard<std::mutex> g(initMutex());
776   initializeOpenSSLLocked();
777 }
778
779 void SSLContext::initializeOpenSSLLocked() {
780   if (initialized_) {
781     return;
782   }
783   SSL_library_init();
784   SSL_load_error_strings();
785   ERR_load_crypto_strings();
786   // static locking
787   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
788   for (auto it: lockTypes()) {
789     locks()[it.first].lockType = it.second;
790   }
791   CRYPTO_set_id_callback(callbackThreadID);
792   CRYPTO_set_locking_callback(callbackLocking);
793   // dynamic locking
794   CRYPTO_set_dynlock_create_callback(dyn_create);
795   CRYPTO_set_dynlock_lock_callback(dyn_lock);
796   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
797   randomize();
798 #ifdef OPENSSL_NPN_NEGOTIATED
799   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
800       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
801 #endif
802   initialized_ = true;
803 }
804
805 void SSLContext::cleanupOpenSSL() {
806   std::lock_guard<std::mutex> g(initMutex());
807   cleanupOpenSSLLocked();
808 }
809
810 void SSLContext::cleanupOpenSSLLocked() {
811   if (!initialized_) {
812     return;
813   }
814
815   CRYPTO_set_id_callback(nullptr);
816   CRYPTO_set_locking_callback(nullptr);
817   CRYPTO_set_dynlock_create_callback(nullptr);
818   CRYPTO_set_dynlock_lock_callback(nullptr);
819   CRYPTO_set_dynlock_destroy_callback(nullptr);
820   CRYPTO_cleanup_all_ex_data();
821   ERR_free_strings();
822   EVP_cleanup();
823   ERR_remove_state(0);
824   locks().reset();
825   initialized_ = false;
826 }
827
828 void SSLContext::setOptions(long options) {
829   long newOpt = SSL_CTX_set_options(ctx_, options);
830   if ((newOpt & options) != options) {
831     throw std::runtime_error("SSL_CTX_set_options failed");
832   }
833 }
834
835 std::string SSLContext::getErrors(int errnoCopy) {
836   std::string errors;
837   unsigned long  errorCode;
838   char   message[256];
839
840   errors.reserve(512);
841   while ((errorCode = ERR_get_error()) != 0) {
842     if (!errors.empty()) {
843       errors += "; ";
844     }
845     const char* reason = ERR_reason_error_string(errorCode);
846     if (reason == nullptr) {
847       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
848       reason = message;
849     }
850     errors += reason;
851   }
852   if (errors.empty()) {
853     errors = "error code: " + folly::to<std::string>(errnoCopy);
854   }
855   return errors;
856 }
857
858 std::ostream&
859 operator<<(std::ostream& os, const PasswordCollector& collector) {
860   os << collector.describe();
861   return os;
862 }
863
864 bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
865                                                   sockaddr_storage* addrStorage,
866                                                   socklen_t* addrLen) {
867   // Grab the ssl idx and then the ssl object so that we can get the peer
868   // name to compare against the ips in the subjectAltName
869   auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
870   auto ssl =
871     reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
872   int fd = SSL_get_fd(ssl);
873   if (fd < 0) {
874     LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
875     return false;
876   }
877
878   *addrLen = sizeof(*addrStorage);
879   if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
880     PLOG(ERROR) << "Unable to get peer name";
881     return false;
882   }
883   CHECK(*addrLen <= sizeof(*addrStorage));
884   return true;
885 }
886
887 bool OpenSSLUtils::validatePeerCertNames(X509* cert,
888                                          const sockaddr* addr,
889                                          socklen_t addrLen) {
890   // Try to extract the names within the SAN extension from the certificate
891   auto altNames =
892     reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
893         X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
894   SCOPE_EXIT {
895     if (altNames != nullptr) {
896       sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
897     }
898   };
899   if (altNames == nullptr) {
900     LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
901     return false;
902   }
903
904   const sockaddr_in* addr4 = nullptr;
905   const sockaddr_in6* addr6 = nullptr;
906   if (addr != nullptr) {
907     if (addr->sa_family == AF_INET) {
908       addr4 = reinterpret_cast<const sockaddr_in*>(addr);
909     } else if (addr->sa_family == AF_INET6) {
910       addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
911     } else {
912       LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
913     }
914   }
915
916
917   for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
918     auto name = sk_GENERAL_NAME_value(altNames, i);
919     if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
920       // Extra const-ness for paranoia
921       unsigned char const * const rawIpStr = name->d.iPAddress->data;
922       int const rawIpLen = name->d.iPAddress->length;
923
924       if (rawIpLen == 4 && addr4 != nullptr) {
925         if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
926           return true;
927         }
928       } else if (rawIpLen == 16 && addr6 != nullptr) {
929         if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
930           return true;
931         }
932       } else if (rawIpLen != 4 && rawIpLen != 16) {
933         LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
934       }
935     }
936   }
937
938   LOG(WARNING) << "Unable to match client cert against alt name ip";
939   return false;
940 }
941
942
943 } // folly