2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SSLContext.h"
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/SpinLock.h>
28 // ---------------------------------------------------------------------
29 // SSLContext implementation
30 // ---------------------------------------------------------------------
32 struct CRYPTO_dynlock_value {
38 bool SSLContext::initialized_ = false;
42 std::mutex& initMutex() {
47 } // anonymous namespace
49 #ifdef OPENSSL_NPN_NEGOTIATED
50 int SSLContext::sNextProtocolsExDataIndex_ = -1;
53 // SSLContext implementation
54 SSLContext::SSLContext(SSLVersion version) {
56 std::lock_guard<std::mutex> g(initMutex());
57 initializeOpenSSLLocked();
60 ctx_ = SSL_CTX_new(SSLv23_method());
61 if (ctx_ == nullptr) {
62 throw std::runtime_error("SSL_CTX_new: " + getErrors());
68 opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
71 opt = SSL_OP_NO_SSLv2;
77 int newOpt = SSL_CTX_set_options(ctx_, opt);
78 DCHECK((newOpt & opt) == opt);
80 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
82 checkPeerName_ = false;
84 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
85 SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
86 SSL_CTX_set_tlsext_servername_arg(ctx_, this);
89 #ifdef OPENSSL_NPN_NEGOTIATED
90 Random::seed(nextProtocolPicker_);
94 SSLContext::~SSLContext() {
95 if (ctx_ != nullptr) {
100 #ifdef OPENSSL_NPN_NEGOTIATED
101 deleteNextProtocolsStrings();
105 void SSLContext::ciphers(const std::string& ciphers) {
106 providedCiphersString_ = ciphers;
107 setCiphersOrThrow(ciphers);
110 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
111 int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
112 if (ERR_peek_error() != 0) {
113 throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
116 throw std::runtime_error("None of specified ciphers are supported");
120 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
122 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
123 verifyPeer_ = verifyPeer;
126 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
128 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
129 int mode = SSL_VERIFY_NONE;
131 // case SSLVerifyPeerEnum::USE_CTX: // can't happen
134 case SSLVerifyPeerEnum::VERIFY:
135 mode = SSL_VERIFY_PEER;
138 case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
139 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
142 case SSLVerifyPeerEnum::NO_VERIFY:
143 mode = SSL_VERIFY_NONE;
152 int SSLContext::getVerificationMode() {
153 return getVerificationMode(verifyPeer_);
156 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
157 const std::string& peerName) {
160 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
161 checkPeerName_ = checkPeerName;
162 peerFixedName_ = peerName;
164 mode = SSL_VERIFY_NONE;
165 checkPeerName_ = false; // can't check name without cert!
166 peerFixedName_.clear();
168 SSL_CTX_set_verify(ctx_, mode, nullptr);
171 void SSLContext::loadCertificate(const char* path, const char* format) {
172 if (path == nullptr || format == nullptr) {
173 throw std::invalid_argument(
174 "loadCertificateChain: either <path> or <format> is nullptr");
176 if (strcmp(format, "PEM") == 0) {
177 if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
178 int errnoCopy = errno;
179 std::string reason("SSL_CTX_use_certificate_chain_file: ");
182 reason.append(getErrors(errnoCopy));
183 throw std::runtime_error(reason);
186 throw std::runtime_error("Unsupported certificate format: " + std::string(format));
190 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
191 if (cert.data() == nullptr) {
192 throw std::invalid_argument("loadCertificate: <cert> is nullptr");
195 ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
196 if (bio == nullptr) {
197 throw std::runtime_error("BIO_new: " + getErrors());
200 int written = BIO_write(bio.get(), cert.data(), cert.size());
201 if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
202 throw std::runtime_error("BIO_write: " + getErrors());
205 ssl::X509UniquePtr x509(
206 PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
207 if (x509 == nullptr) {
208 throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
211 if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
212 throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
216 void SSLContext::loadPrivateKey(const char* path, const char* format) {
217 if (path == nullptr || format == nullptr) {
218 throw std::invalid_argument(
219 "loadPrivateKey: either <path> or <format> is nullptr");
221 if (strcmp(format, "PEM") == 0) {
222 if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
223 throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
226 throw std::runtime_error("Unsupported private key format: " + std::string(format));
230 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
231 if (pkey.data() == nullptr) {
232 throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
235 ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
236 if (bio == nullptr) {
237 throw std::runtime_error("BIO_new: " + getErrors());
240 int written = BIO_write(bio.get(), pkey.data(), pkey.size());
241 if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
242 throw std::runtime_error("BIO_write: " + getErrors());
245 ssl::EvpPkeyUniquePtr key(
246 PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
247 if (key == nullptr) {
248 throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
251 if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
252 throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
256 void SSLContext::loadTrustedCertificates(const char* path) {
257 if (path == nullptr) {
258 throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
260 if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
261 throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
265 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
266 SSL_CTX_set_cert_store(ctx_, store);
269 void SSLContext::loadClientCAList(const char* path) {
270 auto clientCAs = SSL_load_client_CA_file(path);
271 if (clientCAs == nullptr) {
272 LOG(ERROR) << "Unable to load ca file: " << path;
275 SSL_CTX_set_client_CA_list(ctx_, clientCAs);
278 void SSLContext::randomize() {
282 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
283 if (collector == nullptr) {
284 LOG(ERROR) << "passwordCollector: ignore invalid password collector";
287 collector_ = collector;
288 SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
289 SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
292 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
294 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
298 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
299 clientHelloCbs_.push_back(cb);
302 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
303 SSLContext* context = (SSLContext*)data;
305 if (context == nullptr) {
306 return SSL_TLSEXT_ERR_NOACK;
309 for (auto& cb : context->clientHelloCbs_) {
310 // Generic callbacks to happen after we receive the Client Hello.
311 // For example, we use one to switch which cipher we use depending
312 // on the user's TLS version. Because the primary purpose of
313 // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
314 // are side-uses, we ignore any possible failures other than just logging
319 if (!context->serverNameCb_) {
320 return SSL_TLSEXT_ERR_NOACK;
323 ServerNameCallbackResult ret = context->serverNameCb_(ssl);
325 case SERVER_NAME_FOUND:
326 return SSL_TLSEXT_ERR_OK;
327 case SERVER_NAME_NOT_FOUND:
328 return SSL_TLSEXT_ERR_NOACK;
329 case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
330 *al = TLS1_AD_UNRECOGNIZED_NAME;
331 return SSL_TLSEXT_ERR_ALERT_FATAL;
336 return SSL_TLSEXT_ERR_NOACK;
339 void SSLContext::switchCiphersIfTLS11(
341 const std::string& tls11CipherString) {
343 CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
345 if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
346 // We only do this for TLS v 1.1 and later
350 // Prefer AES for TLS versions 1.1 and later since these are not
351 // vulnerable to BEAST attacks on AES. Note that we're setting the
352 // cipher list on the SSL object, not the SSL_CTX object, so it will
353 // only last for this request.
354 int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
355 if ((rc == 0) || ERR_peek_error() != 0) {
356 // This shouldn't happen since we checked for this when proxygen
358 LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
359 SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
364 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
365 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
366 const unsigned char** out,
367 unsigned char* outlen,
368 const unsigned char* in,
371 SSLContext* context = (SSLContext*)data;
373 if (context->advertisedNextProtocols_.empty()) {
377 auto i = context->pickNextProtocols();
378 const auto& item = context->advertisedNextProtocols_[i];
379 if (SSL_select_next_proto((unsigned char**)out,
384 inlen) != OPENSSL_NPN_NEGOTIATED) {
385 return SSL_TLSEXT_ERR_NOACK;
388 return SSL_TLSEXT_ERR_OK;
392 #ifdef OPENSSL_NPN_NEGOTIATED
394 bool SSLContext::setAdvertisedNextProtocols(
395 const std::list<std::string>& protocols, NextProtocolType protocolType) {
396 return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
399 bool SSLContext::setRandomizedAdvertisedNextProtocols(
400 const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
401 unsetNextProtocols();
402 if (items.size() == 0) {
405 int total_weight = 0;
406 for (const auto &item : items) {
407 if (item.protocols.size() == 0) {
410 AdvertisedNextProtocolsItem advertised_item;
411 advertised_item.length = 0;
412 for (const auto& proto : item.protocols) {
413 ++advertised_item.length;
414 unsigned protoLength = proto.length();
415 if (protoLength >= 256) {
416 deleteNextProtocolsStrings();
419 advertised_item.length += protoLength;
421 advertised_item.protocols = new unsigned char[advertised_item.length];
422 if (!advertised_item.protocols) {
423 throw std::runtime_error("alloc failure");
425 unsigned char* dst = advertised_item.protocols;
426 for (auto& proto : item.protocols) {
427 unsigned protoLength = proto.length();
428 *dst++ = (unsigned char)protoLength;
429 memcpy(dst, proto.data(), protoLength);
432 total_weight += item.weight;
433 advertisedNextProtocols_.push_back(advertised_item);
434 advertisedNextProtocolWeights_.push_back(item.weight);
436 if (total_weight == 0) {
437 deleteNextProtocolsStrings();
440 nextProtocolDistribution_ =
441 std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
442 advertisedNextProtocolWeights_.end());
443 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
444 SSL_CTX_set_next_protos_advertised_cb(
445 ctx_, advertisedNextProtocolCallback, this);
446 SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
448 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
449 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
450 SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
451 // Client cannot really use randomized alpn
452 SSL_CTX_set_alpn_protos(ctx_,
453 advertisedNextProtocols_[0].protocols,
454 advertisedNextProtocols_[0].length);
460 void SSLContext::deleteNextProtocolsStrings() {
461 for (auto protocols : advertisedNextProtocols_) {
462 delete[] protocols.protocols;
464 advertisedNextProtocols_.clear();
465 advertisedNextProtocolWeights_.clear();
468 void SSLContext::unsetNextProtocols() {
469 deleteNextProtocolsStrings();
470 SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
471 SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
472 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
473 SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
474 SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
478 size_t SSLContext::pickNextProtocols() {
479 CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
480 return nextProtocolDistribution_(nextProtocolPicker_);
483 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
484 const unsigned char** out, unsigned int* outlen, void* data) {
485 SSLContext* context = (SSLContext*)data;
486 if (context == nullptr || context->advertisedNextProtocols_.empty()) {
489 } else if (context->advertisedNextProtocols_.size() == 1) {
490 *out = context->advertisedNextProtocols_[0].protocols;
491 *outlen = context->advertisedNextProtocols_[0].length;
493 uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
494 sNextProtocolsExDataIndex_));
495 if (selected_index) {
497 *out = context->advertisedNextProtocols_[selected_index].protocols;
498 *outlen = context->advertisedNextProtocols_[selected_index].length;
500 auto i = context->pickNextProtocols();
501 uintptr_t selected = i + 1;
502 SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
503 *out = context->advertisedNextProtocols_[i].protocols;
504 *outlen = context->advertisedNextProtocols_[i].length;
507 return SSL_TLSEXT_ERR_OK;
510 int SSLContext::selectNextProtocolCallback(SSL* ssl,
512 unsigned char* outlen,
513 const unsigned char* server,
514 unsigned int server_len,
516 (void)ssl; // Make -Wunused-parameters happy
517 SSLContext* ctx = (SSLContext*)data;
518 if (ctx->advertisedNextProtocols_.size() > 1) {
519 VLOG(3) << "SSLContext::selectNextProcolCallback() "
520 << "client should be deterministic in selecting protocols.";
523 unsigned char *client;
524 unsigned int client_len;
525 bool filtered = false;
526 auto cpf = ctx->getClientProtocolFilterCallback();
528 filtered = (*cpf)(&client, &client_len, server, server_len);
532 if (ctx->advertisedNextProtocols_.empty()) {
533 client = (unsigned char *) "";
536 client = ctx->advertisedNextProtocols_[0].protocols;
537 client_len = ctx->advertisedNextProtocols_[0].length;
541 int retval = SSL_select_next_proto(out, outlen, server, server_len,
543 if (retval != OPENSSL_NPN_NEGOTIATED) {
544 VLOG(3) << "SSLContext::selectNextProcolCallback() "
545 << "unable to pick a next protocol.";
547 return SSL_TLSEXT_ERR_OK;
549 #endif // OPENSSL_NPN_NEGOTIATED
551 SSL* SSLContext::createSSL() const {
552 SSL* ssl = SSL_new(ctx_);
553 if (ssl == nullptr) {
554 throw std::runtime_error("SSL_new: " + getErrors());
560 * Match a name with a pattern. The pattern may include wildcard. A single
561 * wildcard "*" can match up to one component in the domain name.
563 * @param host Host name, typically the name of the remote host
564 * @param pattern Name retrieved from certificate
565 * @param size Size of "pattern"
566 * @return True, if "host" matches "pattern". False otherwise.
568 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
571 while (i < size && host[j] != '\0') {
572 if (toupper(pattern[i]) == toupper(host[j])) {
577 if (pattern[i] == '*') {
578 while (host[j] != '.' && host[j] != '\0') {
586 if (i == size && host[j] == '\0') {
592 int SSLContext::passwordCallback(char* password,
596 SSLContext* context = (SSLContext*)data;
597 if (context == nullptr || context->passwordCollector() == nullptr) {
600 std::string userPassword;
601 // call user defined password collector to get password
602 context->passwordCollector()->getPassword(userPassword, size);
603 int length = userPassword.size();
607 strncpy(password, userPassword.c_str(), length);
613 SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
614 lockType(inLockType) {
618 if (lockType == SSLContext::LOCK_MUTEX) {
620 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
623 // lockType == LOCK_NONE, no-op
627 if (lockType == SSLContext::LOCK_MUTEX) {
629 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
632 // lockType == LOCK_NONE, no-op
635 SSLContext::SSLLockType lockType;
636 folly::SpinLock spinLock{};
640 // Statics are unsafe in environments that call exit().
641 // If one thread calls exit() while another thread is
642 // references a member of SSLContext, bad things can happen.
643 // SSLContext runs in such environments.
644 // Instead of declaring a static member we "new" the static
645 // member so that it won't be destructed on exit().
646 static std::unique_ptr<SSLLock[]>& locks() {
647 static auto locksInst = new std::unique_ptr<SSLLock[]>();
651 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
652 static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
653 return *lockTypesInst;
656 static void callbackLocking(int mode, int n, const char*, int) {
657 if (mode & CRYPTO_LOCK) {
664 static unsigned long callbackThreadID() {
665 return static_cast<unsigned long>(
667 pthread_mach_thread_np(pthread_self())
674 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
675 return new CRYPTO_dynlock_value;
678 static void dyn_lock(int mode,
679 struct CRYPTO_dynlock_value* lock,
681 if (lock != nullptr) {
682 if (mode & CRYPTO_LOCK) {
685 lock->mutex.unlock();
690 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
694 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
695 lockTypes() = inLockTypes;
698 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
699 void SSLContext::enableFalseStart() {
700 SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
704 void SSLContext::markInitialized() {
705 std::lock_guard<std::mutex> g(initMutex());
709 void SSLContext::initializeOpenSSL() {
710 std::lock_guard<std::mutex> g(initMutex());
711 initializeOpenSSLLocked();
714 void SSLContext::initializeOpenSSLLocked() {
719 SSL_load_error_strings();
720 ERR_load_crypto_strings();
722 locks().reset(new SSLLock[::CRYPTO_num_locks()]);
723 for (auto it: lockTypes()) {
724 locks()[it.first].lockType = it.second;
726 CRYPTO_set_id_callback(callbackThreadID);
727 CRYPTO_set_locking_callback(callbackLocking);
729 CRYPTO_set_dynlock_create_callback(dyn_create);
730 CRYPTO_set_dynlock_lock_callback(dyn_lock);
731 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
733 #ifdef OPENSSL_NPN_NEGOTIATED
734 sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
735 (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
740 void SSLContext::cleanupOpenSSL() {
741 std::lock_guard<std::mutex> g(initMutex());
742 cleanupOpenSSLLocked();
745 void SSLContext::cleanupOpenSSLLocked() {
750 CRYPTO_set_id_callback(nullptr);
751 CRYPTO_set_locking_callback(nullptr);
752 CRYPTO_set_dynlock_create_callback(nullptr);
753 CRYPTO_set_dynlock_lock_callback(nullptr);
754 CRYPTO_set_dynlock_destroy_callback(nullptr);
755 CRYPTO_cleanup_all_ex_data();
760 initialized_ = false;
763 void SSLContext::setOptions(long options) {
764 long newOpt = SSL_CTX_set_options(ctx_, options);
765 if ((newOpt & options) != options) {
766 throw std::runtime_error("SSL_CTX_set_options failed");
770 std::string SSLContext::getErrors(int errnoCopy) {
772 unsigned long errorCode;
776 while ((errorCode = ERR_get_error()) != 0) {
777 if (!errors.empty()) {
780 const char* reason = ERR_reason_error_string(errorCode);
781 if (reason == nullptr) {
782 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
787 if (errors.empty()) {
788 errors = "error code: " + folly::to<std::string>(errnoCopy);
794 operator<<(std::ostream& os, const PasswordCollector& collector) {
795 os << collector.describe();