SSLContext
authorDave Watson <davejwatson@fb.com>
Wed, 1 Oct 2014 17:51:40 +0000 (10:51 -0700)
committerPavlo Kushnir <pavlo@fb.com>
Sat, 8 Nov 2014 02:16:12 +0000 (18:16 -0800)
Summary:
Move SSLContext to folly.

Test Plan: It builds.

Reviewed By: dcsommer@fb.com

Subscribers: jdperlow, atlas2-eng@, wormhole-diffs@, bwester, trunkagent, doug, ps, bmatheny, ssl-diffs@, alikhtarov, njormrod, mshneer, folly-diffs@, andrewcox, alandau, jsedgwick, fugalh

FB internal diff: D1631924

Signature: t1:1631924:1414616562:9a67dbf20f00eb8fbcb35880efcb94c0fae07dcc

folly/Makefile.am
folly/io/async/SSLContext.cpp [new file with mode: 0644]
folly/io/async/SSLContext.h [new file with mode: 0644]

index 76ad6a30ea02aad1809f2cd16b402b2c1cb1f065..dde8c28533ba694735a57b6105533aa0ca8ab916 100644 (file)
@@ -143,6 +143,7 @@ nobase_follyinclude_HEADERS = \
        io/async/NotificationQueue.h \
        io/async/HHWheelTimer.h \
        io/async/Request.h \
+       io/async/SSLContext.h \
        io/async/TimeoutManager.h \
        json.h \
        Lazy.h \
@@ -269,6 +270,7 @@ libfolly_la_SOURCES = \
        io/async/EventBaseManager.cpp \
        io/async/EventHandler.cpp \
        io/async/Request.cpp \
+       io/async/SSLContext.cpp \
        io/async/HHWheelTimer.cpp \
        json.cpp \
        detail/MemoryIdler.cpp \
diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp
new file mode 100644 (file)
index 0000000..c50e257
--- /dev/null
@@ -0,0 +1,643 @@
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "SSLContext.h"
+
+#include <openssl/err.h>
+#include <openssl/rand.h>
+#include <openssl/ssl.h>
+#include <openssl/x509v3.h>
+
+#include <folly/SmallLocks.h>
+#include <folly/Format.h>
+#include <folly/io/PortableSpinLock.h>
+
+// ---------------------------------------------------------------------
+// SSLContext implementation
+// ---------------------------------------------------------------------
+
+struct CRYPTO_dynlock_value {
+  std::mutex mutex;
+};
+
+namespace folly {
+
+uint64_t SSLContext::count_ = 0;
+std::mutex    SSLContext::mutex_;
+#ifdef OPENSSL_NPN_NEGOTIATED
+int SSLContext::sNextProtocolsExDataIndex_ = -1;
+
+#endif
+// SSLContext implementation
+SSLContext::SSLContext(SSLVersion version) {
+  {
+    std::lock_guard<std::mutex> g(mutex_);
+    if (!count_++) {
+      initializeOpenSSL();
+      randomize();
+#ifdef OPENSSL_NPN_NEGOTIATED
+      sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
+          (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
+#endif
+    }
+  }
+
+  ctx_ = SSL_CTX_new(SSLv23_method());
+  if (ctx_ == nullptr) {
+    throw std::runtime_error("SSL_CTX_new: " + getErrors());
+  }
+
+  int opt = 0;
+  switch (version) {
+    case TLSv1:
+      opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
+      break;
+    case SSLv3:
+      opt = SSL_OP_NO_SSLv2;
+      break;
+    default:
+      // do nothing
+      break;
+  }
+  int newOpt = SSL_CTX_set_options(ctx_, opt);
+  DCHECK((newOpt & opt) == opt);
+
+  SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
+
+  checkPeerName_ = false;
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+  SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
+  SSL_CTX_set_tlsext_servername_arg(ctx_, this);
+#endif
+}
+
+SSLContext::~SSLContext() {
+  if (ctx_ != nullptr) {
+    SSL_CTX_free(ctx_);
+    ctx_ = nullptr;
+  }
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+  deleteNextProtocolsStrings();
+#endif
+
+  std::lock_guard<std::mutex> g(mutex_);
+  if (!--count_) {
+    cleanupOpenSSL();
+  }
+}
+
+void SSLContext::ciphers(const std::string& ciphers) {
+  providedCiphersString_ = ciphers;
+  setCiphersOrThrow(ciphers);
+}
+
+void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
+  int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
+  if (ERR_peek_error() != 0) {
+    throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
+  }
+  if (rc == 0) {
+    throw std::runtime_error("None of specified ciphers are supported");
+  }
+}
+
+void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
+    verifyPeer) {
+  CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
+  verifyPeer_ = verifyPeer;
+}
+
+int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
+    verifyPeer) {
+  CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
+  int mode = SSL_VERIFY_NONE;
+  switch(verifyPeer) {
+    // case SSLVerifyPeerEnum::USE_CTX: // can't happen
+    // break;
+
+    case SSLVerifyPeerEnum::VERIFY:
+      mode = SSL_VERIFY_PEER;
+      break;
+
+    case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
+      mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
+      break;
+
+    case SSLVerifyPeerEnum::NO_VERIFY:
+      mode = SSL_VERIFY_NONE;
+      break;
+
+    default:
+      break;
+  }
+  return mode;
+}
+
+int SSLContext::getVerificationMode() {
+  return getVerificationMode(verifyPeer_);
+}
+
+void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
+                              const std::string& peerName) {
+  int mode;
+  if (checkPeerCert) {
+    mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
+    checkPeerName_ = checkPeerName;
+    peerFixedName_ = peerName;
+  } else {
+    mode = SSL_VERIFY_NONE;
+    checkPeerName_ = false; // can't check name without cert!
+    peerFixedName_.clear();
+  }
+  SSL_CTX_set_verify(ctx_, mode, nullptr);
+}
+
+void SSLContext::loadCertificate(const char* path, const char* format) {
+  if (path == nullptr || format == nullptr) {
+    throw std::invalid_argument(
+         "loadCertificateChain: either <path> or <format> is nullptr");
+  }
+  if (strcmp(format, "PEM") == 0) {
+    if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
+      int errnoCopy = errno;
+      std::string reason("SSL_CTX_use_certificate_chain_file: ");
+      reason.append(path);
+      reason.append(": ");
+      reason.append(getErrors(errnoCopy));
+      throw std::runtime_error(reason);
+    }
+  } else {
+    throw std::runtime_error("Unsupported certificate format: " + std::string(format));
+  }
+}
+
+void SSLContext::loadPrivateKey(const char* path, const char* format) {
+  if (path == nullptr || format == nullptr) {
+    throw std::invalid_argument(
+         "loadPrivateKey: either <path> or <format> is nullptr");
+  }
+  if (strcmp(format, "PEM") == 0) {
+    if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
+      throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
+    }
+  } else {
+    throw std::runtime_error("Unsupported private key format: " + std::string(format));
+  }
+}
+
+void SSLContext::loadTrustedCertificates(const char* path) {
+  if (path == nullptr) {
+    throw std::invalid_argument(
+         "loadTrustedCertificates: <path> is nullptr");
+  }
+  if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
+    throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
+  }
+}
+
+void SSLContext::loadTrustedCertificates(X509_STORE* store) {
+  SSL_CTX_set_cert_store(ctx_, store);
+}
+
+void SSLContext::loadClientCAList(const char* path) {
+  auto clientCAs = SSL_load_client_CA_file(path);
+  if (clientCAs == nullptr) {
+    LOG(ERROR) << "Unable to load ca file: " << path;
+    return;
+  }
+  SSL_CTX_set_client_CA_list(ctx_, clientCAs);
+}
+
+void SSLContext::randomize() {
+  RAND_poll();
+}
+
+void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
+  if (collector == nullptr) {
+    LOG(ERROR) << "passwordCollector: ignore invalid password collector";
+    return;
+  }
+  collector_ = collector;
+  SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
+  SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
+}
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+
+void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
+  serverNameCb_ = cb;
+}
+
+void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
+  clientHelloCbs_.push_back(cb);
+}
+
+int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
+  SSLContext* context = (SSLContext*)data;
+
+  if (context == nullptr) {
+    return SSL_TLSEXT_ERR_NOACK;
+  }
+
+  for (auto& cb : context->clientHelloCbs_) {
+    // Generic callbacks to happen after we receive the Client Hello.
+    // For example, we use one to switch which cipher we use depending
+    // on the user's TLS version.  Because the primary purpose of
+    // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
+    // are side-uses, we ignore any possible failures other than just logging
+    // them.
+    cb(ssl);
+  }
+
+  if (!context->serverNameCb_) {
+    return SSL_TLSEXT_ERR_NOACK;
+  }
+
+  ServerNameCallbackResult ret = context->serverNameCb_(ssl);
+  switch (ret) {
+    case SERVER_NAME_FOUND:
+      return SSL_TLSEXT_ERR_OK;
+    case SERVER_NAME_NOT_FOUND:
+      return SSL_TLSEXT_ERR_NOACK;
+    case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
+      *al = TLS1_AD_UNRECOGNIZED_NAME;
+      return SSL_TLSEXT_ERR_ALERT_FATAL;
+    default:
+      CHECK(false);
+  }
+
+  return SSL_TLSEXT_ERR_NOACK;
+}
+
+void SSLContext::switchCiphersIfTLS11(
+    SSL* ssl,
+    const std::string& tls11CipherString) {
+
+  CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
+
+  if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
+    // We only do this for TLS v 1.1 and later
+    return;
+  }
+
+  // Prefer AES for TLS versions 1.1 and later since these are not
+  // vulnerable to BEAST attacks on AES.  Note that we're setting the
+  // cipher list on the SSL object, not the SSL_CTX object, so it will
+  // only last for this request.
+  int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
+  if ((rc == 0) || ERR_peek_error() != 0) {
+    // This shouldn't happen since we checked for this when proxygen
+    // started up.
+    LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
+    SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
+  }
+}
+#endif
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
+  return setRandomizedAdvertisedNextProtocols({{1, protocols}});
+}
+
+bool SSLContext::setRandomizedAdvertisedNextProtocols(
+    const std::list<NextProtocolsItem>& items) {
+  unsetNextProtocols();
+  if (items.size() == 0) {
+    return false;
+  }
+  int total_weight = 0;
+  for (const auto &item : items) {
+    if (item.protocols.size() == 0) {
+      continue;
+    }
+    AdvertisedNextProtocolsItem advertised_item;
+    advertised_item.length = 0;
+    for (const auto& proto : item.protocols) {
+      ++advertised_item.length;
+      unsigned protoLength = proto.length();
+      if (protoLength >= 256) {
+        deleteNextProtocolsStrings();
+        return false;
+      }
+      advertised_item.length += protoLength;
+    }
+    advertised_item.protocols = new unsigned char[advertised_item.length];
+    if (!advertised_item.protocols) {
+      throw std::runtime_error("alloc failure");
+    }
+    unsigned char* dst = advertised_item.protocols;
+    for (auto& proto : item.protocols) {
+      unsigned protoLength = proto.length();
+      *dst++ = (unsigned char)protoLength;
+      memcpy(dst, proto.data(), protoLength);
+      dst += protoLength;
+    }
+    total_weight += item.weight;
+    advertised_item.probability = item.weight;
+    advertisedNextProtocols_.push_back(advertised_item);
+  }
+  if (total_weight == 0) {
+    deleteNextProtocolsStrings();
+    return false;
+  }
+  for (auto &advertised_item : advertisedNextProtocols_) {
+    advertised_item.probability /= total_weight;
+  }
+  SSL_CTX_set_next_protos_advertised_cb(
+    ctx_, advertisedNextProtocolCallback, this);
+  SSL_CTX_set_next_proto_select_cb(
+    ctx_, selectNextProtocolCallback, this);
+  return true;
+}
+
+void SSLContext::deleteNextProtocolsStrings() {
+  for (auto protocols : advertisedNextProtocols_) {
+    delete[] protocols.protocols;
+  }
+  advertisedNextProtocols_.clear();
+}
+
+void SSLContext::unsetNextProtocols() {
+  deleteNextProtocolsStrings();
+  SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
+  SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
+}
+
+int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
+      const unsigned char** out, unsigned int* outlen, void* data) {
+  SSLContext* context = (SSLContext*)data;
+  if (context == nullptr || context->advertisedNextProtocols_.empty()) {
+    *out = nullptr;
+    *outlen = 0;
+  } else if (context->advertisedNextProtocols_.size() == 1) {
+    *out = context->advertisedNextProtocols_[0].protocols;
+    *outlen = context->advertisedNextProtocols_[0].length;
+  } else {
+    uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
+          sNextProtocolsExDataIndex_));
+    if (selected_index) {
+      --selected_index;
+      *out = context->advertisedNextProtocols_[selected_index].protocols;
+      *outlen = context->advertisedNextProtocols_[selected_index].length;
+    } else {
+      unsigned char random_byte;
+      RAND_bytes(&random_byte, 1);
+      double random_value = random_byte / 255.0;
+      double sum = 0;
+      for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
+        sum += context->advertisedNextProtocols_[i].probability;
+        if (sum < random_value &&
+            i + 1 < context->advertisedNextProtocols_.size()) {
+          continue;
+        }
+        uintptr_t selected = i + 1;
+        SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
+        *out = context->advertisedNextProtocols_[i].protocols;
+        *outlen = context->advertisedNextProtocols_[i].length;
+        break;
+      }
+    }
+  }
+  return SSL_TLSEXT_ERR_OK;
+}
+
+int SSLContext::selectNextProtocolCallback(
+  SSL* ssl, unsigned char **out, unsigned char *outlen,
+  const unsigned char *server, unsigned int server_len, void *data) {
+
+  SSLContext* ctx = (SSLContext*)data;
+  if (ctx->advertisedNextProtocols_.size() > 1) {
+    VLOG(3) << "SSLContext::selectNextProcolCallback() "
+            << "client should be deterministic in selecting protocols.";
+  }
+
+  unsigned char *client;
+  int client_len;
+  if (ctx->advertisedNextProtocols_.empty()) {
+    client = (unsigned char *) "";
+    client_len = 0;
+  } else {
+    client = ctx->advertisedNextProtocols_[0].protocols;
+    client_len = ctx->advertisedNextProtocols_[0].length;
+  }
+
+  int retval = SSL_select_next_proto(out, outlen, server, server_len,
+                                     client, client_len);
+  if (retval != OPENSSL_NPN_NEGOTIATED) {
+    VLOG(3) << "SSLContext::selectNextProcolCallback() "
+            << "unable to pick a next protocol.";
+  }
+  return SSL_TLSEXT_ERR_OK;
+}
+#endif // OPENSSL_NPN_NEGOTIATED
+
+SSL* SSLContext::createSSL() const {
+  SSL* ssl = SSL_new(ctx_);
+  if (ssl == nullptr) {
+    throw std::runtime_error("SSL_new: " + getErrors());
+  }
+  return ssl;
+}
+
+/**
+ * Match a name with a pattern. The pattern may include wildcard. A single
+ * wildcard "*" can match up to one component in the domain name.
+ *
+ * @param  host    Host name, typically the name of the remote host
+ * @param  pattern Name retrieved from certificate
+ * @param  size    Size of "pattern"
+ * @return True, if "host" matches "pattern". False otherwise.
+ */
+bool SSLContext::matchName(const char* host, const char* pattern, int size) {
+  bool match = false;
+  int i = 0, j = 0;
+  while (i < size && host[j] != '\0') {
+    if (toupper(pattern[i]) == toupper(host[j])) {
+      i++;
+      j++;
+      continue;
+    }
+    if (pattern[i] == '*') {
+      while (host[j] != '.' && host[j] != '\0') {
+        j++;
+      }
+      i++;
+      continue;
+    }
+    break;
+  }
+  if (i == size && host[j] == '\0') {
+    match = true;
+  }
+  return match;
+}
+
+int SSLContext::passwordCallback(char* password,
+                                 int size,
+                                 int,
+                                 void* data) {
+  SSLContext* context = (SSLContext*)data;
+  if (context == nullptr || context->passwordCollector() == nullptr) {
+    return 0;
+  }
+  std::string userPassword;
+  // call user defined password collector to get password
+  context->passwordCollector()->getPassword(userPassword, size);
+  int length = userPassword.size();
+  if (length > size) {
+    length = size;
+  }
+  strncpy(password, userPassword.c_str(), length);
+  return length;
+}
+
+struct SSLLock {
+  explicit SSLLock(
+    SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
+      lockType(inLockType) {
+  }
+
+  void lock() {
+    if (lockType == SSLContext::LOCK_MUTEX) {
+      mutex.lock();
+    } else if (lockType == SSLContext::LOCK_SPINLOCK) {
+      spinLock.lock();
+    }
+    // lockType == LOCK_NONE, no-op
+  }
+
+  void unlock() {
+    if (lockType == SSLContext::LOCK_MUTEX) {
+      mutex.unlock();
+    } else if (lockType == SSLContext::LOCK_SPINLOCK) {
+      spinLock.unlock();
+    }
+    // lockType == LOCK_NONE, no-op
+  }
+
+  SSLContext::SSLLockType lockType;
+  folly::io::PortableSpinLock spinLock{};
+  std::mutex mutex;
+};
+
+static std::map<int, SSLContext::SSLLockType> lockTypes;
+static std::unique_ptr<SSLLock[]> locks;
+
+static void callbackLocking(int mode, int n, const char*, int) {
+  if (mode & CRYPTO_LOCK) {
+    locks[n].lock();
+  } else {
+    locks[n].unlock();
+  }
+}
+
+static unsigned long callbackThreadID() {
+  return static_cast<unsigned long>(pthread_self());
+}
+
+static CRYPTO_dynlock_value* dyn_create(const char*, int) {
+  return new CRYPTO_dynlock_value;
+}
+
+static void dyn_lock(int mode,
+                     struct CRYPTO_dynlock_value* lock,
+                     const char*, int) {
+  if (lock != nullptr) {
+    if (mode & CRYPTO_LOCK) {
+      lock->mutex.lock();
+    } else {
+      lock->mutex.unlock();
+    }
+  }
+}
+
+static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
+  delete lock;
+}
+
+void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
+  lockTypes = inLockTypes;
+}
+
+void SSLContext::initializeOpenSSL() {
+  SSL_library_init();
+  SSL_load_error_strings();
+  ERR_load_crypto_strings();
+  // static locking
+  locks.reset(new SSLLock[::CRYPTO_num_locks()]);
+  for (auto it: lockTypes) {
+    locks[it.first].lockType = it.second;
+  }
+  CRYPTO_set_id_callback(callbackThreadID);
+  CRYPTO_set_locking_callback(callbackLocking);
+  // dynamic locking
+  CRYPTO_set_dynlock_create_callback(dyn_create);
+  CRYPTO_set_dynlock_lock_callback(dyn_lock);
+  CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
+}
+
+void SSLContext::cleanupOpenSSL() {
+  CRYPTO_set_id_callback(nullptr);
+  CRYPTO_set_locking_callback(nullptr);
+  CRYPTO_set_dynlock_create_callback(nullptr);
+  CRYPTO_set_dynlock_lock_callback(nullptr);
+  CRYPTO_set_dynlock_destroy_callback(nullptr);
+  CRYPTO_cleanup_all_ex_data();
+  ERR_free_strings();
+  EVP_cleanup();
+  ERR_remove_state(0);
+  locks.reset();
+}
+
+void SSLContext::setOptions(long options) {
+  long newOpt = SSL_CTX_set_options(ctx_, options);
+  if ((newOpt & options) != options) {
+    throw std::runtime_error("SSL_CTX_set_options failed");
+  }
+}
+
+std::string SSLContext::getErrors(int errnoCopy) {
+  std::string errors;
+  unsigned long  errorCode;
+  char   message[256];
+
+  errors.reserve(512);
+  while ((errorCode = ERR_get_error()) != 0) {
+    if (!errors.empty()) {
+      errors += "; ";
+    }
+    const char* reason = ERR_reason_error_string(errorCode);
+    if (reason == nullptr) {
+      snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
+      reason = message;
+    }
+    errors += reason;
+  }
+  if (errors.empty()) {
+    errors = "error code: " + folly::to<std::string>(errnoCopy);
+  }
+  return errors;
+}
+
+std::ostream&
+operator<<(std::ostream& os, const PasswordCollector& collector) {
+  os << collector.describe();
+  return os;
+}
+
+} // folly
diff --git a/folly/io/async/SSLContext.h b/folly/io/async/SSLContext.h
new file mode 100644 (file)
index 0000000..5819f68
--- /dev/null
@@ -0,0 +1,453 @@
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <mutex>
+#include <list>
+#include <map>
+#include <vector>
+#include <memory>
+#include <string>
+
+#include <openssl/ssl.h>
+#include <openssl/tls1.h>
+
+#include <glog/logging.h>
+
+namespace folly {
+
+/**
+ * Override the default password collector.
+ */
+class PasswordCollector {
+ public:
+  virtual ~PasswordCollector() {}
+  /**
+   * Interface for customizing how to collect private key password.
+   *
+   * By default, OpenSSL prints a prompt on screen and request for password
+   * while loading private key. To implement a custom password collector,
+   * implement this interface and register it with TSSLSocketFactory.
+   *
+   * @param password Pass collected password back to OpenSSL
+   * @param size     Maximum length of password including nullptr character
+   */
+  virtual void getPassword(std::string& password, int size) = 0;
+
+  /**
+   * Return a description of this collector for logging purposes
+   */
+  virtual std::string describe() const = 0;
+};
+
+/**
+ * Wrap OpenSSL SSL_CTX into a class.
+ */
+class SSLContext {
+ public:
+
+  enum SSLVersion {
+     SSLv2,
+     SSLv3,
+     TLSv1
+  };
+
+  enum SSLVerifyPeerEnum{
+    USE_CTX,
+    VERIFY,
+    VERIFY_REQ_CLIENT_CERT,
+    NO_VERIFY
+  };
+
+  struct NextProtocolsItem {
+    int weight;
+    std::list<std::string> protocols;
+  };
+
+  struct AdvertisedNextProtocolsItem {
+    unsigned char *protocols;
+    unsigned length;
+    double probability;
+  };
+
+  /**
+   * Convenience function to call getErrors() with the current errno value.
+   *
+   * Make sure that you only call this when there was no intervening operation
+   * since the last OpenSSL error that may have changed the current errno value.
+   */
+  static std::string getErrors() {
+    return getErrors(errno);
+  }
+
+  /**
+   * Constructor.
+   *
+   * @param version The lowest or oldest SSL version to support.
+   */
+  explicit SSLContext(SSLVersion version = TLSv1);
+  virtual ~SSLContext();
+
+  /**
+   * Set default ciphers to be used in SSL handshake process.
+   *
+   * @param ciphers A list of ciphers to use for TLSv1.0
+   */
+  virtual void ciphers(const std::string& ciphers);
+
+  /**
+   * Low-level method that attempts to set the provided ciphers on the
+   * SSL_CTX object, and throws if something goes wrong.
+   */
+  virtual void setCiphersOrThrow(const std::string& ciphers);
+
+  /**
+   * Method to set verification option in the context object.
+   *
+   * @param verifyPeer SSLVerifyPeerEnum indicating the verification
+   *                       method to use.
+   */
+  virtual void setVerificationOption(const SSLVerifyPeerEnum& verifyPeer);
+
+  /**
+   * Method to check if peer verfication is set.
+   *
+   * @return true if peer verification is required.
+   *
+   */
+  virtual bool needsPeerVerification() {
+    return (verifyPeer_ == SSLVerifyPeerEnum::VERIFY ||
+              verifyPeer_ == SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+  }
+
+  /**
+   * Method to fetch Verification mode for a SSLVerifyPeerEnum.
+   * verifyPeer cannot be SSLVerifyPeerEnum::USE_CTX since there is no
+   * context.
+   *
+   * @param verifyPeer SSLVerifyPeerEnum for which the flags need to
+   *                  to be returned
+   *
+   * @return mode flags that can be used with SSL_set_verify
+   */
+  static int getVerificationMode(const SSLVerifyPeerEnum& verifyPeer);
+
+  /**
+   * Method to fetch Verification mode determined by the options
+   * set using setVerificationOption.
+   *
+   * @return mode flags that can be used with SSL_set_verify
+   */
+  virtual int getVerificationMode();
+
+  /**
+   * Enable/Disable authentication. Peer name validation can only be done
+   * if checkPeerCert is true.
+   *
+   * @param checkPeerCert If true, require peer to present valid certificate
+   * @param checkPeerName If true, validate that the certificate common name
+   *                      or alternate name(s) of peer matches the hostname
+   *                      used to connect.
+   * @param peerName      If non-empty, validate that the certificate common
+   *                      name of peer matches the given string (altername
+   *                      name(s) are not used in this case).
+   */
+  virtual void authenticate(bool checkPeerCert, bool checkPeerName,
+                            const std::string& peerName = std::string());
+  /**
+   * Load server certificate.
+   *
+   * @param path   Path to the certificate file
+   * @param format Certificate file format
+   */
+  virtual void loadCertificate(const char* path, const char* format = "PEM");
+  /**
+   * Load private key.
+   *
+   * @param path   Path to the private key file
+   * @param format Private key file format
+   */
+  virtual void loadPrivateKey(const char* path, const char* format = "PEM");
+  /**
+   * Load trusted certificates from specified file.
+   *
+   * @param path Path to trusted certificate file
+   */
+  virtual void loadTrustedCertificates(const char* path);
+  /**
+   * Load trusted certificates from specified X509 certificate store.
+   *
+   * @param store X509 certificate store.
+   */
+  virtual void loadTrustedCertificates(X509_STORE* store);
+  /**
+   * Load a client CA list for validating clients
+   */
+  virtual void loadClientCAList(const char* path);
+  /**
+   * Default randomize method.
+   */
+  virtual void randomize();
+  /**
+   * Override default OpenSSL password collector.
+   *
+   * @param collector Instance of user defined password collector
+   */
+  virtual void passwordCollector(std::shared_ptr<PasswordCollector> collector);
+  /**
+   * Obtain password collector.
+   *
+   * @return User defined password collector
+   */
+  virtual std::shared_ptr<PasswordCollector> passwordCollector() {
+    return collector_;
+  }
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+  /**
+   * Provide SNI support
+   */
+  enum ServerNameCallbackResult {
+    SERVER_NAME_FOUND,
+    SERVER_NAME_NOT_FOUND,
+    SERVER_NAME_NOT_FOUND_ALERT_FATAL,
+  };
+  /**
+   * Callback function from openssl to give the application a
+   * chance to check the tlsext_hostname just right after parsing
+   * the Client Hello or Server Hello message.
+   *
+   * It is for the server to switch the SSL to another SSL_CTX
+   * to continue the handshake. (i.e. Server Name Indication, SNI, in RFC6066).
+   *
+   * If the ServerNameCallback returns:
+   * SERVER_NAME_FOUND:
+   *    server: Send a tlsext_hostname in the Server Hello
+   *    client: No-effect
+   * SERVER_NAME_NOT_FOUND:
+   *    server: Does not send a tlsext_hostname in Server Hello
+   *            and continue the handshake.
+   *    client: No-effect
+   * SERVER_NAME_NOT_FOUND_ALERT_FATAL:
+   *    server and client: Send fatal TLS1_AD_UNRECOGNIZED_NAME alert to
+   *                       the peer.
+   *
+   * Quote from RFC 6066:
+   * "...
+   * If the server understood the ClientHello extension but
+   * does not recognize the server name, the server SHOULD take one of two
+   * actions: either abort the handshake by sending a fatal-level
+   * unrecognized_name(112) alert or continue the handshake.  It is NOT
+   * RECOMMENDED to send a warning-level unrecognized_name(112) alert,
+   * because the client's behavior in response to warning-level alerts is
+   * unpredictable.
+   * ..."
+   */
+
+  /**
+   * Set the ServerNameCallback
+   */
+  typedef std::function<ServerNameCallbackResult(SSL* ssl)> ServerNameCallback;
+  virtual void setServerNameCallback(const ServerNameCallback& cb);
+
+  /**
+   * Generic callbacks that are run after we get the Client Hello (right
+   * before we run the ServerNameCallback)
+   */
+  typedef std::function<void(SSL* ssl)> ClientHelloCallback;
+  virtual void addClientHelloCallback(const ClientHelloCallback& cb);
+#endif
+
+  /**
+   * Create an SSL object from this context.
+   */
+  SSL* createSSL() const;
+
+  /**
+   * Set the options on the SSL_CTX object.
+   */
+  void setOptions(long options);
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+  /**
+   * Set the list of protocols that this SSL context supports. In server
+   * mode, this is the list of protocols that will be advertised for Next
+   * Protocol Negotiation (NPN). In client mode, the first protocol
+   * advertised by the server that is also on this list is
+   * chosen. Invoking this function with a list of length zero causes NPN
+   * to be disabled.
+   *
+   * @param protocols   List of protocol names. This method makes a copy,
+   *                    so the caller needn't keep the list in scope after
+   *                    the call completes. The list must have at least
+   *                    one element to enable NPN. Each element must have
+   *                    a string length < 256.
+   * @return true if NPN has been activated. False if NPN is disabled.
+   */
+  bool setAdvertisedNextProtocols(const std::list<std::string>& protocols);
+  /**
+   * Set weighted list of lists of protocols that this SSL context supports.
+   * In server mode, each element of the list contains a list of protocols that
+   * could be advertised for Next Protocol Negotiation (NPN). The list of
+   * protocols that will be advertised to a client is selected randomly, based
+   * on weights of elements. Client mode doesn't support randomized NPN, so
+   * this list should contain only 1 element. The first protocol advertised
+   * by the server that is also on the list of protocols of this element is
+   * chosen. Invoking this function with a list of length zero causes NPN
+   * to be disabled.
+   *
+   * @param items  List of NextProtocolsItems, Each item contains a list of
+   *               protocol names and weight. After the call of this fucntion
+   *               each non-empty list of protocols will be advertised with
+   *               probability weight/sum_of_weights. This method makes a copy,
+   *               so the caller needn't keep the list in scope after the call
+   *               completes. The list must have at least one element with
+   *               non-zero weight and non-empty protocols list to enable NPN.
+   *               Each name of the protocol must have a string length < 256.
+   * @return true if NPN has been activated. False if NPN is disabled.
+   */
+  bool setRandomizedAdvertisedNextProtocols(
+      const std::list<NextProtocolsItem>& items);
+
+  /**
+   * Disables NPN on this SSL context.
+   */
+  void unsetNextProtocols();
+  void deleteNextProtocolsStrings();
+#endif // OPENSSL_NPN_NEGOTIATED
+
+  /**
+   * Gets the underlying SSL_CTX for advanced usage
+   */
+  SSL_CTX *getSSLCtx() const {
+    return ctx_;
+  }
+
+  enum SSLLockType {
+    LOCK_MUTEX,
+    LOCK_SPINLOCK,
+    LOCK_NONE
+  };
+
+  /**
+   * Set preferences for how to treat locks in OpenSSL.  This must be
+   * called before the instantiation of any SSLContext objects, otherwise
+   * the defaults will be used.
+   *
+   * OpenSSL has a lock for each module rather than for each object or
+   * data that needs locking.  Some locks protect only refcounts, and
+   * might be better as spinlocks rather than mutexes.  Other locks
+   * may be totally unnecessary if the objects being protected are not
+   * shared between threads in the application.
+   *
+   * By default, all locks are initialized as mutexes.  OpenSSL's lock usage
+   * may change from version to version and you should know what you are doing
+   * before disabling any locks entirely.
+   *
+   * Example: if you don't share SSL sessions between threads in your
+   * application, you may be able to do this
+   *
+   * setSSLLockTypes({{CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_NONE}})
+   */
+  static void setSSLLockTypes(std::map<int, SSLLockType> lockTypes);
+
+  /**
+   * Examine OpenSSL's error stack, and return a string description of the
+   * errors.
+   *
+   * This operation removes the errors from OpenSSL's error stack.
+   */
+  static std::string getErrors(int errnoCopy);
+
+  /**
+   * We want to vary which cipher we'll use based on the client's TLS version.
+   */
+  void switchCiphersIfTLS11(
+    SSL* ssl,
+    const std::string& tls11CipherString
+  );
+
+  bool checkPeerName() { return checkPeerName_; }
+  std::string peerFixedName() { return peerFixedName_; }
+
+  /**
+   * Helper to match a hostname versus a pattern.
+   */
+  static bool matchName(const char* host, const char* pattern, int size);
+
+ protected:
+  SSL_CTX* ctx_;
+
+ private:
+  SSLVerifyPeerEnum verifyPeer_{SSLVerifyPeerEnum::NO_VERIFY};
+
+  bool checkPeerName_;
+  std::string peerFixedName_;
+  std::shared_ptr<PasswordCollector> collector_;
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+  ServerNameCallback serverNameCb_;
+  std::vector<ClientHelloCallback> clientHelloCbs_;
+#endif
+
+  static std::mutex mutex_;
+  static uint64_t count_;
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+  /**
+   * Wire-format list of advertised protocols for use in NPN.
+   */
+  std::vector<AdvertisedNextProtocolsItem> advertisedNextProtocols_;
+  static int sNextProtocolsExDataIndex_;
+
+  static int advertisedNextProtocolCallback(SSL* ssl,
+      const unsigned char** out, unsigned int* outlen, void* data);
+  static int selectNextProtocolCallback(
+    SSL* ssl, unsigned char **out, unsigned char *outlen,
+    const unsigned char *server, unsigned int server_len, void *args);
+#endif // OPENSSL_NPN_NEGOTIATED
+
+  static int passwordCallback(char* password, int size, int, void* data);
+
+  static void initializeOpenSSL();
+  static void cleanupOpenSSL();
+
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+  /**
+   * The function that will be called directly from openssl
+   * in order for the application to get the tlsext_hostname just after
+   * parsing the Client Hello or Server Hello message. It will then call
+   * the serverNameCb_ function object. Hence, it is sort of a
+   * wrapper/proxy between serverNameCb_ and openssl.
+   *
+   * The openssl's primary intention is for SNI support, but we also use it
+   * generically for performing logic after the Client Hello comes in.
+   */
+  static int baseServerNameOpenSSLCallback(
+    SSL* ssl,
+    int* al /* alert (return value) */,
+    void* data
+  );
+#endif
+
+  std::string providedCiphersString_;
+};
+
+typedef std::shared_ptr<SSLContext> SSLContextPtr;
+
+std::ostream& operator<<(std::ostream& os, const folly::PasswordCollector& collector);
+
+} // folly