2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SSLContext.h"
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/SpinLock.h>
27 #include <folly/io/async/OpenSSLPtrTypes.h>
29 // ---------------------------------------------------------------------
30 // SSLContext implementation
31 // ---------------------------------------------------------------------
33 struct CRYPTO_dynlock_value {
39 bool SSLContext::initialized_ = false;
43 std::mutex& initMutex() {
48 inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
49 using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
51 } // anonymous namespace
53 #ifdef OPENSSL_NPN_NEGOTIATED
54 int SSLContext::sNextProtocolsExDataIndex_ = -1;
57 // SSLContext implementation
58 SSLContext::SSLContext(SSLVersion version) {
60 std::lock_guard<std::mutex> g(initMutex());
61 initializeOpenSSLLocked();
64 ctx_ = SSL_CTX_new(SSLv23_method());
65 if (ctx_ == nullptr) {
66 throw std::runtime_error("SSL_CTX_new: " + getErrors());
72 opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
75 opt = SSL_OP_NO_SSLv2;
81 int newOpt = SSL_CTX_set_options(ctx_, opt);
82 DCHECK((newOpt & opt) == opt);
84 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
86 checkPeerName_ = false;
88 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
89 SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
90 SSL_CTX_set_tlsext_servername_arg(ctx_, this);
93 #ifdef OPENSSL_NPN_NEGOTIATED
94 Random::seed(nextProtocolPicker_);
98 SSLContext::~SSLContext() {
99 if (ctx_ != nullptr) {
104 #ifdef OPENSSL_NPN_NEGOTIATED
105 deleteNextProtocolsStrings();
109 void SSLContext::ciphers(const std::string& ciphers) {
110 providedCiphersString_ = ciphers;
111 setCiphersOrThrow(ciphers);
114 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
115 int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
116 if (ERR_peek_error() != 0) {
117 throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
120 throw std::runtime_error("None of specified ciphers are supported");
124 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
126 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
127 verifyPeer_ = verifyPeer;
130 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
132 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
133 int mode = SSL_VERIFY_NONE;
135 // case SSLVerifyPeerEnum::USE_CTX: // can't happen
138 case SSLVerifyPeerEnum::VERIFY:
139 mode = SSL_VERIFY_PEER;
142 case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
143 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
146 case SSLVerifyPeerEnum::NO_VERIFY:
147 mode = SSL_VERIFY_NONE;
156 int SSLContext::getVerificationMode() {
157 return getVerificationMode(verifyPeer_);
160 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
161 const std::string& peerName) {
164 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
165 checkPeerName_ = checkPeerName;
166 peerFixedName_ = peerName;
168 mode = SSL_VERIFY_NONE;
169 checkPeerName_ = false; // can't check name without cert!
170 peerFixedName_.clear();
172 SSL_CTX_set_verify(ctx_, mode, nullptr);
175 void SSLContext::loadCertificate(const char* path, const char* format) {
176 if (path == nullptr || format == nullptr) {
177 throw std::invalid_argument(
178 "loadCertificateChain: either <path> or <format> is nullptr");
180 if (strcmp(format, "PEM") == 0) {
181 if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
182 int errnoCopy = errno;
183 std::string reason("SSL_CTX_use_certificate_chain_file: ");
186 reason.append(getErrors(errnoCopy));
187 throw std::runtime_error(reason);
190 throw std::runtime_error("Unsupported certificate format: " + std::string(format));
194 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
195 if (cert.data() == nullptr) {
196 throw std::invalid_argument("loadCertificate: <cert> is nullptr");
199 std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
200 if (bio == nullptr) {
201 throw std::runtime_error("BIO_new: " + getErrors());
204 int written = BIO_write(bio.get(), cert.data(), cert.size());
205 if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
206 throw std::runtime_error("BIO_write: " + getErrors());
209 X509_UniquePtr x509(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
210 if (x509 == nullptr) {
211 throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
214 if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
215 throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
219 void SSLContext::loadPrivateKey(const char* path, const char* format) {
220 if (path == nullptr || format == nullptr) {
221 throw std::invalid_argument(
222 "loadPrivateKey: either <path> or <format> is nullptr");
224 if (strcmp(format, "PEM") == 0) {
225 if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
226 throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
229 throw std::runtime_error("Unsupported private key format: " + std::string(format));
233 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
234 if (pkey.data() == nullptr) {
235 throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
238 std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
239 if (bio == nullptr) {
240 throw std::runtime_error("BIO_new: " + getErrors());
243 int written = BIO_write(bio.get(), pkey.data(), pkey.size());
244 if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
245 throw std::runtime_error("BIO_write: " + getErrors());
248 EVP_PKEY_UniquePtr key(
249 PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
250 if (key == nullptr) {
251 throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
254 if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
255 throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
259 void SSLContext::loadTrustedCertificates(const char* path) {
260 if (path == nullptr) {
261 throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
263 if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
264 throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
268 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
269 SSL_CTX_set_cert_store(ctx_, store);
272 void SSLContext::loadClientCAList(const char* path) {
273 auto clientCAs = SSL_load_client_CA_file(path);
274 if (clientCAs == nullptr) {
275 LOG(ERROR) << "Unable to load ca file: " << path;
278 SSL_CTX_set_client_CA_list(ctx_, clientCAs);
281 void SSLContext::randomize() {
285 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
286 if (collector == nullptr) {
287 LOG(ERROR) << "passwordCollector: ignore invalid password collector";
290 collector_ = collector;
291 SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
292 SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
295 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
297 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
301 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
302 clientHelloCbs_.push_back(cb);
305 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
306 SSLContext* context = (SSLContext*)data;
308 if (context == nullptr) {
309 return SSL_TLSEXT_ERR_NOACK;
312 for (auto& cb : context->clientHelloCbs_) {
313 // Generic callbacks to happen after we receive the Client Hello.
314 // For example, we use one to switch which cipher we use depending
315 // on the user's TLS version. Because the primary purpose of
316 // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
317 // are side-uses, we ignore any possible failures other than just logging
322 if (!context->serverNameCb_) {
323 return SSL_TLSEXT_ERR_NOACK;
326 ServerNameCallbackResult ret = context->serverNameCb_(ssl);
328 case SERVER_NAME_FOUND:
329 return SSL_TLSEXT_ERR_OK;
330 case SERVER_NAME_NOT_FOUND:
331 return SSL_TLSEXT_ERR_NOACK;
332 case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
333 *al = TLS1_AD_UNRECOGNIZED_NAME;
334 return SSL_TLSEXT_ERR_ALERT_FATAL;
339 return SSL_TLSEXT_ERR_NOACK;
342 void SSLContext::switchCiphersIfTLS11(
344 const std::string& tls11CipherString) {
346 CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
348 if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
349 // We only do this for TLS v 1.1 and later
353 // Prefer AES for TLS versions 1.1 and later since these are not
354 // vulnerable to BEAST attacks on AES. Note that we're setting the
355 // cipher list on the SSL object, not the SSL_CTX object, so it will
356 // only last for this request.
357 int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
358 if ((rc == 0) || ERR_peek_error() != 0) {
359 // This shouldn't happen since we checked for this when proxygen
361 LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
362 SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
367 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
368 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
369 const unsigned char** out,
370 unsigned char* outlen,
371 const unsigned char* in,
374 SSLContext* context = (SSLContext*)data;
376 if (context->advertisedNextProtocols_.empty()) {
380 auto i = context->pickNextProtocols();
381 const auto& item = context->advertisedNextProtocols_[i];
382 if (SSL_select_next_proto((unsigned char**)out,
387 inlen) != OPENSSL_NPN_NEGOTIATED) {
388 return SSL_TLSEXT_ERR_NOACK;
391 return SSL_TLSEXT_ERR_OK;
395 #ifdef OPENSSL_NPN_NEGOTIATED
397 bool SSLContext::setAdvertisedNextProtocols(
398 const std::list<std::string>& protocols, NextProtocolType protocolType) {
399 return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
402 bool SSLContext::setRandomizedAdvertisedNextProtocols(
403 const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
404 unsetNextProtocols();
405 if (items.size() == 0) {
408 int total_weight = 0;
409 for (const auto &item : items) {
410 if (item.protocols.size() == 0) {
413 AdvertisedNextProtocolsItem advertised_item;
414 advertised_item.length = 0;
415 for (const auto& proto : item.protocols) {
416 ++advertised_item.length;
417 unsigned protoLength = proto.length();
418 if (protoLength >= 256) {
419 deleteNextProtocolsStrings();
422 advertised_item.length += protoLength;
424 advertised_item.protocols = new unsigned char[advertised_item.length];
425 if (!advertised_item.protocols) {
426 throw std::runtime_error("alloc failure");
428 unsigned char* dst = advertised_item.protocols;
429 for (auto& proto : item.protocols) {
430 unsigned protoLength = proto.length();
431 *dst++ = (unsigned char)protoLength;
432 memcpy(dst, proto.data(), protoLength);
435 total_weight += item.weight;
436 advertisedNextProtocols_.push_back(advertised_item);
437 advertisedNextProtocolWeights_.push_back(item.weight);
439 if (total_weight == 0) {
440 deleteNextProtocolsStrings();
443 nextProtocolDistribution_ =
444 std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
445 advertisedNextProtocolWeights_.end());
446 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
447 SSL_CTX_set_next_protos_advertised_cb(
448 ctx_, advertisedNextProtocolCallback, this);
449 SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
451 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
452 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
453 SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
454 // Client cannot really use randomized alpn
455 SSL_CTX_set_alpn_protos(ctx_,
456 advertisedNextProtocols_[0].protocols,
457 advertisedNextProtocols_[0].length);
463 void SSLContext::deleteNextProtocolsStrings() {
464 for (auto protocols : advertisedNextProtocols_) {
465 delete[] protocols.protocols;
467 advertisedNextProtocols_.clear();
468 advertisedNextProtocolWeights_.clear();
471 void SSLContext::unsetNextProtocols() {
472 deleteNextProtocolsStrings();
473 SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
474 SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
475 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
476 SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
477 SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
481 size_t SSLContext::pickNextProtocols() {
482 CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
483 return nextProtocolDistribution_(nextProtocolPicker_);
486 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
487 const unsigned char** out, unsigned int* outlen, void* data) {
488 SSLContext* context = (SSLContext*)data;
489 if (context == nullptr || context->advertisedNextProtocols_.empty()) {
492 } else if (context->advertisedNextProtocols_.size() == 1) {
493 *out = context->advertisedNextProtocols_[0].protocols;
494 *outlen = context->advertisedNextProtocols_[0].length;
496 uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
497 sNextProtocolsExDataIndex_));
498 if (selected_index) {
500 *out = context->advertisedNextProtocols_[selected_index].protocols;
501 *outlen = context->advertisedNextProtocols_[selected_index].length;
503 auto i = context->pickNextProtocols();
504 uintptr_t selected = i + 1;
505 SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
506 *out = context->advertisedNextProtocols_[i].protocols;
507 *outlen = context->advertisedNextProtocols_[i].length;
510 return SSL_TLSEXT_ERR_OK;
513 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
514 FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
515 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
517 TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
518 TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
519 TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
520 TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
521 TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
522 TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
523 TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
524 TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
525 TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
526 TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
527 TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
528 TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
529 TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
530 TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
531 TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
532 TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
533 TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
534 TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
535 TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
536 TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
537 TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
538 TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
539 TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
540 TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
541 TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
542 TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
544 length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
545 width_ = sizeof(ciphers_[0]);
546 qsort(ciphers_, length_, width_, compare_ulong);
549 bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
550 const SSL_CIPHER *cipher) {
551 unsigned long cid = cipher->id;
553 (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
558 SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
559 if (*(unsigned long *)x < *(unsigned long *)y) {
562 if (*(unsigned long *)x > *(unsigned long *)y) {
568 bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
569 return falseStartChecker_.canUseFalseStartWithCipher(cipher);
573 int SSLContext::selectNextProtocolCallback(SSL* ssl,
575 unsigned char* outlen,
576 const unsigned char* server,
577 unsigned int server_len,
579 (void)ssl; // Make -Wunused-parameters happy
580 SSLContext* ctx = (SSLContext*)data;
581 if (ctx->advertisedNextProtocols_.size() > 1) {
582 VLOG(3) << "SSLContext::selectNextProcolCallback() "
583 << "client should be deterministic in selecting protocols.";
586 unsigned char *client;
587 unsigned int client_len;
588 bool filtered = false;
589 auto cpf = ctx->getClientProtocolFilterCallback();
591 filtered = (*cpf)(&client, &client_len, server, server_len);
595 if (ctx->advertisedNextProtocols_.empty()) {
596 client = (unsigned char *) "";
599 client = ctx->advertisedNextProtocols_[0].protocols;
600 client_len = ctx->advertisedNextProtocols_[0].length;
604 int retval = SSL_select_next_proto(out, outlen, server, server_len,
606 if (retval != OPENSSL_NPN_NEGOTIATED) {
607 VLOG(3) << "SSLContext::selectNextProcolCallback() "
608 << "unable to pick a next protocol.";
609 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
610 FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
612 const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
613 if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
614 SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
618 return SSL_TLSEXT_ERR_OK;
620 #endif // OPENSSL_NPN_NEGOTIATED
622 SSL* SSLContext::createSSL() const {
623 SSL* ssl = SSL_new(ctx_);
624 if (ssl == nullptr) {
625 throw std::runtime_error("SSL_new: " + getErrors());
631 * Match a name with a pattern. The pattern may include wildcard. A single
632 * wildcard "*" can match up to one component in the domain name.
634 * @param host Host name, typically the name of the remote host
635 * @param pattern Name retrieved from certificate
636 * @param size Size of "pattern"
637 * @return True, if "host" matches "pattern". False otherwise.
639 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
642 while (i < size && host[j] != '\0') {
643 if (toupper(pattern[i]) == toupper(host[j])) {
648 if (pattern[i] == '*') {
649 while (host[j] != '.' && host[j] != '\0') {
657 if (i == size && host[j] == '\0') {
663 int SSLContext::passwordCallback(char* password,
667 SSLContext* context = (SSLContext*)data;
668 if (context == nullptr || context->passwordCollector() == nullptr) {
671 std::string userPassword;
672 // call user defined password collector to get password
673 context->passwordCollector()->getPassword(userPassword, size);
674 int length = userPassword.size();
678 strncpy(password, userPassword.c_str(), length);
684 SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
685 lockType(inLockType) {
689 if (lockType == SSLContext::LOCK_MUTEX) {
691 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
694 // lockType == LOCK_NONE, no-op
698 if (lockType == SSLContext::LOCK_MUTEX) {
700 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
703 // lockType == LOCK_NONE, no-op
706 SSLContext::SSLLockType lockType;
707 folly::SpinLock spinLock{};
711 // Statics are unsafe in environments that call exit().
712 // If one thread calls exit() while another thread is
713 // references a member of SSLContext, bad things can happen.
714 // SSLContext runs in such environments.
715 // Instead of declaring a static member we "new" the static
716 // member so that it won't be destructed on exit().
717 static std::unique_ptr<SSLLock[]>& locks() {
718 static auto locksInst = new std::unique_ptr<SSLLock[]>();
722 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
723 static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
724 return *lockTypesInst;
727 static void callbackLocking(int mode, int n, const char*, int) {
728 if (mode & CRYPTO_LOCK) {
735 static unsigned long callbackThreadID() {
736 return static_cast<unsigned long>(
738 pthread_mach_thread_np(pthread_self())
745 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
746 return new CRYPTO_dynlock_value;
749 static void dyn_lock(int mode,
750 struct CRYPTO_dynlock_value* lock,
752 if (lock != nullptr) {
753 if (mode & CRYPTO_LOCK) {
756 lock->mutex.unlock();
761 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
765 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
766 lockTypes() = inLockTypes;
769 void SSLContext::markInitialized() {
770 std::lock_guard<std::mutex> g(initMutex());
774 void SSLContext::initializeOpenSSL() {
775 std::lock_guard<std::mutex> g(initMutex());
776 initializeOpenSSLLocked();
779 void SSLContext::initializeOpenSSLLocked() {
784 SSL_load_error_strings();
785 ERR_load_crypto_strings();
787 locks().reset(new SSLLock[::CRYPTO_num_locks()]);
788 for (auto it: lockTypes()) {
789 locks()[it.first].lockType = it.second;
791 CRYPTO_set_id_callback(callbackThreadID);
792 CRYPTO_set_locking_callback(callbackLocking);
794 CRYPTO_set_dynlock_create_callback(dyn_create);
795 CRYPTO_set_dynlock_lock_callback(dyn_lock);
796 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
798 #ifdef OPENSSL_NPN_NEGOTIATED
799 sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
800 (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
805 void SSLContext::cleanupOpenSSL() {
806 std::lock_guard<std::mutex> g(initMutex());
807 cleanupOpenSSLLocked();
810 void SSLContext::cleanupOpenSSLLocked() {
815 CRYPTO_set_id_callback(nullptr);
816 CRYPTO_set_locking_callback(nullptr);
817 CRYPTO_set_dynlock_create_callback(nullptr);
818 CRYPTO_set_dynlock_lock_callback(nullptr);
819 CRYPTO_set_dynlock_destroy_callback(nullptr);
820 CRYPTO_cleanup_all_ex_data();
825 initialized_ = false;
828 void SSLContext::setOptions(long options) {
829 long newOpt = SSL_CTX_set_options(ctx_, options);
830 if ((newOpt & options) != options) {
831 throw std::runtime_error("SSL_CTX_set_options failed");
835 std::string SSLContext::getErrors(int errnoCopy) {
837 unsigned long errorCode;
841 while ((errorCode = ERR_get_error()) != 0) {
842 if (!errors.empty()) {
845 const char* reason = ERR_reason_error_string(errorCode);
846 if (reason == nullptr) {
847 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
852 if (errors.empty()) {
853 errors = "error code: " + folly::to<std::string>(errnoCopy);
859 operator<<(std::ostream& os, const PasswordCollector& collector) {
860 os << collector.describe();
864 bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
865 sockaddr_storage* addrStorage,
866 socklen_t* addrLen) {
867 // Grab the ssl idx and then the ssl object so that we can get the peer
868 // name to compare against the ips in the subjectAltName
869 auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
871 reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
872 int fd = SSL_get_fd(ssl);
874 LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
878 *addrLen = sizeof(*addrStorage);
879 if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
880 PLOG(ERROR) << "Unable to get peer name";
883 CHECK(*addrLen <= sizeof(*addrStorage));
887 bool OpenSSLUtils::validatePeerCertNames(X509* cert,
888 const sockaddr* addr,
889 socklen_t /* addrLen */) {
890 // Try to extract the names within the SAN extension from the certificate
892 reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
893 X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
895 if (altNames != nullptr) {
896 sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
899 if (altNames == nullptr) {
900 LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
904 const sockaddr_in* addr4 = nullptr;
905 const sockaddr_in6* addr6 = nullptr;
906 if (addr != nullptr) {
907 if (addr->sa_family == AF_INET) {
908 addr4 = reinterpret_cast<const sockaddr_in*>(addr);
909 } else if (addr->sa_family == AF_INET6) {
910 addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
912 LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
917 for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
918 auto name = sk_GENERAL_NAME_value(altNames, i);
919 if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
920 // Extra const-ness for paranoia
921 unsigned char const * const rawIpStr = name->d.iPAddress->data;
922 int const rawIpLen = name->d.iPAddress->length;
924 if (rawIpLen == 4 && addr4 != nullptr) {
925 if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
928 } else if (rawIpLen == 16 && addr6 != nullptr) {
929 if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
932 } else if (rawIpLen != 4 && rawIpLen != 16) {
933 LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
938 LOG(WARNING) << "Unable to match client cert against alt name ip";