From aaf30bd67128e34c42939a7fd4a5e53724e6289a Mon Sep 17 00:00:00 2001 From: Dave Watson Date: Wed, 1 Oct 2014 10:51:40 -0700 Subject: [PATCH] SSLContext Summary: Move SSLContext to folly. Test Plan: It builds. Reviewed By: dcsommer@fb.com Subscribers: jdperlow, atlas2-eng@, wormhole-diffs@, bwester, trunkagent, doug, ps, bmatheny, ssl-diffs@, alikhtarov, njormrod, mshneer, folly-diffs@, andrewcox, alandau, jsedgwick, fugalh FB internal diff: D1631924 Signature: t1:1631924:1414616562:9a67dbf20f00eb8fbcb35880efcb94c0fae07dcc --- folly/Makefile.am | 2 + folly/io/async/SSLContext.cpp | 643 ++++++++++++++++++++++++++++++++++ folly/io/async/SSLContext.h | 453 ++++++++++++++++++++++++ 3 files changed, 1098 insertions(+) create mode 100644 folly/io/async/SSLContext.cpp create mode 100644 folly/io/async/SSLContext.h diff --git a/folly/Makefile.am b/folly/Makefile.am index 76ad6a30..dde8c285 100644 --- a/folly/Makefile.am +++ b/folly/Makefile.am @@ -143,6 +143,7 @@ nobase_follyinclude_HEADERS = \ io/async/NotificationQueue.h \ io/async/HHWheelTimer.h \ io/async/Request.h \ + io/async/SSLContext.h \ io/async/TimeoutManager.h \ json.h \ Lazy.h \ @@ -269,6 +270,7 @@ libfolly_la_SOURCES = \ io/async/EventBaseManager.cpp \ io/async/EventHandler.cpp \ io/async/Request.cpp \ + io/async/SSLContext.cpp \ io/async/HHWheelTimer.cpp \ json.cpp \ detail/MemoryIdler.cpp \ diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp new file mode 100644 index 00000000..c50e2571 --- /dev/null +++ b/folly/io/async/SSLContext.cpp @@ -0,0 +1,643 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "SSLContext.h" + +#include +#include +#include +#include + +#include +#include +#include + +// --------------------------------------------------------------------- +// SSLContext implementation +// --------------------------------------------------------------------- + +struct CRYPTO_dynlock_value { + std::mutex mutex; +}; + +namespace folly { + +uint64_t SSLContext::count_ = 0; +std::mutex SSLContext::mutex_; +#ifdef OPENSSL_NPN_NEGOTIATED +int SSLContext::sNextProtocolsExDataIndex_ = -1; + +#endif +// SSLContext implementation +SSLContext::SSLContext(SSLVersion version) { + { + std::lock_guard g(mutex_); + if (!count_++) { + initializeOpenSSL(); + randomize(); +#ifdef OPENSSL_NPN_NEGOTIATED + sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0, + (void*)"Advertised next protocol index", nullptr, nullptr, nullptr); +#endif + } + } + + ctx_ = SSL_CTX_new(SSLv23_method()); + if (ctx_ == nullptr) { + throw std::runtime_error("SSL_CTX_new: " + getErrors()); + } + + int opt = 0; + switch (version) { + case TLSv1: + opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3; + break; + case SSLv3: + opt = SSL_OP_NO_SSLv2; + break; + default: + // do nothing + break; + } + int newOpt = SSL_CTX_set_options(ctx_, opt); + DCHECK((newOpt & opt) == opt); + + SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY); + + checkPeerName_ = false; + +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback); + SSL_CTX_set_tlsext_servername_arg(ctx_, this); +#endif +} + +SSLContext::~SSLContext() { + if (ctx_ != nullptr) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + +#ifdef OPENSSL_NPN_NEGOTIATED + deleteNextProtocolsStrings(); +#endif + + std::lock_guard g(mutex_); + if (!--count_) { + cleanupOpenSSL(); + } +} + +void SSLContext::ciphers(const std::string& ciphers) { + providedCiphersString_ = ciphers; + setCiphersOrThrow(ciphers); +} + +void SSLContext::setCiphersOrThrow(const std::string& ciphers) { + int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str()); + if (ERR_peek_error() != 0) { + throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors()); + } + if (rc == 0) { + throw std::runtime_error("None of specified ciphers are supported"); + } +} + +void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum& + verifyPeer) { + CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse + verifyPeer_ = verifyPeer; +} + +int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum& + verifyPeer) { + CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); + int mode = SSL_VERIFY_NONE; + switch(verifyPeer) { + // case SSLVerifyPeerEnum::USE_CTX: // can't happen + // break; + + case SSLVerifyPeerEnum::VERIFY: + mode = SSL_VERIFY_PEER; + break; + + case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT: + mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT; + break; + + case SSLVerifyPeerEnum::NO_VERIFY: + mode = SSL_VERIFY_NONE; + break; + + default: + break; + } + return mode; +} + +int SSLContext::getVerificationMode() { + return getVerificationMode(verifyPeer_); +} + +void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName, + const std::string& peerName) { + int mode; + if (checkPeerCert) { + mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE; + checkPeerName_ = checkPeerName; + peerFixedName_ = peerName; + } else { + mode = SSL_VERIFY_NONE; + checkPeerName_ = false; // can't check name without cert! + peerFixedName_.clear(); + } + SSL_CTX_set_verify(ctx_, mode, nullptr); +} + +void SSLContext::loadCertificate(const char* path, const char* format) { + if (path == nullptr || format == nullptr) { + throw std::invalid_argument( + "loadCertificateChain: either or is nullptr"); + } + if (strcmp(format, "PEM") == 0) { + if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) { + int errnoCopy = errno; + std::string reason("SSL_CTX_use_certificate_chain_file: "); + reason.append(path); + reason.append(": "); + reason.append(getErrors(errnoCopy)); + throw std::runtime_error(reason); + } + } else { + throw std::runtime_error("Unsupported certificate format: " + std::string(format)); + } +} + +void SSLContext::loadPrivateKey(const char* path, const char* format) { + if (path == nullptr || format == nullptr) { + throw std::invalid_argument( + "loadPrivateKey: either or is nullptr"); + } + if (strcmp(format, "PEM") == 0) { + if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) { + throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors()); + } + } else { + throw std::runtime_error("Unsupported private key format: " + std::string(format)); + } +} + +void SSLContext::loadTrustedCertificates(const char* path) { + if (path == nullptr) { + throw std::invalid_argument( + "loadTrustedCertificates: is nullptr"); + } + if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) { + throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors()); + } +} + +void SSLContext::loadTrustedCertificates(X509_STORE* store) { + SSL_CTX_set_cert_store(ctx_, store); +} + +void SSLContext::loadClientCAList(const char* path) { + auto clientCAs = SSL_load_client_CA_file(path); + if (clientCAs == nullptr) { + LOG(ERROR) << "Unable to load ca file: " << path; + return; + } + SSL_CTX_set_client_CA_list(ctx_, clientCAs); +} + +void SSLContext::randomize() { + RAND_poll(); +} + +void SSLContext::passwordCollector(std::shared_ptr collector) { + if (collector == nullptr) { + LOG(ERROR) << "passwordCollector: ignore invalid password collector"; + return; + } + collector_ = collector; + SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback); + SSL_CTX_set_default_passwd_cb_userdata(ctx_, this); +} + +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + +void SSLContext::setServerNameCallback(const ServerNameCallback& cb) { + serverNameCb_ = cb; +} + +void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) { + clientHelloCbs_.push_back(cb); +} + +int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) { + SSLContext* context = (SSLContext*)data; + + if (context == nullptr) { + return SSL_TLSEXT_ERR_NOACK; + } + + for (auto& cb : context->clientHelloCbs_) { + // Generic callbacks to happen after we receive the Client Hello. + // For example, we use one to switch which cipher we use depending + // on the user's TLS version. Because the primary purpose of + // baseServerNameOpenSSLCallback is for SNI support, and these callbacks + // are side-uses, we ignore any possible failures other than just logging + // them. + cb(ssl); + } + + if (!context->serverNameCb_) { + return SSL_TLSEXT_ERR_NOACK; + } + + ServerNameCallbackResult ret = context->serverNameCb_(ssl); + switch (ret) { + case SERVER_NAME_FOUND: + return SSL_TLSEXT_ERR_OK; + case SERVER_NAME_NOT_FOUND: + return SSL_TLSEXT_ERR_NOACK; + case SERVER_NAME_NOT_FOUND_ALERT_FATAL: + *al = TLS1_AD_UNRECOGNIZED_NAME; + return SSL_TLSEXT_ERR_ALERT_FATAL; + default: + CHECK(false); + } + + return SSL_TLSEXT_ERR_NOACK; +} + +void SSLContext::switchCiphersIfTLS11( + SSL* ssl, + const std::string& tls11CipherString) { + + CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers"; + + if (TLS1_get_client_version(ssl) <= TLS1_VERSION) { + // We only do this for TLS v 1.1 and later + return; + } + + // Prefer AES for TLS versions 1.1 and later since these are not + // vulnerable to BEAST attacks on AES. Note that we're setting the + // cipher list on the SSL object, not the SSL_CTX object, so it will + // only last for this request. + int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str()); + if ((rc == 0) || ERR_peek_error() != 0) { + // This shouldn't happen since we checked for this when proxygen + // started up. + LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch"; + SSL_set_cipher_list(ssl, providedCiphersString_.c_str()); + } +} +#endif + +#ifdef OPENSSL_NPN_NEGOTIATED +bool SSLContext::setAdvertisedNextProtocols(const std::list& protocols) { + return setRandomizedAdvertisedNextProtocols({{1, protocols}}); +} + +bool SSLContext::setRandomizedAdvertisedNextProtocols( + const std::list& items) { + unsetNextProtocols(); + if (items.size() == 0) { + return false; + } + int total_weight = 0; + for (const auto &item : items) { + if (item.protocols.size() == 0) { + continue; + } + AdvertisedNextProtocolsItem advertised_item; + advertised_item.length = 0; + for (const auto& proto : item.protocols) { + ++advertised_item.length; + unsigned protoLength = proto.length(); + if (protoLength >= 256) { + deleteNextProtocolsStrings(); + return false; + } + advertised_item.length += protoLength; + } + advertised_item.protocols = new unsigned char[advertised_item.length]; + if (!advertised_item.protocols) { + throw std::runtime_error("alloc failure"); + } + unsigned char* dst = advertised_item.protocols; + for (auto& proto : item.protocols) { + unsigned protoLength = proto.length(); + *dst++ = (unsigned char)protoLength; + memcpy(dst, proto.data(), protoLength); + dst += protoLength; + } + total_weight += item.weight; + advertised_item.probability = item.weight; + advertisedNextProtocols_.push_back(advertised_item); + } + if (total_weight == 0) { + deleteNextProtocolsStrings(); + return false; + } + for (auto &advertised_item : advertisedNextProtocols_) { + advertised_item.probability /= total_weight; + } + SSL_CTX_set_next_protos_advertised_cb( + ctx_, advertisedNextProtocolCallback, this); + SSL_CTX_set_next_proto_select_cb( + ctx_, selectNextProtocolCallback, this); + return true; +} + +void SSLContext::deleteNextProtocolsStrings() { + for (auto protocols : advertisedNextProtocols_) { + delete[] protocols.protocols; + } + advertisedNextProtocols_.clear(); +} + +void SSLContext::unsetNextProtocols() { + deleteNextProtocolsStrings(); + SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr); + SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr); +} + +int SSLContext::advertisedNextProtocolCallback(SSL* ssl, + const unsigned char** out, unsigned int* outlen, void* data) { + SSLContext* context = (SSLContext*)data; + if (context == nullptr || context->advertisedNextProtocols_.empty()) { + *out = nullptr; + *outlen = 0; + } else if (context->advertisedNextProtocols_.size() == 1) { + *out = context->advertisedNextProtocols_[0].protocols; + *outlen = context->advertisedNextProtocols_[0].length; + } else { + uintptr_t selected_index = reinterpret_cast(SSL_get_ex_data(ssl, + sNextProtocolsExDataIndex_)); + if (selected_index) { + --selected_index; + *out = context->advertisedNextProtocols_[selected_index].protocols; + *outlen = context->advertisedNextProtocols_[selected_index].length; + } else { + unsigned char random_byte; + RAND_bytes(&random_byte, 1); + double random_value = random_byte / 255.0; + double sum = 0; + for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) { + sum += context->advertisedNextProtocols_[i].probability; + if (sum < random_value && + i + 1 < context->advertisedNextProtocols_.size()) { + continue; + } + uintptr_t selected = i + 1; + SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected); + *out = context->advertisedNextProtocols_[i].protocols; + *outlen = context->advertisedNextProtocols_[i].length; + break; + } + } + } + return SSL_TLSEXT_ERR_OK; +} + +int SSLContext::selectNextProtocolCallback( + SSL* ssl, unsigned char **out, unsigned char *outlen, + const unsigned char *server, unsigned int server_len, void *data) { + + SSLContext* ctx = (SSLContext*)data; + if (ctx->advertisedNextProtocols_.size() > 1) { + VLOG(3) << "SSLContext::selectNextProcolCallback() " + << "client should be deterministic in selecting protocols."; + } + + unsigned char *client; + int client_len; + if (ctx->advertisedNextProtocols_.empty()) { + client = (unsigned char *) ""; + client_len = 0; + } else { + client = ctx->advertisedNextProtocols_[0].protocols; + client_len = ctx->advertisedNextProtocols_[0].length; + } + + int retval = SSL_select_next_proto(out, outlen, server, server_len, + client, client_len); + if (retval != OPENSSL_NPN_NEGOTIATED) { + VLOG(3) << "SSLContext::selectNextProcolCallback() " + << "unable to pick a next protocol."; + } + return SSL_TLSEXT_ERR_OK; +} +#endif // OPENSSL_NPN_NEGOTIATED + +SSL* SSLContext::createSSL() const { + SSL* ssl = SSL_new(ctx_); + if (ssl == nullptr) { + throw std::runtime_error("SSL_new: " + getErrors()); + } + return ssl; +} + +/** + * Match a name with a pattern. The pattern may include wildcard. A single + * wildcard "*" can match up to one component in the domain name. + * + * @param host Host name, typically the name of the remote host + * @param pattern Name retrieved from certificate + * @param size Size of "pattern" + * @return True, if "host" matches "pattern". False otherwise. + */ +bool SSLContext::matchName(const char* host, const char* pattern, int size) { + bool match = false; + int i = 0, j = 0; + while (i < size && host[j] != '\0') { + if (toupper(pattern[i]) == toupper(host[j])) { + i++; + j++; + continue; + } + if (pattern[i] == '*') { + while (host[j] != '.' && host[j] != '\0') { + j++; + } + i++; + continue; + } + break; + } + if (i == size && host[j] == '\0') { + match = true; + } + return match; +} + +int SSLContext::passwordCallback(char* password, + int size, + int, + void* data) { + SSLContext* context = (SSLContext*)data; + if (context == nullptr || context->passwordCollector() == nullptr) { + return 0; + } + std::string userPassword; + // call user defined password collector to get password + context->passwordCollector()->getPassword(userPassword, size); + int length = userPassword.size(); + if (length > size) { + length = size; + } + strncpy(password, userPassword.c_str(), length); + return length; +} + +struct SSLLock { + explicit SSLLock( + SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) : + lockType(inLockType) { + } + + void lock() { + if (lockType == SSLContext::LOCK_MUTEX) { + mutex.lock(); + } else if (lockType == SSLContext::LOCK_SPINLOCK) { + spinLock.lock(); + } + // lockType == LOCK_NONE, no-op + } + + void unlock() { + if (lockType == SSLContext::LOCK_MUTEX) { + mutex.unlock(); + } else if (lockType == SSLContext::LOCK_SPINLOCK) { + spinLock.unlock(); + } + // lockType == LOCK_NONE, no-op + } + + SSLContext::SSLLockType lockType; + folly::io::PortableSpinLock spinLock{}; + std::mutex mutex; +}; + +static std::map lockTypes; +static std::unique_ptr locks; + +static void callbackLocking(int mode, int n, const char*, int) { + if (mode & CRYPTO_LOCK) { + locks[n].lock(); + } else { + locks[n].unlock(); + } +} + +static unsigned long callbackThreadID() { + return static_cast(pthread_self()); +} + +static CRYPTO_dynlock_value* dyn_create(const char*, int) { + return new CRYPTO_dynlock_value; +} + +static void dyn_lock(int mode, + struct CRYPTO_dynlock_value* lock, + const char*, int) { + if (lock != nullptr) { + if (mode & CRYPTO_LOCK) { + lock->mutex.lock(); + } else { + lock->mutex.unlock(); + } + } +} + +static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) { + delete lock; +} + +void SSLContext::setSSLLockTypes(std::map inLockTypes) { + lockTypes = inLockTypes; +} + +void SSLContext::initializeOpenSSL() { + SSL_library_init(); + SSL_load_error_strings(); + ERR_load_crypto_strings(); + // static locking + locks.reset(new SSLLock[::CRYPTO_num_locks()]); + for (auto it: lockTypes) { + locks[it.first].lockType = it.second; + } + CRYPTO_set_id_callback(callbackThreadID); + CRYPTO_set_locking_callback(callbackLocking); + // dynamic locking + CRYPTO_set_dynlock_create_callback(dyn_create); + CRYPTO_set_dynlock_lock_callback(dyn_lock); + CRYPTO_set_dynlock_destroy_callback(dyn_destroy); +} + +void SSLContext::cleanupOpenSSL() { + CRYPTO_set_id_callback(nullptr); + CRYPTO_set_locking_callback(nullptr); + CRYPTO_set_dynlock_create_callback(nullptr); + CRYPTO_set_dynlock_lock_callback(nullptr); + CRYPTO_set_dynlock_destroy_callback(nullptr); + CRYPTO_cleanup_all_ex_data(); + ERR_free_strings(); + EVP_cleanup(); + ERR_remove_state(0); + locks.reset(); +} + +void SSLContext::setOptions(long options) { + long newOpt = SSL_CTX_set_options(ctx_, options); + if ((newOpt & options) != options) { + throw std::runtime_error("SSL_CTX_set_options failed"); + } +} + +std::string SSLContext::getErrors(int errnoCopy) { + std::string errors; + unsigned long errorCode; + char message[256]; + + errors.reserve(512); + while ((errorCode = ERR_get_error()) != 0) { + if (!errors.empty()) { + errors += "; "; + } + const char* reason = ERR_reason_error_string(errorCode); + if (reason == nullptr) { + snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode); + reason = message; + } + errors += reason; + } + if (errors.empty()) { + errors = "error code: " + folly::to(errnoCopy); + } + return errors; +} + +std::ostream& +operator<<(std::ostream& os, const PasswordCollector& collector) { + os << collector.describe(); + return os; +} + +} // folly diff --git a/folly/io/async/SSLContext.h b/folly/io/async/SSLContext.h new file mode 100644 index 00000000..5819f68f --- /dev/null +++ b/folly/io/async/SSLContext.h @@ -0,0 +1,453 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace folly { + +/** + * Override the default password collector. + */ +class PasswordCollector { + public: + virtual ~PasswordCollector() {} + /** + * Interface for customizing how to collect private key password. + * + * By default, OpenSSL prints a prompt on screen and request for password + * while loading private key. To implement a custom password collector, + * implement this interface and register it with TSSLSocketFactory. + * + * @param password Pass collected password back to OpenSSL + * @param size Maximum length of password including nullptr character + */ + virtual void getPassword(std::string& password, int size) = 0; + + /** + * Return a description of this collector for logging purposes + */ + virtual std::string describe() const = 0; +}; + +/** + * Wrap OpenSSL SSL_CTX into a class. + */ +class SSLContext { + public: + + enum SSLVersion { + SSLv2, + SSLv3, + TLSv1 + }; + + enum SSLVerifyPeerEnum{ + USE_CTX, + VERIFY, + VERIFY_REQ_CLIENT_CERT, + NO_VERIFY + }; + + struct NextProtocolsItem { + int weight; + std::list protocols; + }; + + struct AdvertisedNextProtocolsItem { + unsigned char *protocols; + unsigned length; + double probability; + }; + + /** + * Convenience function to call getErrors() with the current errno value. + * + * Make sure that you only call this when there was no intervening operation + * since the last OpenSSL error that may have changed the current errno value. + */ + static std::string getErrors() { + return getErrors(errno); + } + + /** + * Constructor. + * + * @param version The lowest or oldest SSL version to support. + */ + explicit SSLContext(SSLVersion version = TLSv1); + virtual ~SSLContext(); + + /** + * Set default ciphers to be used in SSL handshake process. + * + * @param ciphers A list of ciphers to use for TLSv1.0 + */ + virtual void ciphers(const std::string& ciphers); + + /** + * Low-level method that attempts to set the provided ciphers on the + * SSL_CTX object, and throws if something goes wrong. + */ + virtual void setCiphersOrThrow(const std::string& ciphers); + + /** + * Method to set verification option in the context object. + * + * @param verifyPeer SSLVerifyPeerEnum indicating the verification + * method to use. + */ + virtual void setVerificationOption(const SSLVerifyPeerEnum& verifyPeer); + + /** + * Method to check if peer verfication is set. + * + * @return true if peer verification is required. + * + */ + virtual bool needsPeerVerification() { + return (verifyPeer_ == SSLVerifyPeerEnum::VERIFY || + verifyPeer_ == SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); + } + + /** + * Method to fetch Verification mode for a SSLVerifyPeerEnum. + * verifyPeer cannot be SSLVerifyPeerEnum::USE_CTX since there is no + * context. + * + * @param verifyPeer SSLVerifyPeerEnum for which the flags need to + * to be returned + * + * @return mode flags that can be used with SSL_set_verify + */ + static int getVerificationMode(const SSLVerifyPeerEnum& verifyPeer); + + /** + * Method to fetch Verification mode determined by the options + * set using setVerificationOption. + * + * @return mode flags that can be used with SSL_set_verify + */ + virtual int getVerificationMode(); + + /** + * Enable/Disable authentication. Peer name validation can only be done + * if checkPeerCert is true. + * + * @param checkPeerCert If true, require peer to present valid certificate + * @param checkPeerName If true, validate that the certificate common name + * or alternate name(s) of peer matches the hostname + * used to connect. + * @param peerName If non-empty, validate that the certificate common + * name of peer matches the given string (altername + * name(s) are not used in this case). + */ + virtual void authenticate(bool checkPeerCert, bool checkPeerName, + const std::string& peerName = std::string()); + /** + * Load server certificate. + * + * @param path Path to the certificate file + * @param format Certificate file format + */ + virtual void loadCertificate(const char* path, const char* format = "PEM"); + /** + * Load private key. + * + * @param path Path to the private key file + * @param format Private key file format + */ + virtual void loadPrivateKey(const char* path, const char* format = "PEM"); + /** + * Load trusted certificates from specified file. + * + * @param path Path to trusted certificate file + */ + virtual void loadTrustedCertificates(const char* path); + /** + * Load trusted certificates from specified X509 certificate store. + * + * @param store X509 certificate store. + */ + virtual void loadTrustedCertificates(X509_STORE* store); + /** + * Load a client CA list for validating clients + */ + virtual void loadClientCAList(const char* path); + /** + * Default randomize method. + */ + virtual void randomize(); + /** + * Override default OpenSSL password collector. + * + * @param collector Instance of user defined password collector + */ + virtual void passwordCollector(std::shared_ptr collector); + /** + * Obtain password collector. + * + * @return User defined password collector + */ + virtual std::shared_ptr passwordCollector() { + return collector_; + } +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + /** + * Provide SNI support + */ + enum ServerNameCallbackResult { + SERVER_NAME_FOUND, + SERVER_NAME_NOT_FOUND, + SERVER_NAME_NOT_FOUND_ALERT_FATAL, + }; + /** + * Callback function from openssl to give the application a + * chance to check the tlsext_hostname just right after parsing + * the Client Hello or Server Hello message. + * + * It is for the server to switch the SSL to another SSL_CTX + * to continue the handshake. (i.e. Server Name Indication, SNI, in RFC6066). + * + * If the ServerNameCallback returns: + * SERVER_NAME_FOUND: + * server: Send a tlsext_hostname in the Server Hello + * client: No-effect + * SERVER_NAME_NOT_FOUND: + * server: Does not send a tlsext_hostname in Server Hello + * and continue the handshake. + * client: No-effect + * SERVER_NAME_NOT_FOUND_ALERT_FATAL: + * server and client: Send fatal TLS1_AD_UNRECOGNIZED_NAME alert to + * the peer. + * + * Quote from RFC 6066: + * "... + * If the server understood the ClientHello extension but + * does not recognize the server name, the server SHOULD take one of two + * actions: either abort the handshake by sending a fatal-level + * unrecognized_name(112) alert or continue the handshake. It is NOT + * RECOMMENDED to send a warning-level unrecognized_name(112) alert, + * because the client's behavior in response to warning-level alerts is + * unpredictable. + * ..." + */ + + /** + * Set the ServerNameCallback + */ + typedef std::function ServerNameCallback; + virtual void setServerNameCallback(const ServerNameCallback& cb); + + /** + * Generic callbacks that are run after we get the Client Hello (right + * before we run the ServerNameCallback) + */ + typedef std::function ClientHelloCallback; + virtual void addClientHelloCallback(const ClientHelloCallback& cb); +#endif + + /** + * Create an SSL object from this context. + */ + SSL* createSSL() const; + + /** + * Set the options on the SSL_CTX object. + */ + void setOptions(long options); + +#ifdef OPENSSL_NPN_NEGOTIATED + /** + * Set the list of protocols that this SSL context supports. In server + * mode, this is the list of protocols that will be advertised for Next + * Protocol Negotiation (NPN). In client mode, the first protocol + * advertised by the server that is also on this list is + * chosen. Invoking this function with a list of length zero causes NPN + * to be disabled. + * + * @param protocols List of protocol names. This method makes a copy, + * so the caller needn't keep the list in scope after + * the call completes. The list must have at least + * one element to enable NPN. Each element must have + * a string length < 256. + * @return true if NPN has been activated. False if NPN is disabled. + */ + bool setAdvertisedNextProtocols(const std::list& protocols); + /** + * Set weighted list of lists of protocols that this SSL context supports. + * In server mode, each element of the list contains a list of protocols that + * could be advertised for Next Protocol Negotiation (NPN). The list of + * protocols that will be advertised to a client is selected randomly, based + * on weights of elements. Client mode doesn't support randomized NPN, so + * this list should contain only 1 element. The first protocol advertised + * by the server that is also on the list of protocols of this element is + * chosen. Invoking this function with a list of length zero causes NPN + * to be disabled. + * + * @param items List of NextProtocolsItems, Each item contains a list of + * protocol names and weight. After the call of this fucntion + * each non-empty list of protocols will be advertised with + * probability weight/sum_of_weights. This method makes a copy, + * so the caller needn't keep the list in scope after the call + * completes. The list must have at least one element with + * non-zero weight and non-empty protocols list to enable NPN. + * Each name of the protocol must have a string length < 256. + * @return true if NPN has been activated. False if NPN is disabled. + */ + bool setRandomizedAdvertisedNextProtocols( + const std::list& items); + + /** + * Disables NPN on this SSL context. + */ + void unsetNextProtocols(); + void deleteNextProtocolsStrings(); +#endif // OPENSSL_NPN_NEGOTIATED + + /** + * Gets the underlying SSL_CTX for advanced usage + */ + SSL_CTX *getSSLCtx() const { + return ctx_; + } + + enum SSLLockType { + LOCK_MUTEX, + LOCK_SPINLOCK, + LOCK_NONE + }; + + /** + * Set preferences for how to treat locks in OpenSSL. This must be + * called before the instantiation of any SSLContext objects, otherwise + * the defaults will be used. + * + * OpenSSL has a lock for each module rather than for each object or + * data that needs locking. Some locks protect only refcounts, and + * might be better as spinlocks rather than mutexes. Other locks + * may be totally unnecessary if the objects being protected are not + * shared between threads in the application. + * + * By default, all locks are initialized as mutexes. OpenSSL's lock usage + * may change from version to version and you should know what you are doing + * before disabling any locks entirely. + * + * Example: if you don't share SSL sessions between threads in your + * application, you may be able to do this + * + * setSSLLockTypes({{CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_NONE}}) + */ + static void setSSLLockTypes(std::map lockTypes); + + /** + * Examine OpenSSL's error stack, and return a string description of the + * errors. + * + * This operation removes the errors from OpenSSL's error stack. + */ + static std::string getErrors(int errnoCopy); + + /** + * We want to vary which cipher we'll use based on the client's TLS version. + */ + void switchCiphersIfTLS11( + SSL* ssl, + const std::string& tls11CipherString + ); + + bool checkPeerName() { return checkPeerName_; } + std::string peerFixedName() { return peerFixedName_; } + + /** + * Helper to match a hostname versus a pattern. + */ + static bool matchName(const char* host, const char* pattern, int size); + + protected: + SSL_CTX* ctx_; + + private: + SSLVerifyPeerEnum verifyPeer_{SSLVerifyPeerEnum::NO_VERIFY}; + + bool checkPeerName_; + std::string peerFixedName_; + std::shared_ptr collector_; +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + ServerNameCallback serverNameCb_; + std::vector clientHelloCbs_; +#endif + + static std::mutex mutex_; + static uint64_t count_; + +#ifdef OPENSSL_NPN_NEGOTIATED + /** + * Wire-format list of advertised protocols for use in NPN. + */ + std::vector advertisedNextProtocols_; + static int sNextProtocolsExDataIndex_; + + static int advertisedNextProtocolCallback(SSL* ssl, + const unsigned char** out, unsigned int* outlen, void* data); + static int selectNextProtocolCallback( + SSL* ssl, unsigned char **out, unsigned char *outlen, + const unsigned char *server, unsigned int server_len, void *args); +#endif // OPENSSL_NPN_NEGOTIATED + + static int passwordCallback(char* password, int size, int, void* data); + + static void initializeOpenSSL(); + static void cleanupOpenSSL(); + + +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + /** + * The function that will be called directly from openssl + * in order for the application to get the tlsext_hostname just after + * parsing the Client Hello or Server Hello message. It will then call + * the serverNameCb_ function object. Hence, it is sort of a + * wrapper/proxy between serverNameCb_ and openssl. + * + * The openssl's primary intention is for SNI support, but we also use it + * generically for performing logic after the Client Hello comes in. + */ + static int baseServerNameOpenSSLCallback( + SSL* ssl, + int* al /* alert (return value) */, + void* data + ); +#endif + + std::string providedCiphersString_; +}; + +typedef std::shared_ptr SSLContextPtr; + +std::ostream& operator<<(std::ostream& os, const folly::PasswordCollector& collector); + +} // folly -- 2.34.1