2 * Copyright 2015 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SSLContext.h"
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/SpinLock.h>
28 // ---------------------------------------------------------------------
29 // SSLContext implementation
30 // ---------------------------------------------------------------------
32 struct CRYPTO_dynlock_value {
38 bool SSLContext::initialized_ = false;
42 std::mutex& initMutex() {
47 inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
48 using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
49 using X509_deleter = folly::static_function_deleter<X509, &X509_free>;
50 using EVP_PKEY_deleter =
51 folly::static_function_deleter<EVP_PKEY, &EVP_PKEY_free>;
53 } // anonymous namespace
55 #ifdef OPENSSL_NPN_NEGOTIATED
56 int SSLContext::sNextProtocolsExDataIndex_ = -1;
59 // SSLContext implementation
60 SSLContext::SSLContext(SSLVersion version) {
62 std::lock_guard<std::mutex> g(initMutex());
63 initializeOpenSSLLocked();
66 ctx_ = SSL_CTX_new(SSLv23_method());
67 if (ctx_ == nullptr) {
68 throw std::runtime_error("SSL_CTX_new: " + getErrors());
74 opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
77 opt = SSL_OP_NO_SSLv2;
83 int newOpt = SSL_CTX_set_options(ctx_, opt);
84 DCHECK((newOpt & opt) == opt);
86 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
88 checkPeerName_ = false;
90 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
91 SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
92 SSL_CTX_set_tlsext_servername_arg(ctx_, this);
95 #ifdef OPENSSL_NPN_NEGOTIATED
96 Random::seed(nextProtocolPicker_);
100 SSLContext::~SSLContext() {
101 if (ctx_ != nullptr) {
106 #ifdef OPENSSL_NPN_NEGOTIATED
107 deleteNextProtocolsStrings();
111 void SSLContext::ciphers(const std::string& ciphers) {
112 providedCiphersString_ = ciphers;
113 setCiphersOrThrow(ciphers);
116 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
117 int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
118 if (ERR_peek_error() != 0) {
119 throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
122 throw std::runtime_error("None of specified ciphers are supported");
126 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
128 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
129 verifyPeer_ = verifyPeer;
132 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
134 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
135 int mode = SSL_VERIFY_NONE;
137 // case SSLVerifyPeerEnum::USE_CTX: // can't happen
140 case SSLVerifyPeerEnum::VERIFY:
141 mode = SSL_VERIFY_PEER;
144 case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
145 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
148 case SSLVerifyPeerEnum::NO_VERIFY:
149 mode = SSL_VERIFY_NONE;
158 int SSLContext::getVerificationMode() {
159 return getVerificationMode(verifyPeer_);
162 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
163 const std::string& peerName) {
166 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
167 checkPeerName_ = checkPeerName;
168 peerFixedName_ = peerName;
170 mode = SSL_VERIFY_NONE;
171 checkPeerName_ = false; // can't check name without cert!
172 peerFixedName_.clear();
174 SSL_CTX_set_verify(ctx_, mode, nullptr);
177 void SSLContext::loadCertificate(const char* path, const char* format) {
178 if (path == nullptr || format == nullptr) {
179 throw std::invalid_argument(
180 "loadCertificateChain: either <path> or <format> is nullptr");
182 if (strcmp(format, "PEM") == 0) {
183 if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
184 int errnoCopy = errno;
185 std::string reason("SSL_CTX_use_certificate_chain_file: ");
188 reason.append(getErrors(errnoCopy));
189 throw std::runtime_error(reason);
192 throw std::runtime_error("Unsupported certificate format: " + std::string(format));
196 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
197 if (cert.data() == nullptr) {
198 throw std::invalid_argument("loadCertificate: <cert> is nullptr");
201 std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
202 if (bio == nullptr) {
203 throw std::runtime_error("BIO_new: " + getErrors());
206 int written = BIO_write(bio.get(), cert.data(), cert.size());
207 if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
208 throw std::runtime_error("BIO_write: " + getErrors());
211 std::unique_ptr<X509, X509_deleter> x509(
212 PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
213 if (x509 == nullptr) {
214 throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
217 if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
218 throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
222 void SSLContext::loadPrivateKey(const char* path, const char* format) {
223 if (path == nullptr || format == nullptr) {
224 throw std::invalid_argument(
225 "loadPrivateKey: either <path> or <format> is nullptr");
227 if (strcmp(format, "PEM") == 0) {
228 if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
229 throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
232 throw std::runtime_error("Unsupported private key format: " + std::string(format));
236 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
237 if (pkey.data() == nullptr) {
238 throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
241 std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
242 if (bio == nullptr) {
243 throw std::runtime_error("BIO_new: " + getErrors());
246 int written = BIO_write(bio.get(), pkey.data(), pkey.size());
247 if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
248 throw std::runtime_error("BIO_write: " + getErrors());
251 std::unique_ptr<EVP_PKEY, EVP_PKEY_deleter> key(
252 PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
253 if (key == nullptr) {
254 throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
257 if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
258 throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
262 void SSLContext::loadTrustedCertificates(const char* path) {
263 if (path == nullptr) {
264 throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
266 if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
267 throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
271 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
272 SSL_CTX_set_cert_store(ctx_, store);
275 void SSLContext::loadClientCAList(const char* path) {
276 auto clientCAs = SSL_load_client_CA_file(path);
277 if (clientCAs == nullptr) {
278 LOG(ERROR) << "Unable to load ca file: " << path;
281 SSL_CTX_set_client_CA_list(ctx_, clientCAs);
284 void SSLContext::randomize() {
288 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
289 if (collector == nullptr) {
290 LOG(ERROR) << "passwordCollector: ignore invalid password collector";
293 collector_ = collector;
294 SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
295 SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
298 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
300 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
304 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
305 clientHelloCbs_.push_back(cb);
308 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
309 SSLContext* context = (SSLContext*)data;
311 if (context == nullptr) {
312 return SSL_TLSEXT_ERR_NOACK;
315 for (auto& cb : context->clientHelloCbs_) {
316 // Generic callbacks to happen after we receive the Client Hello.
317 // For example, we use one to switch which cipher we use depending
318 // on the user's TLS version. Because the primary purpose of
319 // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
320 // are side-uses, we ignore any possible failures other than just logging
325 if (!context->serverNameCb_) {
326 return SSL_TLSEXT_ERR_NOACK;
329 ServerNameCallbackResult ret = context->serverNameCb_(ssl);
331 case SERVER_NAME_FOUND:
332 return SSL_TLSEXT_ERR_OK;
333 case SERVER_NAME_NOT_FOUND:
334 return SSL_TLSEXT_ERR_NOACK;
335 case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
336 *al = TLS1_AD_UNRECOGNIZED_NAME;
337 return SSL_TLSEXT_ERR_ALERT_FATAL;
342 return SSL_TLSEXT_ERR_NOACK;
345 void SSLContext::switchCiphersIfTLS11(
347 const std::string& tls11CipherString) {
349 CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
351 if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
352 // We only do this for TLS v 1.1 and later
356 // Prefer AES for TLS versions 1.1 and later since these are not
357 // vulnerable to BEAST attacks on AES. Note that we're setting the
358 // cipher list on the SSL object, not the SSL_CTX object, so it will
359 // only last for this request.
360 int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
361 if ((rc == 0) || ERR_peek_error() != 0) {
362 // This shouldn't happen since we checked for this when proxygen
364 LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
365 SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
370 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
371 int SSLContext::alpnSelectCallback(SSL* ssl,
372 const unsigned char** out,
373 unsigned char* outlen,
374 const unsigned char* in,
377 SSLContext* context = (SSLContext*)data;
379 if (context->advertisedNextProtocols_.empty()) {
383 auto i = context->pickNextProtocols();
384 const auto& item = context->advertisedNextProtocols_[i];
385 if (SSL_select_next_proto((unsigned char**)out,
390 inlen) != OPENSSL_NPN_NEGOTIATED) {
391 return SSL_TLSEXT_ERR_NOACK;
394 return SSL_TLSEXT_ERR_OK;
398 #ifdef OPENSSL_NPN_NEGOTIATED
400 bool SSLContext::setAdvertisedNextProtocols(
401 const std::list<std::string>& protocols, NextProtocolType protocolType) {
402 return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
405 bool SSLContext::setRandomizedAdvertisedNextProtocols(
406 const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
407 unsetNextProtocols();
408 if (items.size() == 0) {
411 int total_weight = 0;
412 for (const auto &item : items) {
413 if (item.protocols.size() == 0) {
416 AdvertisedNextProtocolsItem advertised_item;
417 advertised_item.length = 0;
418 for (const auto& proto : item.protocols) {
419 ++advertised_item.length;
420 unsigned protoLength = proto.length();
421 if (protoLength >= 256) {
422 deleteNextProtocolsStrings();
425 advertised_item.length += protoLength;
427 advertised_item.protocols = new unsigned char[advertised_item.length];
428 if (!advertised_item.protocols) {
429 throw std::runtime_error("alloc failure");
431 unsigned char* dst = advertised_item.protocols;
432 for (auto& proto : item.protocols) {
433 unsigned protoLength = proto.length();
434 *dst++ = (unsigned char)protoLength;
435 memcpy(dst, proto.data(), protoLength);
438 total_weight += item.weight;
439 advertisedNextProtocols_.push_back(advertised_item);
440 advertisedNextProtocolWeights_.push_back(item.weight);
442 if (total_weight == 0) {
443 deleteNextProtocolsStrings();
446 nextProtocolDistribution_ =
447 std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
448 advertisedNextProtocolWeights_.end());
449 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
450 SSL_CTX_set_next_protos_advertised_cb(
451 ctx_, advertisedNextProtocolCallback, this);
452 SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
454 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
455 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
456 SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
457 // Client cannot really use randomized alpn
458 SSL_CTX_set_alpn_protos(ctx_,
459 advertisedNextProtocols_[0].protocols,
460 advertisedNextProtocols_[0].length);
466 void SSLContext::deleteNextProtocolsStrings() {
467 for (auto protocols : advertisedNextProtocols_) {
468 delete[] protocols.protocols;
470 advertisedNextProtocols_.clear();
471 advertisedNextProtocolWeights_.clear();
474 void SSLContext::unsetNextProtocols() {
475 deleteNextProtocolsStrings();
476 SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
477 SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
478 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
479 SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
480 SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
484 size_t SSLContext::pickNextProtocols() {
485 CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
486 return nextProtocolDistribution_(nextProtocolPicker_);
489 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
490 const unsigned char** out, unsigned int* outlen, void* data) {
491 SSLContext* context = (SSLContext*)data;
492 if (context == nullptr || context->advertisedNextProtocols_.empty()) {
495 } else if (context->advertisedNextProtocols_.size() == 1) {
496 *out = context->advertisedNextProtocols_[0].protocols;
497 *outlen = context->advertisedNextProtocols_[0].length;
499 uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
500 sNextProtocolsExDataIndex_));
501 if (selected_index) {
503 *out = context->advertisedNextProtocols_[selected_index].protocols;
504 *outlen = context->advertisedNextProtocols_[selected_index].length;
506 auto i = context->pickNextProtocols();
507 uintptr_t selected = i + 1;
508 SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
509 *out = context->advertisedNextProtocols_[i].protocols;
510 *outlen = context->advertisedNextProtocols_[i].length;
513 return SSL_TLSEXT_ERR_OK;
516 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
517 FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
518 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
520 TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
521 TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
522 TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
523 TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
524 TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
525 TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
526 TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
527 TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
528 TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
529 TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
530 TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
531 TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
532 TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
533 TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
534 TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
535 TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
536 TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
537 TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
538 TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
539 TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
540 TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
541 TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
542 TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
543 TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
544 TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
545 TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
547 length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
548 width_ = sizeof(ciphers_[0]);
549 qsort(ciphers_, length_, width_, compare_ulong);
552 bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
553 const SSL_CIPHER *cipher) {
554 unsigned long cid = cipher->id;
556 (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
561 SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
562 if (*(unsigned long *)x < *(unsigned long *)y) {
565 if (*(unsigned long *)x > *(unsigned long *)y) {
571 bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
572 return falseStartChecker_.canUseFalseStartWithCipher(cipher);
576 int SSLContext::selectNextProtocolCallback(
577 SSL* ssl, unsigned char **out, unsigned char *outlen,
578 const unsigned char *server, unsigned int server_len, void *data) {
580 SSLContext* ctx = (SSLContext*)data;
581 if (ctx->advertisedNextProtocols_.size() > 1) {
582 VLOG(3) << "SSLContext::selectNextProcolCallback() "
583 << "client should be deterministic in selecting protocols.";
586 unsigned char *client;
587 unsigned int client_len;
588 bool filtered = false;
589 auto cpf = ctx->getClientProtocolFilterCallback();
591 filtered = (*cpf)(&client, &client_len, server, server_len);
595 if (ctx->advertisedNextProtocols_.empty()) {
596 client = (unsigned char *) "";
599 client = ctx->advertisedNextProtocols_[0].protocols;
600 client_len = ctx->advertisedNextProtocols_[0].length;
604 int retval = SSL_select_next_proto(out, outlen, server, server_len,
606 if (retval != OPENSSL_NPN_NEGOTIATED) {
607 VLOG(3) << "SSLContext::selectNextProcolCallback() "
608 << "unable to pick a next protocol.";
609 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
610 FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
612 const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
613 if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
614 SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
618 return SSL_TLSEXT_ERR_OK;
620 #endif // OPENSSL_NPN_NEGOTIATED
622 SSL* SSLContext::createSSL() const {
623 SSL* ssl = SSL_new(ctx_);
624 if (ssl == nullptr) {
625 throw std::runtime_error("SSL_new: " + getErrors());
631 * Match a name with a pattern. The pattern may include wildcard. A single
632 * wildcard "*" can match up to one component in the domain name.
634 * @param host Host name, typically the name of the remote host
635 * @param pattern Name retrieved from certificate
636 * @param size Size of "pattern"
637 * @return True, if "host" matches "pattern". False otherwise.
639 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
642 while (i < size && host[j] != '\0') {
643 if (toupper(pattern[i]) == toupper(host[j])) {
648 if (pattern[i] == '*') {
649 while (host[j] != '.' && host[j] != '\0') {
657 if (i == size && host[j] == '\0') {
663 int SSLContext::passwordCallback(char* password,
667 SSLContext* context = (SSLContext*)data;
668 if (context == nullptr || context->passwordCollector() == nullptr) {
671 std::string userPassword;
672 // call user defined password collector to get password
673 context->passwordCollector()->getPassword(userPassword, size);
674 int length = userPassword.size();
678 strncpy(password, userPassword.c_str(), length);
684 SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
685 lockType(inLockType) {
689 if (lockType == SSLContext::LOCK_MUTEX) {
691 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
694 // lockType == LOCK_NONE, no-op
698 if (lockType == SSLContext::LOCK_MUTEX) {
700 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
703 // lockType == LOCK_NONE, no-op
706 SSLContext::SSLLockType lockType;
707 folly::SpinLock spinLock{};
711 // Statics are unsafe in environments that call exit().
712 // If one thread calls exit() while another thread is
713 // references a member of SSLContext, bad things can happen.
714 // SSLContext runs in such environments.
715 // Instead of declaring a static member we "new" the static
716 // member so that it won't be destructed on exit().
717 static std::unique_ptr<SSLLock[]>& locks() {
718 static auto locksInst = new std::unique_ptr<SSLLock[]>();
722 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
723 static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
724 return *lockTypesInst;
727 static void callbackLocking(int mode, int n, const char*, int) {
728 if (mode & CRYPTO_LOCK) {
735 static unsigned long callbackThreadID() {
736 return static_cast<unsigned long>(
738 pthread_mach_thread_np(pthread_self())
745 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
746 return new CRYPTO_dynlock_value;
749 static void dyn_lock(int mode,
750 struct CRYPTO_dynlock_value* lock,
752 if (lock != nullptr) {
753 if (mode & CRYPTO_LOCK) {
756 lock->mutex.unlock();
761 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
765 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
766 lockTypes() = inLockTypes;
769 void SSLContext::markInitialized() {
770 std::lock_guard<std::mutex> g(initMutex());
774 void SSLContext::initializeOpenSSL() {
775 std::lock_guard<std::mutex> g(initMutex());
776 initializeOpenSSLLocked();
779 void SSLContext::initializeOpenSSLLocked() {
784 SSL_load_error_strings();
785 ERR_load_crypto_strings();
787 locks().reset(new SSLLock[::CRYPTO_num_locks()]);
788 for (auto it: lockTypes()) {
789 locks()[it.first].lockType = it.second;
791 CRYPTO_set_id_callback(callbackThreadID);
792 CRYPTO_set_locking_callback(callbackLocking);
794 CRYPTO_set_dynlock_create_callback(dyn_create);
795 CRYPTO_set_dynlock_lock_callback(dyn_lock);
796 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
798 #ifdef OPENSSL_NPN_NEGOTIATED
799 sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
800 (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
805 void SSLContext::cleanupOpenSSL() {
806 std::lock_guard<std::mutex> g(initMutex());
807 cleanupOpenSSLLocked();
810 void SSLContext::cleanupOpenSSLLocked() {
815 CRYPTO_set_id_callback(nullptr);
816 CRYPTO_set_locking_callback(nullptr);
817 CRYPTO_set_dynlock_create_callback(nullptr);
818 CRYPTO_set_dynlock_lock_callback(nullptr);
819 CRYPTO_set_dynlock_destroy_callback(nullptr);
820 CRYPTO_cleanup_all_ex_data();
825 initialized_ = false;
828 void SSLContext::setOptions(long options) {
829 long newOpt = SSL_CTX_set_options(ctx_, options);
830 if ((newOpt & options) != options) {
831 throw std::runtime_error("SSL_CTX_set_options failed");
835 std::string SSLContext::getErrors(int errnoCopy) {
837 unsigned long errorCode;
841 while ((errorCode = ERR_get_error()) != 0) {
842 if (!errors.empty()) {
845 const char* reason = ERR_reason_error_string(errorCode);
846 if (reason == nullptr) {
847 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
852 if (errors.empty()) {
853 errors = "error code: " + folly::to<std::string>(errnoCopy);
859 operator<<(std::ostream& os, const PasswordCollector& collector) {
860 os << collector.describe();
864 bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
865 sockaddr_storage* addrStorage,
866 socklen_t* addrLen) {
867 // Grab the ssl idx and then the ssl object so that we can get the peer
868 // name to compare against the ips in the subjectAltName
869 auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
871 reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
872 int fd = SSL_get_fd(ssl);
874 LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
878 *addrLen = sizeof(*addrStorage);
879 if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
880 PLOG(ERROR) << "Unable to get peer name";
883 CHECK(*addrLen <= sizeof(*addrStorage));
887 bool OpenSSLUtils::validatePeerCertNames(X509* cert,
888 const sockaddr* addr,
890 // Try to extract the names within the SAN extension from the certificate
892 reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
893 X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
895 if (altNames != nullptr) {
896 sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
899 if (altNames == nullptr) {
900 LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
904 const sockaddr_in* addr4 = nullptr;
905 const sockaddr_in6* addr6 = nullptr;
906 if (addr != nullptr) {
907 if (addr->sa_family == AF_INET) {
908 addr4 = reinterpret_cast<const sockaddr_in*>(addr);
909 } else if (addr->sa_family == AF_INET6) {
910 addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
912 LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
917 for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
918 auto name = sk_GENERAL_NAME_value(altNames, i);
919 if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
920 // Extra const-ness for paranoia
921 unsigned char const * const rawIpStr = name->d.iPAddress->data;
922 int const rawIpLen = name->d.iPAddress->length;
924 if (rawIpLen == 4 && addr4 != nullptr) {
925 if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
928 } else if (rawIpLen == 16 && addr6 != nullptr) {
929 if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
932 } else if (rawIpLen != 4 && rawIpLen != 16) {
933 LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
938 LOG(WARNING) << "Unable to match client cert against alt name ip";