--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "SSLContext.h"
+
+#include <openssl/err.h>
+#include <openssl/rand.h>
+#include <openssl/ssl.h>
+#include <openssl/x509v3.h>
+
+#include <folly/SmallLocks.h>
+#include <folly/Format.h>
+#include <folly/io/PortableSpinLock.h>
+
+// ---------------------------------------------------------------------
+// SSLContext implementation
+// ---------------------------------------------------------------------
+
+struct CRYPTO_dynlock_value {
+ std::mutex mutex;
+};
+
+namespace folly {
+
+uint64_t SSLContext::count_ = 0;
+std::mutex SSLContext::mutex_;
+#ifdef OPENSSL_NPN_NEGOTIATED
+int SSLContext::sNextProtocolsExDataIndex_ = -1;
+
+#endif
+// SSLContext implementation
+SSLContext::SSLContext(SSLVersion version) {
+ {
+ std::lock_guard<std::mutex> g(mutex_);
+ if (!count_++) {
+ initializeOpenSSL();
+ randomize();
+#ifdef OPENSSL_NPN_NEGOTIATED
+ sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
+ (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
+#endif
+ }
+ }
+
+ ctx_ = SSL_CTX_new(SSLv23_method());
+ if (ctx_ == nullptr) {
+ throw std::runtime_error("SSL_CTX_new: " + getErrors());
+ }
+
+ int opt = 0;
+ switch (version) {
+ case TLSv1:
+ opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
+ break;
+ case SSLv3:
+ opt = SSL_OP_NO_SSLv2;
+ break;
+ default:
+ // do nothing
+ break;
+ }
+ int newOpt = SSL_CTX_set_options(ctx_, opt);
+ DCHECK((newOpt & opt) == opt);
+
+ SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
+
+ checkPeerName_ = false;
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+ SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
+ SSL_CTX_set_tlsext_servername_arg(ctx_, this);
+#endif
+}
+
+SSLContext::~SSLContext() {
+ if (ctx_ != nullptr) {
+ SSL_CTX_free(ctx_);
+ ctx_ = nullptr;
+ }
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+ deleteNextProtocolsStrings();
+#endif
+
+ std::lock_guard<std::mutex> g(mutex_);
+ if (!--count_) {
+ cleanupOpenSSL();
+ }
+}
+
+void SSLContext::ciphers(const std::string& ciphers) {
+ providedCiphersString_ = ciphers;
+ setCiphersOrThrow(ciphers);
+}
+
+void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
+ int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
+ if (ERR_peek_error() != 0) {
+ throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
+ }
+ if (rc == 0) {
+ throw std::runtime_error("None of specified ciphers are supported");
+ }
+}
+
+void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
+ verifyPeer) {
+ CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
+ verifyPeer_ = verifyPeer;
+}
+
+int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
+ verifyPeer) {
+ CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
+ int mode = SSL_VERIFY_NONE;
+ switch(verifyPeer) {
+ // case SSLVerifyPeerEnum::USE_CTX: // can't happen
+ // break;
+
+ case SSLVerifyPeerEnum::VERIFY:
+ mode = SSL_VERIFY_PEER;
+ break;
+
+ case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
+ mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
+ break;
+
+ case SSLVerifyPeerEnum::NO_VERIFY:
+ mode = SSL_VERIFY_NONE;
+ break;
+
+ default:
+ break;
+ }
+ return mode;
+}
+
+int SSLContext::getVerificationMode() {
+ return getVerificationMode(verifyPeer_);
+}
+
+void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
+ const std::string& peerName) {
+ int mode;
+ if (checkPeerCert) {
+ mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
+ checkPeerName_ = checkPeerName;
+ peerFixedName_ = peerName;
+ } else {
+ mode = SSL_VERIFY_NONE;
+ checkPeerName_ = false; // can't check name without cert!
+ peerFixedName_.clear();
+ }
+ SSL_CTX_set_verify(ctx_, mode, nullptr);
+}
+
+void SSLContext::loadCertificate(const char* path, const char* format) {
+ if (path == nullptr || format == nullptr) {
+ throw std::invalid_argument(
+ "loadCertificateChain: either <path> or <format> is nullptr");
+ }
+ if (strcmp(format, "PEM") == 0) {
+ if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
+ int errnoCopy = errno;
+ std::string reason("SSL_CTX_use_certificate_chain_file: ");
+ reason.append(path);
+ reason.append(": ");
+ reason.append(getErrors(errnoCopy));
+ throw std::runtime_error(reason);
+ }
+ } else {
+ throw std::runtime_error("Unsupported certificate format: " + std::string(format));
+ }
+}
+
+void SSLContext::loadPrivateKey(const char* path, const char* format) {
+ if (path == nullptr || format == nullptr) {
+ throw std::invalid_argument(
+ "loadPrivateKey: either <path> or <format> is nullptr");
+ }
+ if (strcmp(format, "PEM") == 0) {
+ if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
+ throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
+ }
+ } else {
+ throw std::runtime_error("Unsupported private key format: " + std::string(format));
+ }
+}
+
+void SSLContext::loadTrustedCertificates(const char* path) {
+ if (path == nullptr) {
+ throw std::invalid_argument(
+ "loadTrustedCertificates: <path> is nullptr");
+ }
+ if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
+ throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
+ }
+}
+
+void SSLContext::loadTrustedCertificates(X509_STORE* store) {
+ SSL_CTX_set_cert_store(ctx_, store);
+}
+
+void SSLContext::loadClientCAList(const char* path) {
+ auto clientCAs = SSL_load_client_CA_file(path);
+ if (clientCAs == nullptr) {
+ LOG(ERROR) << "Unable to load ca file: " << path;
+ return;
+ }
+ SSL_CTX_set_client_CA_list(ctx_, clientCAs);
+}
+
+void SSLContext::randomize() {
+ RAND_poll();
+}
+
+void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
+ if (collector == nullptr) {
+ LOG(ERROR) << "passwordCollector: ignore invalid password collector";
+ return;
+ }
+ collector_ = collector;
+ SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
+ SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
+}
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+
+void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
+ serverNameCb_ = cb;
+}
+
+void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
+ clientHelloCbs_.push_back(cb);
+}
+
+int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
+ SSLContext* context = (SSLContext*)data;
+
+ if (context == nullptr) {
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+
+ for (auto& cb : context->clientHelloCbs_) {
+ // Generic callbacks to happen after we receive the Client Hello.
+ // For example, we use one to switch which cipher we use depending
+ // on the user's TLS version. Because the primary purpose of
+ // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
+ // are side-uses, we ignore any possible failures other than just logging
+ // them.
+ cb(ssl);
+ }
+
+ if (!context->serverNameCb_) {
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+
+ ServerNameCallbackResult ret = context->serverNameCb_(ssl);
+ switch (ret) {
+ case SERVER_NAME_FOUND:
+ return SSL_TLSEXT_ERR_OK;
+ case SERVER_NAME_NOT_FOUND:
+ return SSL_TLSEXT_ERR_NOACK;
+ case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
+ *al = TLS1_AD_UNRECOGNIZED_NAME;
+ return SSL_TLSEXT_ERR_ALERT_FATAL;
+ default:
+ CHECK(false);
+ }
+
+ return SSL_TLSEXT_ERR_NOACK;
+}
+
+void SSLContext::switchCiphersIfTLS11(
+ SSL* ssl,
+ const std::string& tls11CipherString) {
+
+ CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
+
+ if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
+ // We only do this for TLS v 1.1 and later
+ return;
+ }
+
+ // Prefer AES for TLS versions 1.1 and later since these are not
+ // vulnerable to BEAST attacks on AES. Note that we're setting the
+ // cipher list on the SSL object, not the SSL_CTX object, so it will
+ // only last for this request.
+ int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
+ if ((rc == 0) || ERR_peek_error() != 0) {
+ // This shouldn't happen since we checked for this when proxygen
+ // started up.
+ LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
+ SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
+ }
+}
+#endif
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
+ return setRandomizedAdvertisedNextProtocols({{1, protocols}});
+}
+
+bool SSLContext::setRandomizedAdvertisedNextProtocols(
+ const std::list<NextProtocolsItem>& items) {
+ unsetNextProtocols();
+ if (items.size() == 0) {
+ return false;
+ }
+ int total_weight = 0;
+ for (const auto &item : items) {
+ if (item.protocols.size() == 0) {
+ continue;
+ }
+ AdvertisedNextProtocolsItem advertised_item;
+ advertised_item.length = 0;
+ for (const auto& proto : item.protocols) {
+ ++advertised_item.length;
+ unsigned protoLength = proto.length();
+ if (protoLength >= 256) {
+ deleteNextProtocolsStrings();
+ return false;
+ }
+ advertised_item.length += protoLength;
+ }
+ advertised_item.protocols = new unsigned char[advertised_item.length];
+ if (!advertised_item.protocols) {
+ throw std::runtime_error("alloc failure");
+ }
+ unsigned char* dst = advertised_item.protocols;
+ for (auto& proto : item.protocols) {
+ unsigned protoLength = proto.length();
+ *dst++ = (unsigned char)protoLength;
+ memcpy(dst, proto.data(), protoLength);
+ dst += protoLength;
+ }
+ total_weight += item.weight;
+ advertised_item.probability = item.weight;
+ advertisedNextProtocols_.push_back(advertised_item);
+ }
+ if (total_weight == 0) {
+ deleteNextProtocolsStrings();
+ return false;
+ }
+ for (auto &advertised_item : advertisedNextProtocols_) {
+ advertised_item.probability /= total_weight;
+ }
+ SSL_CTX_set_next_protos_advertised_cb(
+ ctx_, advertisedNextProtocolCallback, this);
+ SSL_CTX_set_next_proto_select_cb(
+ ctx_, selectNextProtocolCallback, this);
+ return true;
+}
+
+void SSLContext::deleteNextProtocolsStrings() {
+ for (auto protocols : advertisedNextProtocols_) {
+ delete[] protocols.protocols;
+ }
+ advertisedNextProtocols_.clear();
+}
+
+void SSLContext::unsetNextProtocols() {
+ deleteNextProtocolsStrings();
+ SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
+ SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
+}
+
+int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
+ const unsigned char** out, unsigned int* outlen, void* data) {
+ SSLContext* context = (SSLContext*)data;
+ if (context == nullptr || context->advertisedNextProtocols_.empty()) {
+ *out = nullptr;
+ *outlen = 0;
+ } else if (context->advertisedNextProtocols_.size() == 1) {
+ *out = context->advertisedNextProtocols_[0].protocols;
+ *outlen = context->advertisedNextProtocols_[0].length;
+ } else {
+ uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
+ sNextProtocolsExDataIndex_));
+ if (selected_index) {
+ --selected_index;
+ *out = context->advertisedNextProtocols_[selected_index].protocols;
+ *outlen = context->advertisedNextProtocols_[selected_index].length;
+ } else {
+ unsigned char random_byte;
+ RAND_bytes(&random_byte, 1);
+ double random_value = random_byte / 255.0;
+ double sum = 0;
+ for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
+ sum += context->advertisedNextProtocols_[i].probability;
+ if (sum < random_value &&
+ i + 1 < context->advertisedNextProtocols_.size()) {
+ continue;
+ }
+ uintptr_t selected = i + 1;
+ SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
+ *out = context->advertisedNextProtocols_[i].protocols;
+ *outlen = context->advertisedNextProtocols_[i].length;
+ break;
+ }
+ }
+ }
+ return SSL_TLSEXT_ERR_OK;
+}
+
+int SSLContext::selectNextProtocolCallback(
+ SSL* ssl, unsigned char **out, unsigned char *outlen,
+ const unsigned char *server, unsigned int server_len, void *data) {
+
+ SSLContext* ctx = (SSLContext*)data;
+ if (ctx->advertisedNextProtocols_.size() > 1) {
+ VLOG(3) << "SSLContext::selectNextProcolCallback() "
+ << "client should be deterministic in selecting protocols.";
+ }
+
+ unsigned char *client;
+ int client_len;
+ if (ctx->advertisedNextProtocols_.empty()) {
+ client = (unsigned char *) "";
+ client_len = 0;
+ } else {
+ client = ctx->advertisedNextProtocols_[0].protocols;
+ client_len = ctx->advertisedNextProtocols_[0].length;
+ }
+
+ int retval = SSL_select_next_proto(out, outlen, server, server_len,
+ client, client_len);
+ if (retval != OPENSSL_NPN_NEGOTIATED) {
+ VLOG(3) << "SSLContext::selectNextProcolCallback() "
+ << "unable to pick a next protocol.";
+ }
+ return SSL_TLSEXT_ERR_OK;
+}
+#endif // OPENSSL_NPN_NEGOTIATED
+
+SSL* SSLContext::createSSL() const {
+ SSL* ssl = SSL_new(ctx_);
+ if (ssl == nullptr) {
+ throw std::runtime_error("SSL_new: " + getErrors());
+ }
+ return ssl;
+}
+
+/**
+ * Match a name with a pattern. The pattern may include wildcard. A single
+ * wildcard "*" can match up to one component in the domain name.
+ *
+ * @param host Host name, typically the name of the remote host
+ * @param pattern Name retrieved from certificate
+ * @param size Size of "pattern"
+ * @return True, if "host" matches "pattern". False otherwise.
+ */
+bool SSLContext::matchName(const char* host, const char* pattern, int size) {
+ bool match = false;
+ int i = 0, j = 0;
+ while (i < size && host[j] != '\0') {
+ if (toupper(pattern[i]) == toupper(host[j])) {
+ i++;
+ j++;
+ continue;
+ }
+ if (pattern[i] == '*') {
+ while (host[j] != '.' && host[j] != '\0') {
+ j++;
+ }
+ i++;
+ continue;
+ }
+ break;
+ }
+ if (i == size && host[j] == '\0') {
+ match = true;
+ }
+ return match;
+}
+
+int SSLContext::passwordCallback(char* password,
+ int size,
+ int,
+ void* data) {
+ SSLContext* context = (SSLContext*)data;
+ if (context == nullptr || context->passwordCollector() == nullptr) {
+ return 0;
+ }
+ std::string userPassword;
+ // call user defined password collector to get password
+ context->passwordCollector()->getPassword(userPassword, size);
+ int length = userPassword.size();
+ if (length > size) {
+ length = size;
+ }
+ strncpy(password, userPassword.c_str(), length);
+ return length;
+}
+
+struct SSLLock {
+ explicit SSLLock(
+ SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
+ lockType(inLockType) {
+ }
+
+ void lock() {
+ if (lockType == SSLContext::LOCK_MUTEX) {
+ mutex.lock();
+ } else if (lockType == SSLContext::LOCK_SPINLOCK) {
+ spinLock.lock();
+ }
+ // lockType == LOCK_NONE, no-op
+ }
+
+ void unlock() {
+ if (lockType == SSLContext::LOCK_MUTEX) {
+ mutex.unlock();
+ } else if (lockType == SSLContext::LOCK_SPINLOCK) {
+ spinLock.unlock();
+ }
+ // lockType == LOCK_NONE, no-op
+ }
+
+ SSLContext::SSLLockType lockType;
+ folly::io::PortableSpinLock spinLock{};
+ std::mutex mutex;
+};
+
+static std::map<int, SSLContext::SSLLockType> lockTypes;
+static std::unique_ptr<SSLLock[]> locks;
+
+static void callbackLocking(int mode, int n, const char*, int) {
+ if (mode & CRYPTO_LOCK) {
+ locks[n].lock();
+ } else {
+ locks[n].unlock();
+ }
+}
+
+static unsigned long callbackThreadID() {
+ return static_cast<unsigned long>(pthread_self());
+}
+
+static CRYPTO_dynlock_value* dyn_create(const char*, int) {
+ return new CRYPTO_dynlock_value;
+}
+
+static void dyn_lock(int mode,
+ struct CRYPTO_dynlock_value* lock,
+ const char*, int) {
+ if (lock != nullptr) {
+ if (mode & CRYPTO_LOCK) {
+ lock->mutex.lock();
+ } else {
+ lock->mutex.unlock();
+ }
+ }
+}
+
+static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
+ delete lock;
+}
+
+void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
+ lockTypes = inLockTypes;
+}
+
+void SSLContext::initializeOpenSSL() {
+ SSL_library_init();
+ SSL_load_error_strings();
+ ERR_load_crypto_strings();
+ // static locking
+ locks.reset(new SSLLock[::CRYPTO_num_locks()]);
+ for (auto it: lockTypes) {
+ locks[it.first].lockType = it.second;
+ }
+ CRYPTO_set_id_callback(callbackThreadID);
+ CRYPTO_set_locking_callback(callbackLocking);
+ // dynamic locking
+ CRYPTO_set_dynlock_create_callback(dyn_create);
+ CRYPTO_set_dynlock_lock_callback(dyn_lock);
+ CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
+}
+
+void SSLContext::cleanupOpenSSL() {
+ CRYPTO_set_id_callback(nullptr);
+ CRYPTO_set_locking_callback(nullptr);
+ CRYPTO_set_dynlock_create_callback(nullptr);
+ CRYPTO_set_dynlock_lock_callback(nullptr);
+ CRYPTO_set_dynlock_destroy_callback(nullptr);
+ CRYPTO_cleanup_all_ex_data();
+ ERR_free_strings();
+ EVP_cleanup();
+ ERR_remove_state(0);
+ locks.reset();
+}
+
+void SSLContext::setOptions(long options) {
+ long newOpt = SSL_CTX_set_options(ctx_, options);
+ if ((newOpt & options) != options) {
+ throw std::runtime_error("SSL_CTX_set_options failed");
+ }
+}
+
+std::string SSLContext::getErrors(int errnoCopy) {
+ std::string errors;
+ unsigned long errorCode;
+ char message[256];
+
+ errors.reserve(512);
+ while ((errorCode = ERR_get_error()) != 0) {
+ if (!errors.empty()) {
+ errors += "; ";
+ }
+ const char* reason = ERR_reason_error_string(errorCode);
+ if (reason == nullptr) {
+ snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
+ reason = message;
+ }
+ errors += reason;
+ }
+ if (errors.empty()) {
+ errors = "error code: " + folly::to<std::string>(errnoCopy);
+ }
+ return errors;
+}
+
+std::ostream&
+operator<<(std::ostream& os, const PasswordCollector& collector) {
+ os << collector.describe();
+ return os;
+}
+
+} // folly
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <mutex>
+#include <list>
+#include <map>
+#include <vector>
+#include <memory>
+#include <string>
+
+#include <openssl/ssl.h>
+#include <openssl/tls1.h>
+
+#include <glog/logging.h>
+
+namespace folly {
+
+/**
+ * Override the default password collector.
+ */
+class PasswordCollector {
+ public:
+ virtual ~PasswordCollector() {}
+ /**
+ * Interface for customizing how to collect private key password.
+ *
+ * By default, OpenSSL prints a prompt on screen and request for password
+ * while loading private key. To implement a custom password collector,
+ * implement this interface and register it with TSSLSocketFactory.
+ *
+ * @param password Pass collected password back to OpenSSL
+ * @param size Maximum length of password including nullptr character
+ */
+ virtual void getPassword(std::string& password, int size) = 0;
+
+ /**
+ * Return a description of this collector for logging purposes
+ */
+ virtual std::string describe() const = 0;
+};
+
+/**
+ * Wrap OpenSSL SSL_CTX into a class.
+ */
+class SSLContext {
+ public:
+
+ enum SSLVersion {
+ SSLv2,
+ SSLv3,
+ TLSv1
+ };
+
+ enum SSLVerifyPeerEnum{
+ USE_CTX,
+ VERIFY,
+ VERIFY_REQ_CLIENT_CERT,
+ NO_VERIFY
+ };
+
+ struct NextProtocolsItem {
+ int weight;
+ std::list<std::string> protocols;
+ };
+
+ struct AdvertisedNextProtocolsItem {
+ unsigned char *protocols;
+ unsigned length;
+ double probability;
+ };
+
+ /**
+ * Convenience function to call getErrors() with the current errno value.
+ *
+ * Make sure that you only call this when there was no intervening operation
+ * since the last OpenSSL error that may have changed the current errno value.
+ */
+ static std::string getErrors() {
+ return getErrors(errno);
+ }
+
+ /**
+ * Constructor.
+ *
+ * @param version The lowest or oldest SSL version to support.
+ */
+ explicit SSLContext(SSLVersion version = TLSv1);
+ virtual ~SSLContext();
+
+ /**
+ * Set default ciphers to be used in SSL handshake process.
+ *
+ * @param ciphers A list of ciphers to use for TLSv1.0
+ */
+ virtual void ciphers(const std::string& ciphers);
+
+ /**
+ * Low-level method that attempts to set the provided ciphers on the
+ * SSL_CTX object, and throws if something goes wrong.
+ */
+ virtual void setCiphersOrThrow(const std::string& ciphers);
+
+ /**
+ * Method to set verification option in the context object.
+ *
+ * @param verifyPeer SSLVerifyPeerEnum indicating the verification
+ * method to use.
+ */
+ virtual void setVerificationOption(const SSLVerifyPeerEnum& verifyPeer);
+
+ /**
+ * Method to check if peer verfication is set.
+ *
+ * @return true if peer verification is required.
+ *
+ */
+ virtual bool needsPeerVerification() {
+ return (verifyPeer_ == SSLVerifyPeerEnum::VERIFY ||
+ verifyPeer_ == SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+ }
+
+ /**
+ * Method to fetch Verification mode for a SSLVerifyPeerEnum.
+ * verifyPeer cannot be SSLVerifyPeerEnum::USE_CTX since there is no
+ * context.
+ *
+ * @param verifyPeer SSLVerifyPeerEnum for which the flags need to
+ * to be returned
+ *
+ * @return mode flags that can be used with SSL_set_verify
+ */
+ static int getVerificationMode(const SSLVerifyPeerEnum& verifyPeer);
+
+ /**
+ * Method to fetch Verification mode determined by the options
+ * set using setVerificationOption.
+ *
+ * @return mode flags that can be used with SSL_set_verify
+ */
+ virtual int getVerificationMode();
+
+ /**
+ * Enable/Disable authentication. Peer name validation can only be done
+ * if checkPeerCert is true.
+ *
+ * @param checkPeerCert If true, require peer to present valid certificate
+ * @param checkPeerName If true, validate that the certificate common name
+ * or alternate name(s) of peer matches the hostname
+ * used to connect.
+ * @param peerName If non-empty, validate that the certificate common
+ * name of peer matches the given string (altername
+ * name(s) are not used in this case).
+ */
+ virtual void authenticate(bool checkPeerCert, bool checkPeerName,
+ const std::string& peerName = std::string());
+ /**
+ * Load server certificate.
+ *
+ * @param path Path to the certificate file
+ * @param format Certificate file format
+ */
+ virtual void loadCertificate(const char* path, const char* format = "PEM");
+ /**
+ * Load private key.
+ *
+ * @param path Path to the private key file
+ * @param format Private key file format
+ */
+ virtual void loadPrivateKey(const char* path, const char* format = "PEM");
+ /**
+ * Load trusted certificates from specified file.
+ *
+ * @param path Path to trusted certificate file
+ */
+ virtual void loadTrustedCertificates(const char* path);
+ /**
+ * Load trusted certificates from specified X509 certificate store.
+ *
+ * @param store X509 certificate store.
+ */
+ virtual void loadTrustedCertificates(X509_STORE* store);
+ /**
+ * Load a client CA list for validating clients
+ */
+ virtual void loadClientCAList(const char* path);
+ /**
+ * Default randomize method.
+ */
+ virtual void randomize();
+ /**
+ * Override default OpenSSL password collector.
+ *
+ * @param collector Instance of user defined password collector
+ */
+ virtual void passwordCollector(std::shared_ptr<PasswordCollector> collector);
+ /**
+ * Obtain password collector.
+ *
+ * @return User defined password collector
+ */
+ virtual std::shared_ptr<PasswordCollector> passwordCollector() {
+ return collector_;
+ }
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+ /**
+ * Provide SNI support
+ */
+ enum ServerNameCallbackResult {
+ SERVER_NAME_FOUND,
+ SERVER_NAME_NOT_FOUND,
+ SERVER_NAME_NOT_FOUND_ALERT_FATAL,
+ };
+ /**
+ * Callback function from openssl to give the application a
+ * chance to check the tlsext_hostname just right after parsing
+ * the Client Hello or Server Hello message.
+ *
+ * It is for the server to switch the SSL to another SSL_CTX
+ * to continue the handshake. (i.e. Server Name Indication, SNI, in RFC6066).
+ *
+ * If the ServerNameCallback returns:
+ * SERVER_NAME_FOUND:
+ * server: Send a tlsext_hostname in the Server Hello
+ * client: No-effect
+ * SERVER_NAME_NOT_FOUND:
+ * server: Does not send a tlsext_hostname in Server Hello
+ * and continue the handshake.
+ * client: No-effect
+ * SERVER_NAME_NOT_FOUND_ALERT_FATAL:
+ * server and client: Send fatal TLS1_AD_UNRECOGNIZED_NAME alert to
+ * the peer.
+ *
+ * Quote from RFC 6066:
+ * "...
+ * If the server understood the ClientHello extension but
+ * does not recognize the server name, the server SHOULD take one of two
+ * actions: either abort the handshake by sending a fatal-level
+ * unrecognized_name(112) alert or continue the handshake. It is NOT
+ * RECOMMENDED to send a warning-level unrecognized_name(112) alert,
+ * because the client's behavior in response to warning-level alerts is
+ * unpredictable.
+ * ..."
+ */
+
+ /**
+ * Set the ServerNameCallback
+ */
+ typedef std::function<ServerNameCallbackResult(SSL* ssl)> ServerNameCallback;
+ virtual void setServerNameCallback(const ServerNameCallback& cb);
+
+ /**
+ * Generic callbacks that are run after we get the Client Hello (right
+ * before we run the ServerNameCallback)
+ */
+ typedef std::function<void(SSL* ssl)> ClientHelloCallback;
+ virtual void addClientHelloCallback(const ClientHelloCallback& cb);
+#endif
+
+ /**
+ * Create an SSL object from this context.
+ */
+ SSL* createSSL() const;
+
+ /**
+ * Set the options on the SSL_CTX object.
+ */
+ void setOptions(long options);
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+ /**
+ * Set the list of protocols that this SSL context supports. In server
+ * mode, this is the list of protocols that will be advertised for Next
+ * Protocol Negotiation (NPN). In client mode, the first protocol
+ * advertised by the server that is also on this list is
+ * chosen. Invoking this function with a list of length zero causes NPN
+ * to be disabled.
+ *
+ * @param protocols List of protocol names. This method makes a copy,
+ * so the caller needn't keep the list in scope after
+ * the call completes. The list must have at least
+ * one element to enable NPN. Each element must have
+ * a string length < 256.
+ * @return true if NPN has been activated. False if NPN is disabled.
+ */
+ bool setAdvertisedNextProtocols(const std::list<std::string>& protocols);
+ /**
+ * Set weighted list of lists of protocols that this SSL context supports.
+ * In server mode, each element of the list contains a list of protocols that
+ * could be advertised for Next Protocol Negotiation (NPN). The list of
+ * protocols that will be advertised to a client is selected randomly, based
+ * on weights of elements. Client mode doesn't support randomized NPN, so
+ * this list should contain only 1 element. The first protocol advertised
+ * by the server that is also on the list of protocols of this element is
+ * chosen. Invoking this function with a list of length zero causes NPN
+ * to be disabled.
+ *
+ * @param items List of NextProtocolsItems, Each item contains a list of
+ * protocol names and weight. After the call of this fucntion
+ * each non-empty list of protocols will be advertised with
+ * probability weight/sum_of_weights. This method makes a copy,
+ * so the caller needn't keep the list in scope after the call
+ * completes. The list must have at least one element with
+ * non-zero weight and non-empty protocols list to enable NPN.
+ * Each name of the protocol must have a string length < 256.
+ * @return true if NPN has been activated. False if NPN is disabled.
+ */
+ bool setRandomizedAdvertisedNextProtocols(
+ const std::list<NextProtocolsItem>& items);
+
+ /**
+ * Disables NPN on this SSL context.
+ */
+ void unsetNextProtocols();
+ void deleteNextProtocolsStrings();
+#endif // OPENSSL_NPN_NEGOTIATED
+
+ /**
+ * Gets the underlying SSL_CTX for advanced usage
+ */
+ SSL_CTX *getSSLCtx() const {
+ return ctx_;
+ }
+
+ enum SSLLockType {
+ LOCK_MUTEX,
+ LOCK_SPINLOCK,
+ LOCK_NONE
+ };
+
+ /**
+ * Set preferences for how to treat locks in OpenSSL. This must be
+ * called before the instantiation of any SSLContext objects, otherwise
+ * the defaults will be used.
+ *
+ * OpenSSL has a lock for each module rather than for each object or
+ * data that needs locking. Some locks protect only refcounts, and
+ * might be better as spinlocks rather than mutexes. Other locks
+ * may be totally unnecessary if the objects being protected are not
+ * shared between threads in the application.
+ *
+ * By default, all locks are initialized as mutexes. OpenSSL's lock usage
+ * may change from version to version and you should know what you are doing
+ * before disabling any locks entirely.
+ *
+ * Example: if you don't share SSL sessions between threads in your
+ * application, you may be able to do this
+ *
+ * setSSLLockTypes({{CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_NONE}})
+ */
+ static void setSSLLockTypes(std::map<int, SSLLockType> lockTypes);
+
+ /**
+ * Examine OpenSSL's error stack, and return a string description of the
+ * errors.
+ *
+ * This operation removes the errors from OpenSSL's error stack.
+ */
+ static std::string getErrors(int errnoCopy);
+
+ /**
+ * We want to vary which cipher we'll use based on the client's TLS version.
+ */
+ void switchCiphersIfTLS11(
+ SSL* ssl,
+ const std::string& tls11CipherString
+ );
+
+ bool checkPeerName() { return checkPeerName_; }
+ std::string peerFixedName() { return peerFixedName_; }
+
+ /**
+ * Helper to match a hostname versus a pattern.
+ */
+ static bool matchName(const char* host, const char* pattern, int size);
+
+ protected:
+ SSL_CTX* ctx_;
+
+ private:
+ SSLVerifyPeerEnum verifyPeer_{SSLVerifyPeerEnum::NO_VERIFY};
+
+ bool checkPeerName_;
+ std::string peerFixedName_;
+ std::shared_ptr<PasswordCollector> collector_;
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+ ServerNameCallback serverNameCb_;
+ std::vector<ClientHelloCallback> clientHelloCbs_;
+#endif
+
+ static std::mutex mutex_;
+ static uint64_t count_;
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+ /**
+ * Wire-format list of advertised protocols for use in NPN.
+ */
+ std::vector<AdvertisedNextProtocolsItem> advertisedNextProtocols_;
+ static int sNextProtocolsExDataIndex_;
+
+ static int advertisedNextProtocolCallback(SSL* ssl,
+ const unsigned char** out, unsigned int* outlen, void* data);
+ static int selectNextProtocolCallback(
+ SSL* ssl, unsigned char **out, unsigned char *outlen,
+ const unsigned char *server, unsigned int server_len, void *args);
+#endif // OPENSSL_NPN_NEGOTIATED
+
+ static int passwordCallback(char* password, int size, int, void* data);
+
+ static void initializeOpenSSL();
+ static void cleanupOpenSSL();
+
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+ /**
+ * The function that will be called directly from openssl
+ * in order for the application to get the tlsext_hostname just after
+ * parsing the Client Hello or Server Hello message. It will then call
+ * the serverNameCb_ function object. Hence, it is sort of a
+ * wrapper/proxy between serverNameCb_ and openssl.
+ *
+ * The openssl's primary intention is for SNI support, but we also use it
+ * generically for performing logic after the Client Hello comes in.
+ */
+ static int baseServerNameOpenSSLCallback(
+ SSL* ssl,
+ int* al /* alert (return value) */,
+ void* data
+ );
+#endif
+
+ std::string providedCiphersString_;
+};
+
+typedef std::shared_ptr<SSLContext> SSLContextPtr;
+
+std::ostream& operator<<(std::ostream& os, const folly::PasswordCollector& collector);
+
+} // folly