2 * Copyright 2014 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SSLContext.h"
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
24 #include <folly/Format.h>
25 #include <folly/io/PortableSpinLock.h>
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
31 struct CRYPTO_dynlock_value {
37 bool SSLContext::initialized_ = false;
38 std::mutex SSLContext::mutex_;
39 #ifdef OPENSSL_NPN_NEGOTIATED
40 int SSLContext::sNextProtocolsExDataIndex_ = -1;
43 #ifndef SSLCONTEXT_NO_REFCOUNT
44 uint64_t SSLContext::count_ = 0;
47 // SSLContext implementation
48 SSLContext::SSLContext(SSLVersion version) {
50 std::lock_guard<std::mutex> g(mutex_);
51 #ifndef SSLCONTEXT_NO_REFCOUNT
54 initializeOpenSSLLocked();
57 ctx_ = SSL_CTX_new(SSLv23_method());
58 if (ctx_ == nullptr) {
59 throw std::runtime_error("SSL_CTX_new: " + getErrors());
65 opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
68 opt = SSL_OP_NO_SSLv2;
74 int newOpt = SSL_CTX_set_options(ctx_, opt);
75 DCHECK((newOpt & opt) == opt);
77 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
79 checkPeerName_ = false;
81 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
82 SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
83 SSL_CTX_set_tlsext_servername_arg(ctx_, this);
87 SSLContext::~SSLContext() {
88 if (ctx_ != nullptr) {
93 #ifdef OPENSSL_NPN_NEGOTIATED
94 deleteNextProtocolsStrings();
97 #ifndef SSLCONTEXT_NO_REFCOUNT
99 std::lock_guard<std::mutex> g(mutex_);
101 cleanupOpenSSLLocked();
107 void SSLContext::ciphers(const std::string& ciphers) {
108 providedCiphersString_ = ciphers;
109 setCiphersOrThrow(ciphers);
112 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
113 int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
114 if (ERR_peek_error() != 0) {
115 throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
118 throw std::runtime_error("None of specified ciphers are supported");
122 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
124 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
125 verifyPeer_ = verifyPeer;
128 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
130 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
131 int mode = SSL_VERIFY_NONE;
133 // case SSLVerifyPeerEnum::USE_CTX: // can't happen
136 case SSLVerifyPeerEnum::VERIFY:
137 mode = SSL_VERIFY_PEER;
140 case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
141 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
144 case SSLVerifyPeerEnum::NO_VERIFY:
145 mode = SSL_VERIFY_NONE;
154 int SSLContext::getVerificationMode() {
155 return getVerificationMode(verifyPeer_);
158 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
159 const std::string& peerName) {
162 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
163 checkPeerName_ = checkPeerName;
164 peerFixedName_ = peerName;
166 mode = SSL_VERIFY_NONE;
167 checkPeerName_ = false; // can't check name without cert!
168 peerFixedName_.clear();
170 SSL_CTX_set_verify(ctx_, mode, nullptr);
173 void SSLContext::loadCertificate(const char* path, const char* format) {
174 if (path == nullptr || format == nullptr) {
175 throw std::invalid_argument(
176 "loadCertificateChain: either <path> or <format> is nullptr");
178 if (strcmp(format, "PEM") == 0) {
179 if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
180 int errnoCopy = errno;
181 std::string reason("SSL_CTX_use_certificate_chain_file: ");
184 reason.append(getErrors(errnoCopy));
185 throw std::runtime_error(reason);
188 throw std::runtime_error("Unsupported certificate format: " + std::string(format));
192 void SSLContext::loadPrivateKey(const char* path, const char* format) {
193 if (path == nullptr || format == nullptr) {
194 throw std::invalid_argument(
195 "loadPrivateKey: either <path> or <format> is nullptr");
197 if (strcmp(format, "PEM") == 0) {
198 if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
199 throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
202 throw std::runtime_error("Unsupported private key format: " + std::string(format));
206 void SSLContext::loadTrustedCertificates(const char* path) {
207 if (path == nullptr) {
208 throw std::invalid_argument(
209 "loadTrustedCertificates: <path> is nullptr");
211 if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
212 throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
216 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
217 SSL_CTX_set_cert_store(ctx_, store);
220 void SSLContext::loadClientCAList(const char* path) {
221 auto clientCAs = SSL_load_client_CA_file(path);
222 if (clientCAs == nullptr) {
223 LOG(ERROR) << "Unable to load ca file: " << path;
226 SSL_CTX_set_client_CA_list(ctx_, clientCAs);
229 void SSLContext::randomize() {
233 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
234 if (collector == nullptr) {
235 LOG(ERROR) << "passwordCollector: ignore invalid password collector";
238 collector_ = collector;
239 SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
240 SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
243 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
245 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
249 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
250 clientHelloCbs_.push_back(cb);
253 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
254 SSLContext* context = (SSLContext*)data;
256 if (context == nullptr) {
257 return SSL_TLSEXT_ERR_NOACK;
260 for (auto& cb : context->clientHelloCbs_) {
261 // Generic callbacks to happen after we receive the Client Hello.
262 // For example, we use one to switch which cipher we use depending
263 // on the user's TLS version. Because the primary purpose of
264 // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
265 // are side-uses, we ignore any possible failures other than just logging
270 if (!context->serverNameCb_) {
271 return SSL_TLSEXT_ERR_NOACK;
274 ServerNameCallbackResult ret = context->serverNameCb_(ssl);
276 case SERVER_NAME_FOUND:
277 return SSL_TLSEXT_ERR_OK;
278 case SERVER_NAME_NOT_FOUND:
279 return SSL_TLSEXT_ERR_NOACK;
280 case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
281 *al = TLS1_AD_UNRECOGNIZED_NAME;
282 return SSL_TLSEXT_ERR_ALERT_FATAL;
287 return SSL_TLSEXT_ERR_NOACK;
290 void SSLContext::switchCiphersIfTLS11(
292 const std::string& tls11CipherString) {
294 CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
296 if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
297 // We only do this for TLS v 1.1 and later
301 // Prefer AES for TLS versions 1.1 and later since these are not
302 // vulnerable to BEAST attacks on AES. Note that we're setting the
303 // cipher list on the SSL object, not the SSL_CTX object, so it will
304 // only last for this request.
305 int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
306 if ((rc == 0) || ERR_peek_error() != 0) {
307 // This shouldn't happen since we checked for this when proxygen
309 LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
310 SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
315 #ifdef OPENSSL_NPN_NEGOTIATED
316 bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
317 return setRandomizedAdvertisedNextProtocols({{1, protocols}});
320 bool SSLContext::setRandomizedAdvertisedNextProtocols(
321 const std::list<NextProtocolsItem>& items) {
322 unsetNextProtocols();
323 if (items.size() == 0) {
326 int total_weight = 0;
327 for (const auto &item : items) {
328 if (item.protocols.size() == 0) {
331 AdvertisedNextProtocolsItem advertised_item;
332 advertised_item.length = 0;
333 for (const auto& proto : item.protocols) {
334 ++advertised_item.length;
335 unsigned protoLength = proto.length();
336 if (protoLength >= 256) {
337 deleteNextProtocolsStrings();
340 advertised_item.length += protoLength;
342 advertised_item.protocols = new unsigned char[advertised_item.length];
343 if (!advertised_item.protocols) {
344 throw std::runtime_error("alloc failure");
346 unsigned char* dst = advertised_item.protocols;
347 for (auto& proto : item.protocols) {
348 unsigned protoLength = proto.length();
349 *dst++ = (unsigned char)protoLength;
350 memcpy(dst, proto.data(), protoLength);
353 total_weight += item.weight;
354 advertised_item.probability = item.weight;
355 advertisedNextProtocols_.push_back(advertised_item);
357 if (total_weight == 0) {
358 deleteNextProtocolsStrings();
361 for (auto &advertised_item : advertisedNextProtocols_) {
362 advertised_item.probability /= total_weight;
364 SSL_CTX_set_next_protos_advertised_cb(
365 ctx_, advertisedNextProtocolCallback, this);
366 SSL_CTX_set_next_proto_select_cb(
367 ctx_, selectNextProtocolCallback, this);
371 void SSLContext::deleteNextProtocolsStrings() {
372 for (auto protocols : advertisedNextProtocols_) {
373 delete[] protocols.protocols;
375 advertisedNextProtocols_.clear();
378 void SSLContext::unsetNextProtocols() {
379 deleteNextProtocolsStrings();
380 SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
381 SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
384 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
385 const unsigned char** out, unsigned int* outlen, void* data) {
386 SSLContext* context = (SSLContext*)data;
387 if (context == nullptr || context->advertisedNextProtocols_.empty()) {
390 } else if (context->advertisedNextProtocols_.size() == 1) {
391 *out = context->advertisedNextProtocols_[0].protocols;
392 *outlen = context->advertisedNextProtocols_[0].length;
394 uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
395 sNextProtocolsExDataIndex_));
396 if (selected_index) {
398 *out = context->advertisedNextProtocols_[selected_index].protocols;
399 *outlen = context->advertisedNextProtocols_[selected_index].length;
401 unsigned char random_byte;
402 RAND_bytes(&random_byte, 1);
403 double random_value = random_byte / 255.0;
405 for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
406 sum += context->advertisedNextProtocols_[i].probability;
407 if (sum < random_value &&
408 i + 1 < context->advertisedNextProtocols_.size()) {
411 uintptr_t selected = i + 1;
412 SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
413 *out = context->advertisedNextProtocols_[i].protocols;
414 *outlen = context->advertisedNextProtocols_[i].length;
419 return SSL_TLSEXT_ERR_OK;
422 int SSLContext::selectNextProtocolCallback(
423 SSL* ssl, unsigned char **out, unsigned char *outlen,
424 const unsigned char *server, unsigned int server_len, void *data) {
426 SSLContext* ctx = (SSLContext*)data;
427 if (ctx->advertisedNextProtocols_.size() > 1) {
428 VLOG(3) << "SSLContext::selectNextProcolCallback() "
429 << "client should be deterministic in selecting protocols.";
432 unsigned char *client;
434 if (ctx->advertisedNextProtocols_.empty()) {
435 client = (unsigned char *) "";
438 client = ctx->advertisedNextProtocols_[0].protocols;
439 client_len = ctx->advertisedNextProtocols_[0].length;
442 int retval = SSL_select_next_proto(out, outlen, server, server_len,
444 if (retval != OPENSSL_NPN_NEGOTIATED) {
445 VLOG(3) << "SSLContext::selectNextProcolCallback() "
446 << "unable to pick a next protocol.";
448 return SSL_TLSEXT_ERR_OK;
450 #endif // OPENSSL_NPN_NEGOTIATED
452 SSL* SSLContext::createSSL() const {
453 SSL* ssl = SSL_new(ctx_);
454 if (ssl == nullptr) {
455 throw std::runtime_error("SSL_new: " + getErrors());
461 * Match a name with a pattern. The pattern may include wildcard. A single
462 * wildcard "*" can match up to one component in the domain name.
464 * @param host Host name, typically the name of the remote host
465 * @param pattern Name retrieved from certificate
466 * @param size Size of "pattern"
467 * @return True, if "host" matches "pattern". False otherwise.
469 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
472 while (i < size && host[j] != '\0') {
473 if (toupper(pattern[i]) == toupper(host[j])) {
478 if (pattern[i] == '*') {
479 while (host[j] != '.' && host[j] != '\0') {
487 if (i == size && host[j] == '\0') {
493 int SSLContext::passwordCallback(char* password,
497 SSLContext* context = (SSLContext*)data;
498 if (context == nullptr || context->passwordCollector() == nullptr) {
501 std::string userPassword;
502 // call user defined password collector to get password
503 context->passwordCollector()->getPassword(userPassword, size);
504 int length = userPassword.size();
508 strncpy(password, userPassword.c_str(), length);
514 SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
515 lockType(inLockType) {
519 if (lockType == SSLContext::LOCK_MUTEX) {
521 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
524 // lockType == LOCK_NONE, no-op
528 if (lockType == SSLContext::LOCK_MUTEX) {
530 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
533 // lockType == LOCK_NONE, no-op
536 SSLContext::SSLLockType lockType;
537 folly::io::PortableSpinLock spinLock{};
541 // Statics are unsafe in environments that call exit().
542 // If one thread calls exit() while another thread is
543 // references a member of SSLContext, bad things can happen.
544 // SSLContext runs in such environments.
545 // Instead of declaring a static member we "new" the static
546 // member so that it won't be destructed on exit().
547 static std::map<int, SSLContext::SSLLockType>* lockTypesInst =
548 new std::map<int, SSLContext::SSLLockType>();
550 static std::unique_ptr<SSLLock[]>* locksInst =
551 new std::unique_ptr<SSLLock[]>();
553 static std::unique_ptr<SSLLock[]>& locks() {
557 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
558 return *lockTypesInst;
561 static void callbackLocking(int mode, int n, const char*, int) {
562 if (mode & CRYPTO_LOCK) {
569 static unsigned long callbackThreadID() {
570 return static_cast<unsigned long>(
572 pthread_mach_thread_np(pthread_self())
579 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
580 return new CRYPTO_dynlock_value;
583 static void dyn_lock(int mode,
584 struct CRYPTO_dynlock_value* lock,
586 if (lock != nullptr) {
587 if (mode & CRYPTO_LOCK) {
590 lock->mutex.unlock();
595 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
599 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
600 lockTypes() = inLockTypes;
603 void SSLContext::initializeOpenSSL() {
604 std::lock_guard<std::mutex> g(mutex_);
605 initializeOpenSSLLocked();
608 void SSLContext::initializeOpenSSLLocked() {
613 SSL_load_error_strings();
614 ERR_load_crypto_strings();
616 locks().reset(new SSLLock[::CRYPTO_num_locks()]);
617 for (auto it: lockTypes()) {
618 locks()[it.first].lockType = it.second;
620 CRYPTO_set_id_callback(callbackThreadID);
621 CRYPTO_set_locking_callback(callbackLocking);
623 CRYPTO_set_dynlock_create_callback(dyn_create);
624 CRYPTO_set_dynlock_lock_callback(dyn_lock);
625 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
627 #ifdef OPENSSL_NPN_NEGOTIATED
628 sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
629 (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
634 void SSLContext::cleanupOpenSSL() {
635 std::lock_guard<std::mutex> g(mutex_);
636 cleanupOpenSSLLocked();
639 void SSLContext::cleanupOpenSSLLocked() {
644 CRYPTO_set_id_callback(nullptr);
645 CRYPTO_set_locking_callback(nullptr);
646 CRYPTO_set_dynlock_create_callback(nullptr);
647 CRYPTO_set_dynlock_lock_callback(nullptr);
648 CRYPTO_set_dynlock_destroy_callback(nullptr);
649 CRYPTO_cleanup_all_ex_data();
654 initialized_ = false;
657 void SSLContext::setOptions(long options) {
658 long newOpt = SSL_CTX_set_options(ctx_, options);
659 if ((newOpt & options) != options) {
660 throw std::runtime_error("SSL_CTX_set_options failed");
664 std::string SSLContext::getErrors(int errnoCopy) {
666 unsigned long errorCode;
670 while ((errorCode = ERR_get_error()) != 0) {
671 if (!errors.empty()) {
674 const char* reason = ERR_reason_error_string(errorCode);
675 if (reason == nullptr) {
676 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
681 if (errors.empty()) {
682 errors = "error code: " + folly::to<std::string>(errnoCopy);
688 operator<<(std::ostream& os, const PasswordCollector& collector) {
689 os << collector.describe();