2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SSLContext.h"
19 #include <folly/Format.h>
20 #include <folly/Memory.h>
21 #include <folly/Random.h>
22 #include <folly/SharedMutex.h>
23 #include <folly/SpinLock.h>
24 #include <folly/ThreadId.h>
25 #include <folly/ssl/Init.h>
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
32 // For OpenSSL portability API
33 using namespace folly::ssl;
35 // SSLContext implementation
36 SSLContext::SSLContext(SSLVersion version) {
39 ctx_ = SSL_CTX_new(SSLv23_method());
40 if (ctx_ == nullptr) {
41 throw std::runtime_error("SSL_CTX_new: " + getErrors());
47 opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
50 opt = SSL_OP_NO_SSLv2;
56 int newOpt = SSL_CTX_set_options(ctx_, opt);
57 DCHECK((newOpt & opt) == opt);
59 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
61 checkPeerName_ = false;
63 SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
65 #if FOLLY_OPENSSL_HAS_SNI
66 SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
67 SSL_CTX_set_tlsext_servername_arg(ctx_, this);
71 SSLContext::~SSLContext() {
72 if (ctx_ != nullptr) {
77 #ifdef OPENSSL_NPN_NEGOTIATED
78 deleteNextProtocolsStrings();
82 void SSLContext::ciphers(const std::string& ciphers) {
83 setCiphersOrThrow(ciphers);
86 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
87 if (ciphers.size() == 0) {
90 std::string opensslCipherList;
91 join(":", ciphers, opensslCipherList);
92 setCiphersOrThrow(opensslCipherList);
95 void SSLContext::setSignatureAlgorithms(
96 const std::vector<std::string>& sigalgs) {
97 if (sigalgs.size() == 0) {
100 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
101 std::string opensslSigAlgsList;
102 join(":", sigalgs, opensslSigAlgsList);
103 int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
105 throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
110 void SSLContext::setClientECCurvesList(
111 const std::vector<std::string>& ecCurves) {
112 if (ecCurves.size() == 0) {
115 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
116 std::string ecCurvesList;
117 join(":", ecCurves, ecCurvesList);
118 int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
120 throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
125 void SSLContext::setServerECCurve(const std::string& curveName) {
126 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
127 EC_KEY* ecdh = nullptr;
131 * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
132 * from RFC 4492 section 5.1.1, or explicitly described curves over
133 * binary fields. OpenSSL only supports the "named curves", which provide
134 * maximum interoperability.
137 nid = OBJ_sn2nid(curveName.c_str());
139 LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
141 ecdh = EC_KEY_new_by_curve_name(nid);
142 if (ecdh == nullptr) {
143 LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
146 SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
149 throw std::runtime_error("Elliptic curve encryption not allowed");
153 void SSLContext::setX509VerifyParam(
154 const ssl::X509VerifyParam& x509VerifyParam) {
155 if (!x509VerifyParam) {
158 if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
159 throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
163 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
164 int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
166 throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
168 providedCiphersString_ = ciphers;
171 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
173 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
174 verifyPeer_ = verifyPeer;
177 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
179 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
180 int mode = SSL_VERIFY_NONE;
182 // case SSLVerifyPeerEnum::USE_CTX: // can't happen
185 case SSLVerifyPeerEnum::VERIFY:
186 mode = SSL_VERIFY_PEER;
189 case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
190 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
193 case SSLVerifyPeerEnum::NO_VERIFY:
194 mode = SSL_VERIFY_NONE;
203 int SSLContext::getVerificationMode() {
204 return getVerificationMode(verifyPeer_);
207 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
208 const std::string& peerName) {
211 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
212 SSL_VERIFY_CLIENT_ONCE;
213 checkPeerName_ = checkPeerName;
214 peerFixedName_ = peerName;
216 mode = SSL_VERIFY_NONE;
217 checkPeerName_ = false; // can't check name without cert!
218 peerFixedName_.clear();
220 SSL_CTX_set_verify(ctx_, mode, nullptr);
223 void SSLContext::loadCertificate(const char* path, const char* format) {
224 if (path == nullptr || format == nullptr) {
225 throw std::invalid_argument(
226 "loadCertificateChain: either <path> or <format> is nullptr");
228 if (strcmp(format, "PEM") == 0) {
229 if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
230 int errnoCopy = errno;
231 std::string reason("SSL_CTX_use_certificate_chain_file: ");
234 reason.append(getErrors(errnoCopy));
235 throw std::runtime_error(reason);
238 throw std::runtime_error(
239 "Unsupported certificate format: " + std::string(format));
243 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
244 if (cert.data() == nullptr) {
245 throw std::invalid_argument("loadCertificate: <cert> is nullptr");
248 ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
249 if (bio == nullptr) {
250 throw std::runtime_error("BIO_new: " + getErrors());
253 int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
254 if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
255 throw std::runtime_error("BIO_write: " + getErrors());
258 ssl::X509UniquePtr x509(
259 PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
260 if (x509 == nullptr) {
261 throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
264 if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
265 throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
269 void SSLContext::loadPrivateKey(const char* path, const char* format) {
270 if (path == nullptr || format == nullptr) {
271 throw std::invalid_argument(
272 "loadPrivateKey: either <path> or <format> is nullptr");
274 if (strcmp(format, "PEM") == 0) {
275 if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
276 throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
279 throw std::runtime_error(
280 "Unsupported private key format: " + std::string(format));
284 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
285 if (pkey.data() == nullptr) {
286 throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
289 ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
290 if (bio == nullptr) {
291 throw std::runtime_error("BIO_new: " + getErrors());
294 int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
295 if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
296 throw std::runtime_error("BIO_write: " + getErrors());
299 ssl::EvpPkeyUniquePtr key(
300 PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
301 if (key == nullptr) {
302 throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
305 if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
306 throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
310 void SSLContext::loadTrustedCertificates(const char* path) {
311 if (path == nullptr) {
312 throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
314 if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
315 throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
320 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
321 SSL_CTX_set_cert_store(ctx_, store);
324 void SSLContext::loadClientCAList(const char* path) {
325 auto clientCAs = SSL_load_client_CA_file(path);
326 if (clientCAs == nullptr) {
327 LOG(ERROR) << "Unable to load ca file: " << path;
330 SSL_CTX_set_client_CA_list(ctx_, clientCAs);
333 void SSLContext::passwordCollector(
334 std::shared_ptr<PasswordCollector> collector) {
335 if (collector == nullptr) {
336 LOG(ERROR) << "passwordCollector: ignore invalid password collector";
339 collector_ = collector;
340 SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
341 SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
344 #if FOLLY_OPENSSL_HAS_SNI
346 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
350 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
351 clientHelloCbs_.push_back(cb);
354 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
355 SSLContext* context = (SSLContext*)data;
357 if (context == nullptr) {
358 return SSL_TLSEXT_ERR_NOACK;
361 for (auto& cb : context->clientHelloCbs_) {
362 // Generic callbacks to happen after we receive the Client Hello.
363 // For example, we use one to switch which cipher we use depending
364 // on the user's TLS version. Because the primary purpose of
365 // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
366 // are side-uses, we ignore any possible failures other than just logging
371 if (!context->serverNameCb_) {
372 return SSL_TLSEXT_ERR_NOACK;
375 ServerNameCallbackResult ret = context->serverNameCb_(ssl);
377 case SERVER_NAME_FOUND:
378 return SSL_TLSEXT_ERR_OK;
379 case SERVER_NAME_NOT_FOUND:
380 return SSL_TLSEXT_ERR_NOACK;
381 case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
382 *al = TLS1_AD_UNRECOGNIZED_NAME;
383 return SSL_TLSEXT_ERR_ALERT_FATAL;
388 return SSL_TLSEXT_ERR_NOACK;
391 void SSLContext::switchCiphersIfTLS11(
393 const std::string& tls11CipherString,
394 const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
395 CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
396 << "Shouldn't call if empty ciphers / alt ciphers";
398 if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
399 // We only do this for TLS v 1.1 and later
403 const std::string* ciphers = &tls11CipherString;
404 if (!tls11AltCipherlist.empty()) {
405 if (!cipherListPicker_) {
406 std::vector<int> weights;
408 tls11AltCipherlist.begin(),
409 tls11AltCipherlist.end(),
410 [&](const std::pair<std::string, int>& e) {
411 weights.push_back(e.second);
413 cipherListPicker_.reset(
414 new std::discrete_distribution<int>(weights.begin(), weights.end()));
416 auto rng = ThreadLocalPRNG();
417 auto index = (*cipherListPicker_)(rng);
418 if ((size_t)index >= tls11AltCipherlist.size()) {
419 LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
420 << ", but tls11AltCipherlist is of length "
421 << tls11AltCipherlist.size();
423 ciphers = &tls11AltCipherlist[size_t(index)].first;
427 // Prefer AES for TLS versions 1.1 and later since these are not
428 // vulnerable to BEAST attacks on AES. Note that we're setting the
429 // cipher list on the SSL object, not the SSL_CTX object, so it will
430 // only last for this request.
431 int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
432 if ((rc == 0) || ERR_peek_error() != 0) {
433 // This shouldn't happen since we checked for this when proxygen
435 LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
436 SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
439 #endif // FOLLY_OPENSSL_HAS_SNI
441 #if FOLLY_OPENSSL_HAS_ALPN
442 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
443 const unsigned char** out,
444 unsigned char* outlen,
445 const unsigned char* in,
448 SSLContext* context = (SSLContext*)data;
450 if (context->advertisedNextProtocols_.empty()) {
454 auto i = context->pickNextProtocols();
455 const auto& item = context->advertisedNextProtocols_[i];
456 if (SSL_select_next_proto((unsigned char**)out,
461 inlen) != OPENSSL_NPN_NEGOTIATED) {
462 return SSL_TLSEXT_ERR_NOACK;
465 return SSL_TLSEXT_ERR_OK;
467 #endif // FOLLY_OPENSSL_HAS_ALPN
469 #ifdef OPENSSL_NPN_NEGOTIATED
471 bool SSLContext::setAdvertisedNextProtocols(
472 const std::list<std::string>& protocols, NextProtocolType protocolType) {
473 return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
476 bool SSLContext::setRandomizedAdvertisedNextProtocols(
477 const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
478 unsetNextProtocols();
479 if (items.size() == 0) {
482 int total_weight = 0;
483 for (const auto &item : items) {
484 if (item.protocols.size() == 0) {
487 AdvertisedNextProtocolsItem advertised_item;
488 advertised_item.length = 0;
489 for (const auto& proto : item.protocols) {
490 ++advertised_item.length;
491 auto protoLength = proto.length();
492 if (protoLength >= 256) {
493 deleteNextProtocolsStrings();
496 advertised_item.length += unsigned(protoLength);
498 advertised_item.protocols = new unsigned char[advertised_item.length];
499 if (!advertised_item.protocols) {
500 throw std::runtime_error("alloc failure");
502 unsigned char* dst = advertised_item.protocols;
503 for (auto& proto : item.protocols) {
504 uint8_t protoLength = uint8_t(proto.length());
505 *dst++ = (unsigned char)protoLength;
506 memcpy(dst, proto.data(), protoLength);
509 total_weight += item.weight;
510 advertisedNextProtocols_.push_back(advertised_item);
511 advertisedNextProtocolWeights_.push_back(item.weight);
513 if (total_weight == 0) {
514 deleteNextProtocolsStrings();
517 nextProtocolDistribution_ =
518 std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
519 advertisedNextProtocolWeights_.end());
520 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
521 SSL_CTX_set_next_protos_advertised_cb(
522 ctx_, advertisedNextProtocolCallback, this);
523 SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
525 #if FOLLY_OPENSSL_HAS_ALPN
526 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
527 SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
528 // Client cannot really use randomized alpn
529 SSL_CTX_set_alpn_protos(ctx_,
530 advertisedNextProtocols_[0].protocols,
531 advertisedNextProtocols_[0].length);
537 void SSLContext::deleteNextProtocolsStrings() {
538 for (auto protocols : advertisedNextProtocols_) {
539 delete[] protocols.protocols;
541 advertisedNextProtocols_.clear();
542 advertisedNextProtocolWeights_.clear();
545 void SSLContext::unsetNextProtocols() {
546 deleteNextProtocolsStrings();
547 SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
548 SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
549 #if FOLLY_OPENSSL_HAS_ALPN
550 SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
551 SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
555 size_t SSLContext::pickNextProtocols() {
556 CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
557 auto rng = ThreadLocalPRNG();
558 return size_t(nextProtocolDistribution_(rng));
561 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
562 const unsigned char** out, unsigned int* outlen, void* data) {
563 static int nextProtocolsExDataIndex = SSL_get_ex_new_index(
564 0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
566 SSLContext* context = (SSLContext*)data;
567 if (context == nullptr || context->advertisedNextProtocols_.empty()) {
570 } else if (context->advertisedNextProtocols_.size() == 1) {
571 *out = context->advertisedNextProtocols_[0].protocols;
572 *outlen = context->advertisedNextProtocols_[0].length;
574 uintptr_t selected_index = reinterpret_cast<uintptr_t>(
575 SSL_get_ex_data(ssl, nextProtocolsExDataIndex));
576 if (selected_index) {
578 *out = context->advertisedNextProtocols_[selected_index].protocols;
579 *outlen = context->advertisedNextProtocols_[selected_index].length;
581 auto i = context->pickNextProtocols();
582 uintptr_t selected = i + 1;
583 SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected);
584 *out = context->advertisedNextProtocols_[i].protocols;
585 *outlen = context->advertisedNextProtocols_[i].length;
588 return SSL_TLSEXT_ERR_OK;
591 int SSLContext::selectNextProtocolCallback(SSL* ssl,
593 unsigned char* outlen,
594 const unsigned char* server,
595 unsigned int server_len,
597 (void)ssl; // Make -Wunused-parameters happy
598 SSLContext* ctx = (SSLContext*)data;
599 if (ctx->advertisedNextProtocols_.size() > 1) {
600 VLOG(3) << "SSLContext::selectNextProcolCallback() "
601 << "client should be deterministic in selecting protocols.";
604 unsigned char* client = nullptr;
605 unsigned int client_len = 0;
606 bool filtered = false;
607 auto cpf = ctx->getClientProtocolFilterCallback();
609 filtered = (*cpf)(&client, &client_len, server, server_len);
613 if (ctx->advertisedNextProtocols_.empty()) {
614 client = (unsigned char *) "";
617 client = ctx->advertisedNextProtocols_[0].protocols;
618 client_len = ctx->advertisedNextProtocols_[0].length;
622 int retval = SSL_select_next_proto(out, outlen, server, server_len,
624 if (retval != OPENSSL_NPN_NEGOTIATED) {
625 VLOG(3) << "SSLContext::selectNextProcolCallback() "
626 << "unable to pick a next protocol.";
628 return SSL_TLSEXT_ERR_OK;
630 #endif // OPENSSL_NPN_NEGOTIATED
632 SSL* SSLContext::createSSL() const {
633 SSL* ssl = SSL_new(ctx_);
634 if (ssl == nullptr) {
635 throw std::runtime_error("SSL_new: " + getErrors());
640 void SSLContext::setSessionCacheContext(const std::string& context) {
641 SSL_CTX_set_session_id_context(
643 reinterpret_cast<const unsigned char*>(context.data()),
644 std::min<unsigned int>(
645 static_cast<unsigned int>(context.length()),
646 SSL_MAX_SSL_SESSION_ID_LENGTH));
650 * Match a name with a pattern. The pattern may include wildcard. A single
651 * wildcard "*" can match up to one component in the domain name.
653 * @param host Host name, typically the name of the remote host
654 * @param pattern Name retrieved from certificate
655 * @param size Size of "pattern"
656 * @return True, if "host" matches "pattern". False otherwise.
658 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
661 while (i < size && host[j] != '\0') {
662 if (toupper(pattern[i]) == toupper(host[j])) {
667 if (pattern[i] == '*') {
668 while (host[j] != '.' && host[j] != '\0') {
676 if (i == size && host[j] == '\0') {
682 int SSLContext::passwordCallback(char* password,
686 SSLContext* context = (SSLContext*)data;
687 if (context == nullptr || context->passwordCollector() == nullptr) {
690 std::string userPassword;
691 // call user defined password collector to get password
692 context->passwordCollector()->getPassword(userPassword, size);
693 auto length = int(userPassword.size());
697 strncpy(password, userPassword.c_str(), size_t(length));
701 void SSLContext::setSSLLockTypes(std::map<int, LockType> inLockTypes) {
702 folly::ssl::setLockTypes(inLockTypes);
705 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
706 void SSLContext::enableFalseStart() {
707 SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
711 void SSLContext::initializeOpenSSL() {
715 void SSLContext::setOptions(long options) {
716 long newOpt = SSL_CTX_set_options(ctx_, options);
717 if ((newOpt & options) != options) {
718 throw std::runtime_error("SSL_CTX_set_options failed");
722 std::string SSLContext::getErrors(int errnoCopy) {
724 unsigned long errorCode;
728 while ((errorCode = ERR_get_error()) != 0) {
729 if (!errors.empty()) {
732 const char* reason = ERR_reason_error_string(errorCode);
733 if (reason == nullptr) {
734 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
739 if (errors.empty()) {
740 errors = "error code: " + folly::to<std::string>(errnoCopy);
746 operator<<(std::ostream& os, const PasswordCollector& collector) {
747 os << collector.describe();