Move Acceptor to wangle
authorDave Watson <davejwatson@fb.com>
Thu, 23 Oct 2014 17:26:08 +0000 (10:26 -0700)
committerDave Watson <davejwatson@fb.com>
Wed, 19 Nov 2014 20:52:27 +0000 (12:52 -0800)
Summary:
Initial pass at moving acceptor to wangle.  Involves moving most of the config stuff from proxygen/lib/services, and *all* of the ssl stuff from proxygen/lib/ssl.

Only minor changes:
* Acceptor can be overriden to use thrift socket types, so I don't have to change TTransportException everywhere just yet
* proxygen::Exception to std::runtime_exception in a few spots - looks like it is entirely bad config exceptions, so it should be okay
* Just used std::chrono directly instead of stuff in Time.h (which is just typedefs and simple helpers)

Test Plan:
used in D1539327

fbconfig -r proxygen/httpserver; fbmake runtests

Probably other projects are broken, will iterate to fix

None of the failling tests look related

Reviewed By: dcsommer@fb.com

Subscribers: oleksandr, netego-diffs@, hphp-diffs@, ps, trunkagent, doug, fugalh, alandau, bmatheny, njormrod, mshneer, folly-diffs@

FB internal diff: D1638358

Tasks: 5002353

Signature: t1:1638358:1414526683:87a405e3c24711078707c00b62a50b0e960bf126

32 files changed:
folly/Makefile.am
folly/experimental/wangle/acceptor/Acceptor.cpp [new file with mode: 0644]
folly/experimental/wangle/acceptor/Acceptor.h [new file with mode: 0644]
folly/experimental/wangle/acceptor/ConnectionCounter.h [new file with mode: 0644]
folly/experimental/wangle/acceptor/DomainNameMisc.h [new file with mode: 0644]
folly/experimental/wangle/acceptor/LoadShedConfiguration.cpp [new file with mode: 0644]
folly/experimental/wangle/acceptor/LoadShedConfiguration.h [new file with mode: 0644]
folly/experimental/wangle/acceptor/NetworkAddress.h [new file with mode: 0644]
folly/experimental/wangle/acceptor/ServerSocketConfig.h [new file with mode: 0644]
folly/experimental/wangle/acceptor/SocketOptions.cpp [new file with mode: 0644]
folly/experimental/wangle/acceptor/SocketOptions.h [new file with mode: 0644]
folly/experimental/wangle/acceptor/TransportInfo.cpp [new file with mode: 0644]
folly/experimental/wangle/acceptor/TransportInfo.h [new file with mode: 0644]
folly/experimental/wangle/ssl/ClientHelloExtStats.h [new file with mode: 0644]
folly/experimental/wangle/ssl/DHParam.h [new file with mode: 0644]
folly/experimental/wangle/ssl/PasswordInFile.cpp [new file with mode: 0644]
folly/experimental/wangle/ssl/PasswordInFile.h [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLCacheOptions.h [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLCacheProvider.h [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLContextConfig.h [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLContextManager.cpp [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLContextManager.h [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLSessionCacheManager.cpp [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLSessionCacheManager.h [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLStats.h [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLUtil.cpp [new file with mode: 0644]
folly/experimental/wangle/ssl/SSLUtil.h [new file with mode: 0644]
folly/experimental/wangle/ssl/TLSTicketKeyManager.cpp [new file with mode: 0644]
folly/experimental/wangle/ssl/TLSTicketKeyManager.h [new file with mode: 0644]
folly/experimental/wangle/ssl/TLSTicketKeySeeds.h [new file with mode: 0644]
folly/experimental/wangle/ssl/test/SSLCacheTest.cpp [new file with mode: 0644]
folly/experimental/wangle/ssl/test/SSLContextManagerTest.cpp [new file with mode: 0644]

index a95e558c9473a2dc48aed6efe268bd92689eccf9..66f62e9842050dc7bf1b7b40e82077bc5556bb59 100644 (file)
@@ -88,6 +88,26 @@ nobase_follyinclude_HEADERS = \
        experimental/wangle/rx/types.h \
        experimental/wangle/ConnectionManager.h \
        experimental/wangle/ManagedConnection.h \
+       experimental/wangle/acceptor/Acceptor.h \
+       experimental/wangle/acceptor/ConnectionCounter.h \
+       experimental/wangle/acceptor/SocketOptions.h \
+       experimental/wangle/acceptor/DomainNameMisc.h \
+       experimental/wangle/acceptor/LoadShedConfiguration.h \
+       experimental/wangle/acceptor/NetworkAddress.h \
+       experimental/wangle/acceptor/ServerSocketConfig.h \
+       experimental/wangle/acceptor/TransportInfo.h \
+       experimental/wangle/ssl/ClientHelloExtStats.h \
+       experimental/wangle/ssl/DHParam.h \
+       experimental/wangle/ssl/PasswordInFile.h \
+       experimental/wangle/ssl/SSLCacheOptions.h \
+       experimental/wangle/ssl/SSLCacheProvider.h \
+       experimental/wangle/ssl/SSLContextConfig.h \
+       experimental/wangle/ssl/SSLContextManager.h \
+       experimental/wangle/ssl/SSLSessionCacheManager.h \
+       experimental/wangle/ssl/SSLStats.h \
+       experimental/wangle/ssl/SSLUtil.h \
+       experimental/wangle/ssl/TLSTicketKeyManager.h \
+       experimental/wangle/ssl/TLSTicketKeySeeds.h \
        FBString.h \
        FBVector.h \
        File.h \
@@ -301,7 +321,16 @@ libfolly_la_SOURCES = \
        experimental/wangle/concurrent/IOThreadPoolExecutor.cpp \
        experimental/wangle/concurrent/ThreadPoolExecutor.cpp \
        experimental/wangle/ConnectionManager.cpp \
-       experimental/wangle/ManagedConnection.cpp
+       experimental/wangle/ManagedConnection.cpp \
+       experimental/wangle/acceptor/Acceptor.cpp \
+       experimental/wangle/acceptor/SocketOptions.cpp \
+       experimental/wangle/acceptor/LoadShedConfiguration.cpp \
+       experimental/wangle/acceptor/TransportInfo.cpp \
+       experimental/wangle/ssl/PasswordInFile.cpp \
+       experimental/wangle/ssl/SSLContextManager.cpp \
+       experimental/wangle/ssl/SSLSessionCacheManager.cpp \
+       experimental/wangle/ssl/SSLUtil.cpp \
+       experimental/wangle/ssl/TLSTicketKeyManager.cpp
 
 if HAVE_LINUX
 nobase_follyinclude_HEADERS += \
diff --git a/folly/experimental/wangle/acceptor/Acceptor.cpp b/folly/experimental/wangle/acceptor/Acceptor.cpp
new file mode 100644 (file)
index 0000000..534e6f4
--- /dev/null
@@ -0,0 +1,435 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/experimental/wangle/acceptor/Acceptor.h>
+
+#include <folly/experimental/wangle/ManagedConnection.h>
+#include <folly/experimental/wangle/ssl/SSLContextManager.h>
+
+#include <boost/cast.hpp>
+#include <fcntl.h>
+#include <folly/ScopeGuard.h>
+#include <folly/experimental/wangle/ManagedConnection.h>
+#include <folly/io/async/EventBase.h>
+#include <fstream>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/EventBase.h>
+#include <unistd.h>
+
+using folly::wangle::ConnectionManager;
+using folly::wangle::ManagedConnection;
+using std::chrono::microseconds;
+using std::chrono::milliseconds;
+using std::filebuf;
+using std::ifstream;
+using std::ios;
+using std::shared_ptr;
+using std::string;
+
+namespace folly {
+
+#ifndef NO_LIB_GFLAGS
+DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
+             "closing idle conns");
+#else
+const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
+#endif
+
+static const std::string empty_string;
+std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
+
+/**
+ * Lightweight wrapper class to keep track of a newly
+ * accepted connection during SSL handshaking.
+ */
+class AcceptorHandshakeHelper :
+      public AsyncSSLSocket::HandshakeCB,
+      public ManagedConnection {
+ public:
+  AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
+                          Acceptor* acceptor,
+                          const SocketAddress& clientAddr,
+                          std::chrono::steady_clock::time_point acceptTime)
+    : socket_(std::move(socket)), acceptor_(acceptor),
+      acceptTime_(acceptTime), clientAddr_(clientAddr) {
+    acceptor_->downstreamConnectionManager_->addConnection(this, true);
+    if(acceptor_->parseClientHello_)  {
+      socket_->enableClientHelloParsing();
+    }
+    socket_->sslAccept(this);
+  }
+
+  virtual void timeoutExpired() noexcept {
+    VLOG(4) << "SSL handshake timeout expired";
+    sslError_ = SSLErrorEnum::TIMEOUT;
+    dropConnection();
+  }
+  virtual void describe(std::ostream& os) const {
+    os << "pending handshake on " << clientAddr_;
+  }
+  virtual bool isBusy() const {
+    return true;
+  }
+  virtual void notifyPendingShutdown() {}
+  virtual void closeWhenIdle() {}
+
+  virtual void dropConnection() {
+    VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
+    socket_->closeNow();
+  }
+  virtual void dumpConnectionState(uint8_t loglevel) {
+  }
+
+ private:
+  // AsyncSSLSocket::HandshakeCallback API
+  virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept {
+
+    const unsigned char* nextProto = nullptr;
+    unsigned nextProtoLength = 0;
+    sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
+    if (VLOG_IS_ON(3)) {
+      if (nextProto) {
+        VLOG(3) << "Client selected next protocol " <<
+            string((const char*)nextProto, nextProtoLength);
+      } else {
+        VLOG(3) << "Client did not select a next protocol";
+      }
+    }
+
+    // fill in SSL-related fields from TransportInfo
+    // the other fields like RTT are filled in the Acceptor
+    TransportInfo tinfo;
+    tinfo.ssl = true;
+    tinfo.acceptTime = acceptTime_;
+    tinfo.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
+    tinfo.sslSetupBytesRead = sock->getRawBytesReceived();
+    tinfo.sslSetupBytesWritten = sock->getRawBytesWritten();
+    tinfo.sslServerName = sock->getSSLServerName();
+    tinfo.sslCipher = sock->getNegotiatedCipherName();
+    tinfo.sslVersion = sock->getSSLVersion();
+    tinfo.sslCertSize = sock->getSSLCertSize();
+    tinfo.sslResume = SSLUtil::getResumeState(sock);
+    sock->getSSLClientCiphers(tinfo.sslClientCiphers);
+    sock->getSSLServerCiphers(tinfo.sslServerCiphers);
+    tinfo.sslClientComprMethods = sock->getSSLClientComprMethods();
+    tinfo.sslClientExts = sock->getSSLClientExts();
+    tinfo.sslNextProtocol.assign(
+        reinterpret_cast<const char*>(nextProto),
+        nextProtoLength);
+
+    acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR);
+    acceptor_->downstreamConnectionManager_->removeConnection(this);
+    acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
+        nextProto ? string((const char*)nextProto, nextProtoLength) :
+                                  empty_string, tinfo);
+    delete this;
+  }
+
+  virtual void handshakeErr(AsyncSSLSocket* sock,
+                            const AsyncSocketException& ex) noexcept {
+    auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
+    VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
+        " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
+        sock->getRawBytesWritten() << " bytes sent: " <<
+        ex.what();
+    acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
+    acceptor_->sslConnectionError();
+    delete this;
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+  Acceptor* acceptor_;
+  std::chrono::steady_clock::time_point acceptTime_;
+  SocketAddress clientAddr_;
+  SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
+};
+
+Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
+  accConfig_(accConfig),
+  socketOptions_(accConfig.getSocketOptions()) {
+}
+
+void
+Acceptor::init(AsyncServerSocket* serverSocket,
+               EventBase* eventBase) {
+  CHECK(nullptr == this->base_);
+
+  if (accConfig_.isSSL()) {
+    if (!sslCtxManager_) {
+      sslCtxManager_ = folly::make_unique<SSLContextManager>(
+        eventBase,
+        "vip_" + getName(),
+        accConfig_.strictSSL, nullptr);
+    }
+    for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
+      sslCtxManager_->addSSLContextConfig(
+        sslCtxConfig,
+        accConfig_.sslCacheOptions,
+        &accConfig_.initialTicketSeeds,
+        accConfig_.bindAddress,
+        cacheProvider_);
+      parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
+    }
+
+    CHECK(sslCtxManager_->getDefaultSSLCtx());
+  }
+
+  base_ = eventBase;
+  state_ = State::kRunning;
+  downstreamConnectionManager_ = ConnectionManager::makeUnique(
+    eventBase, accConfig_.connectionIdleTimeout, this);
+
+  serverSocket->addAcceptCallback(this, eventBase);
+  // SO_KEEPALIVE is the only setting that is inherited by accepted
+  // connections so only apply this setting
+  for (const auto& option: socketOptions_) {
+    if (option.first.level == SOL_SOCKET &&
+        option.first.optname == SO_KEEPALIVE && option.second == 1) {
+      serverSocket->setKeepAliveEnabled(true);
+      break;
+    }
+  }
+}
+
+Acceptor::~Acceptor(void) {
+}
+
+void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
+  sslCtxManager_->addSSLContextConfig(sslCtxConfig,
+                                      accConfig_.sslCacheOptions,
+                                      &accConfig_.initialTicketSeeds,
+                                      accConfig_.bindAddress,
+                                      cacheProvider_);
+}
+
+void
+Acceptor::drainAllConnections() {
+  if (downstreamConnectionManager_) {
+    downstreamConnectionManager_->initiateGracefulShutdown(
+      std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
+  }
+}
+
+void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
+                       IConnectionCounter* counter) {
+  loadShedConfig_ = from;
+  connectionCounter_ = counter;
+}
+
+bool Acceptor::canAccept(const SocketAddress& address) {
+  if (!connectionCounter_) {
+    return true;
+  }
+
+  uint64_t maxConnections = connectionCounter_->getMaxConnections();
+  if (maxConnections == 0) {
+    return true;
+  }
+
+  uint64_t currentConnections = connectionCounter_->getNumConnections();
+  if (currentConnections < maxConnections) {
+    return true;
+  }
+
+  if (loadShedConfig_.isWhitelisted(address)) {
+    return true;
+  }
+
+  // Take care of comparing connection count against max connections across
+  // all acceptors. Expensive since a lock must be taken to get the counter.
+  auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
+  if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
+    return true;
+  }
+
+  VLOG(4) << address.describe() << " not whitelisted";
+  return false;
+}
+
+void
+Acceptor::connectionAccepted(
+    int fd, const SocketAddress& clientAddr) noexcept {
+  if (!canAccept(clientAddr)) {
+    close(fd);
+    return;
+  }
+  auto acceptTime = std::chrono::steady_clock::now();
+  for (const auto& opt: socketOptions_) {
+    opt.first.apply(fd, opt.second);
+  }
+
+  onDoneAcceptingConnection(fd, clientAddr, acceptTime);
+}
+
+void Acceptor::onDoneAcceptingConnection(
+    int fd,
+    const SocketAddress& clientAddr,
+    std::chrono::steady_clock::time_point acceptTime) noexcept {
+  processEstablishedConnection(fd, clientAddr, acceptTime);
+}
+
+void
+Acceptor::processEstablishedConnection(
+    int fd,
+    const SocketAddress& clientAddr,
+    std::chrono::steady_clock::time_point acceptTime) noexcept {
+  if (accConfig_.isSSL()) {
+    CHECK(sslCtxManager_);
+    AsyncSSLSocket::UniquePtr sslSock(
+      makeNewAsyncSSLSocket(
+        sslCtxManager_->getDefaultSSLCtx(), base_, fd));
+    ++numPendingSSLConns_;
+    ++totalNumPendingSSLConns_;
+    if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
+      VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
+        " too many handshakes in progress";
+      updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
+                     SSLErrorEnum::DROPPED);
+      sslConnectionError();
+      return;
+    }
+    new AcceptorHandshakeHelper(
+      std::move(sslSock), this, clientAddr, acceptTime);
+  } else {
+    TransportInfo tinfo;
+    tinfo.ssl = false;
+    tinfo.acceptTime = acceptTime;
+    AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
+    connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
+  }
+}
+
+void
+Acceptor::connectionReady(
+    AsyncSocket::UniquePtr sock,
+    const SocketAddress& clientAddr,
+    const string& nextProtocolName,
+    TransportInfo& tinfo) {
+  // Limit the number of reads from the socket per poll loop iteration,
+  // both to keep memory usage under control and to prevent one fast-
+  // writing client from starving other connections.
+  sock->setMaxReadsPerEvent(16);
+  tinfo.initWithSocket(sock.get());
+  onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
+}
+
+void
+Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
+                             const SocketAddress& clientAddr,
+                             const string& nextProtocol,
+                             TransportInfo& tinfo) {
+  CHECK(numPendingSSLConns_ > 0);
+  connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
+  --numPendingSSLConns_;
+  --totalNumPendingSSLConns_;
+  if (state_ == State::kDraining) {
+    checkDrained();
+  }
+}
+
+void
+Acceptor::sslConnectionError() {
+  CHECK(numPendingSSLConns_ > 0);
+  --numPendingSSLConns_;
+  --totalNumPendingSSLConns_;
+  if (state_ == State::kDraining) {
+    checkDrained();
+  }
+}
+
+void
+Acceptor::acceptError(const std::exception& ex) noexcept {
+  // An error occurred.
+  // The most likely error is out of FDs.  AsyncServerSocket will back off
+  // briefly if we are out of FDs, then continue accepting later.
+  // Just log a message here.
+  LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
+}
+
+void
+Acceptor::acceptStopped() noexcept {
+  VLOG(3) << "Acceptor " << this << " acceptStopped()";
+  // Drain the open client connections
+  drainAllConnections();
+
+  // If we haven't yet finished draining, begin doing so by marking ourselves
+  // as in the draining state. We must be sure to hit checkDrained() here, as
+  // if we're completely idle, we can should consider ourself drained
+  // immediately (as there is no outstanding work to complete to cause us to
+  // re-evaluate this).
+  if (state_ != State::kDone) {
+    state_ = State::kDraining;
+    checkDrained();
+  }
+}
+
+void
+Acceptor::onEmpty(const ConnectionManager& cm) {
+  VLOG(3) << "Acceptor=" << this << " onEmpty()";
+  if (state_ == State::kDraining) {
+    checkDrained();
+  }
+}
+
+void
+Acceptor::checkDrained() {
+  CHECK(state_ == State::kDraining);
+  if (forceShutdownInProgress_ ||
+      (downstreamConnectionManager_->getNumConnections() != 0) ||
+      (numPendingSSLConns_ != 0)) {
+    return;
+  }
+
+  VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
+          << base_;
+
+  downstreamConnectionManager_.reset();
+
+  state_ = State::kDone;
+
+  onConnectionsDrained();
+}
+
+milliseconds
+Acceptor::getConnTimeout() const {
+  return accConfig_.connectionIdleTimeout;
+}
+
+void Acceptor::addConnection(ManagedConnection* conn) {
+  // Add the socket to the timeout manager so that it can be cleaned
+  // up after being left idle for a long time.
+  downstreamConnectionManager_->addConnection(conn, true);
+}
+
+void
+Acceptor::forceStop() {
+  base_->runInEventBaseThread([&] { dropAllConnections(); });
+}
+
+void
+Acceptor::dropAllConnections() {
+  if (downstreamConnectionManager_) {
+    LOG(INFO) << "Dropping all connections from Acceptor=" << this <<
+      " in thread " << base_;
+    assert(base_->isInEventBaseThread());
+    forceShutdownInProgress_ = true;
+    downstreamConnectionManager_->dropAllConnections();
+    CHECK(downstreamConnectionManager_->getNumConnections() == 0);
+    downstreamConnectionManager_.reset();
+  }
+  CHECK(numPendingSSLConns_ == 0);
+
+  state_ = State::kDone;
+  onConnectionsDrained();
+}
+
+} // namespace
diff --git a/folly/experimental/wangle/acceptor/Acceptor.h b/folly/experimental/wangle/acceptor/Acceptor.h
new file mode 100644 (file)
index 0000000..6959701
--- /dev/null
@@ -0,0 +1,338 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include "folly/experimental/wangle/acceptor/ServerSocketConfig.h"
+#include "folly/experimental/wangle/acceptor/ConnectionCounter.h"
+#include <folly/experimental/wangle/ConnectionManager.h>
+#include "folly/experimental/wangle/acceptor/LoadShedConfiguration.h"
+#include "folly/experimental/wangle/ssl/SSLCacheProvider.h"
+#include "folly/experimental/wangle/acceptor/TransportInfo.h"
+
+#include <chrono>
+#include <event.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncServerSocket.h>
+
+namespace folly { namespace wangle {
+class ManagedConnection;
+}}
+
+namespace folly {
+
+class SocketAddress;
+class SSLContext;
+class AsyncTransport;
+class SSLContextManager;
+
+/**
+ * An abstract acceptor for TCP-based network services.
+ *
+ * There is one acceptor object per thread for each listening socket.  When a
+ * new connection arrives on the listening socket, it is accepted by one of the
+ * acceptor objects.  From that point on the connection will be processed by
+ * that acceptor's thread.
+ *
+ * The acceptor will call the abstract onNewConnection() method to create
+ * a new ManagedConnection object for each accepted socket.  The acceptor
+ * also tracks all outstanding connections that it has accepted.
+ */
+class Acceptor :
+  public folly::AsyncServerSocket::AcceptCallback,
+  public folly::wangle::ConnectionManager::Callback {
+ public:
+
+  enum class State : uint32_t {
+    kInit,  // not yet started
+    kRunning, // processing requests normally
+    kDraining, // processing outstanding conns, but not accepting new ones
+    kDone,  // no longer accepting, and all connections finished
+  };
+
+  explicit Acceptor(const ServerSocketConfig& accConfig);
+  virtual ~Acceptor();
+
+  /**
+   * Supply an SSL cache provider
+   * @note Call this before init()
+   */
+  virtual void setSSLCacheProvider(
+      const std::shared_ptr<SSLCacheProvider>& cacheProvider) {
+    cacheProvider_ = cacheProvider;
+  }
+
+  /**
+   * Initialize the Acceptor to run in the specified EventBase
+   * thread, receiving connections from the specified AsyncServerSocket.
+   *
+   * This method will be called from the AsyncServerSocket's primary thread,
+   * not the specified EventBase thread.
+   */
+  virtual void init(AsyncServerSocket* serverSocket,
+                    EventBase* eventBase);
+
+  /**
+   * Dynamically add a new SSLContextConfig
+   */
+  void addSSLContextConfig(const SSLContextConfig& sslCtxConfig);
+
+  SSLContextManager* getSSLContextManager() const {
+    return sslCtxManager_.get();
+  }
+
+  /**
+   * Return the number of outstanding connections in this service instance.
+   */
+  uint32_t getNumConnections() const {
+    return downstreamConnectionManager_ ?
+        downstreamConnectionManager_->getNumConnections() : 0;
+  }
+
+  /**
+   * Access the Acceptor's event base.
+   */
+  EventBase* getEventBase() { return base_; }
+
+  /**
+   * Access the Acceptor's downstream (client-side) ConnectionManager
+   */
+  virtual folly::wangle::ConnectionManager* getConnectionManager() {
+    return downstreamConnectionManager_.get();
+  }
+
+  /**
+   * Invoked when a new ManagedConnection is created.
+   *
+   * This allows the Acceptor to track the outstanding connections,
+   * for tracking timeouts and for ensuring that all connections have been
+   * drained on shutdown.
+   */
+  void addConnection(folly::wangle::ManagedConnection* connection);
+
+  /**
+   * Get this acceptor's current state.
+   */
+  State getState() const {
+    return state_;
+  }
+
+  /**
+   * Get the current connection timeout.
+   */
+  std::chrono::milliseconds getConnTimeout() const;
+
+  /**
+   * Returns the name of this VIP.
+   *
+   * Will return an empty string if no name has been configured.
+   */
+  const std::string& getName() const {
+    return accConfig_.name;
+  }
+
+  /**
+   * Force the acceptor to drop all connections and stop processing.
+   *
+   * This function may be called from any thread.  The acceptor will not
+   * necessarily stop before this function returns: the stop will be scheduled
+   * to run in the acceptor's thread.
+   */
+  virtual void forceStop();
+
+  bool isSSL() const { return accConfig_.isSSL(); }
+
+  const ServerSocketConfig& getConfig() const { return accConfig_; }
+
+  static uint64_t getTotalNumPendingSSLConns() {
+    return totalNumPendingSSLConns_.load();
+  }
+
+  /**
+   * Called right when the TCP connection has been accepted, before processing
+   * the first HTTP bytes (HTTP) or the SSL handshake (HTTPS)
+   */
+  virtual void onDoneAcceptingConnection(
+    int fd,
+    const SocketAddress& clientAddr,
+    std::chrono::steady_clock::time_point acceptTime
+  ) noexcept;
+
+  /**
+   * Begins either processing HTTP bytes (HTTP) or the SSL handshake (HTTPS)
+   */
+  void processEstablishedConnection(
+    int fd,
+    const SocketAddress& clientAddr,
+    std::chrono::steady_clock::time_point acceptTime
+  ) noexcept;
+
+ protected:
+  friend class AcceptorHandshakeHelper;
+
+  /**
+   * Our event loop.
+   *
+   * Probably needs to be used to pass to a ManagedConnection
+   * implementation. Also visible in case a subclass wishes to do additional
+   * things w/ the event loop (e.g. in attach()).
+   */
+  EventBase* base_{nullptr};
+
+  virtual uint64_t getConnectionCountForLoadShedding(void) const { return 0; }
+
+  /**
+   * Hook for subclasses to drop newly accepted connections prior
+   * to handshaking.
+   */
+  virtual bool canAccept(const folly::SocketAddress&);
+
+  /**
+   * Invoked when a new connection is created. This is where application starts
+   * processing a new downstream connection.
+   *
+   * NOTE: Application should add the new connection to
+   *       downstreamConnectionManager so that it can be garbage collected after
+   *       certain period of idleness.
+   *
+   * @param sock              the socket connected to the client
+   * @param address           the address of the client
+   * @param nextProtocolName  the name of the L6 or L7 protocol to be
+   *                            spoken on the connection, if known (e.g.,
+   *                            from TLS NPN during secure connection setup),
+   *                            or an empty string if unknown
+   */
+  virtual void onNewConnection(
+      AsyncSocket::UniquePtr sock,
+      const folly::SocketAddress* address,
+      const std::string& nextProtocolName,
+      const TransportInfo& tinfo) = 0;
+
+  virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) {
+    return AsyncSocket::UniquePtr(new AsyncSocket(base, fd));
+  }
+
+  virtual AsyncSSLSocket::UniquePtr makeNewAsyncSSLSocket(
+    const std::shared_ptr<SSLContext>& ctx, EventBase* base, int fd) {
+    return AsyncSSLSocket::UniquePtr(new AsyncSSLSocket(ctx, base, fd));
+  }
+
+  /**
+   * Hook for subclasses to record stats about SSL connection establishment.
+   */
+  virtual void updateSSLStats(
+      const AsyncSSLSocket* sock,
+      std::chrono::milliseconds acceptLatency,
+      SSLErrorEnum error) noexcept {}
+
+  /**
+   * Drop all connections.
+   *
+   * forceStop() schedules dropAllConnections() to be called in the acceptor's
+   * thread.
+   */
+  void dropAllConnections();
+
+  /**
+   * Drains all open connections of their outstanding transactions. When
+   * a connection's transaction count reaches zero, the connection closes.
+   */
+  void drainAllConnections();
+
+  /**
+   * onConnectionsDrained() will be called once all connections have been
+   * drained while the acceptor is stopping.
+   *
+   * Subclasses can override this method to perform any subclass-specific
+   * cleanup.
+   */
+  virtual void onConnectionsDrained() {}
+
+  // AsyncServerSocket::AcceptCallback methods
+  void connectionAccepted(int fd,
+      const folly::SocketAddress& clientAddr)
+      noexcept;
+  void acceptError(const std::exception& ex) noexcept;
+  void acceptStopped() noexcept;
+
+  // ConnectionManager::Callback methods
+  void onEmpty(const folly::wangle::ConnectionManager& cm);
+  void onConnectionAdded(const folly::wangle::ConnectionManager& cm) {}
+  void onConnectionRemoved(const folly::wangle::ConnectionManager& cm) {}
+
+  /**
+   * Process a connection that is to ready to receive L7 traffic.
+   * This method is called immediately upon accept for plaintext
+   * connections and upon completion of SSL handshaking or resumption
+   * for SSL connections.
+   */
+   void connectionReady(
+      AsyncSocket::UniquePtr sock,
+      const folly::SocketAddress& clientAddr,
+      const std::string& nextProtocolName,
+      TransportInfo& tinfo);
+
+  const LoadShedConfiguration& getLoadShedConfiguration() const {
+    return loadShedConfig_;
+  }
+
+ protected:
+  const ServerSocketConfig accConfig_;
+  void setLoadShedConfig(const LoadShedConfiguration& from,
+                         IConnectionCounter* counter);
+
+  /**
+   * Socket options to apply to the client socket
+   */
+  AsyncSocket::OptionMap socketOptions_;
+
+  std::unique_ptr<SSLContextManager> sslCtxManager_;
+
+  /**
+   * Whether we want to enable client hello parsing in the handshake helper
+   * to get list of supported client ciphers.
+   */
+  bool parseClientHello_{false};
+
+  folly::wangle::ConnectionManager::UniquePtr downstreamConnectionManager_;
+
+ private:
+
+  // Forbidden copy constructor and assignment opererator
+  Acceptor(Acceptor const &) = delete;
+  Acceptor& operator=(Acceptor const &) = delete;
+
+  /**
+   * Wrapper for connectionReady() that decrements the count of
+   * pending SSL connections.
+   */
+  void sslConnectionReady(AsyncSocket::UniquePtr sock,
+      const folly::SocketAddress& clientAddr,
+      const std::string& nextProtocol,
+      TransportInfo& tinfo);
+
+  /**
+   * Notification callback for SSL handshake failures.
+   */
+  void sslConnectionError();
+
+  void checkDrained();
+
+  State state_{State::kInit};
+  uint64_t numPendingSSLConns_{0};
+
+  static std::atomic<uint64_t> totalNumPendingSSLConns_;
+
+  bool forceShutdownInProgress_{false};
+  LoadShedConfiguration loadShedConfig_;
+  IConnectionCounter* connectionCounter_{nullptr};
+  std::shared_ptr<SSLCacheProvider> cacheProvider_;
+};
+
+} // namespace
diff --git a/folly/experimental/wangle/acceptor/ConnectionCounter.h b/folly/experimental/wangle/acceptor/ConnectionCounter.h
new file mode 100644 (file)
index 0000000..bf891bb
--- /dev/null
@@ -0,0 +1,54 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+namespace folly {
+
+class IConnectionCounter {
+ public:
+  virtual uint64_t getNumConnections() const = 0;
+
+  /**
+   * Get the maximum number of non-whitelisted client-side connections
+   * across all Acceptors managed by this. A value
+   * of zero means "unlimited."
+   */
+  virtual uint64_t getMaxConnections() const = 0;
+
+  /**
+   * Increment the count of client-side connections.
+   */
+  virtual void onConnectionAdded() = 0;
+
+  /**
+   * Decrement the count of client-side connections.
+   */
+  virtual void onConnectionRemoved() = 0;
+  virtual ~IConnectionCounter() {}
+};
+
+class SimpleConnectionCounter: public IConnectionCounter {
+ public:
+  uint64_t getNumConnections() const override { return numConnections_; }
+  uint64_t getMaxConnections() const override { return maxConnections_; }
+  void setMaxConnections(uint64_t maxConnections) {
+    maxConnections_ = maxConnections;
+  }
+
+  void onConnectionAdded() override { numConnections_++; }
+  void onConnectionRemoved() override { numConnections_--; }
+  virtual ~SimpleConnectionCounter() {}
+
+ protected:
+  uint64_t maxConnections_{0};
+  uint64_t numConnections_{0};
+};
+
+}
diff --git a/folly/experimental/wangle/acceptor/DomainNameMisc.h b/folly/experimental/wangle/acceptor/DomainNameMisc.h
new file mode 100644 (file)
index 0000000..41c4c74
--- /dev/null
@@ -0,0 +1,71 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <string>
+
+namespace folly {
+
+struct dn_char_traits : public std::char_traits<char> {
+  static bool eq(char c1, char c2) {
+    return ::tolower(c1) == ::tolower(c2);
+  }
+
+  static bool ne(char c1, char c2) {
+    return ::tolower(c1) != ::tolower(c2);
+  }
+
+  static bool lt(char c1, char c2) {
+    return ::tolower(c1) < ::tolower(c2);
+  }
+
+  static int compare(const char* s1, const char* s2, size_t n) {
+    while (n--) {
+      if(::tolower(*s1) < ::tolower(*s2) ) {
+        return -1;
+      }
+      if(::tolower(*s1) > ::tolower(*s2) ) {
+        return 1;
+      }
+      ++s1;
+      ++s2;
+    }
+    return 0;
+  }
+
+  static const char* find(const char* s, size_t n, char a) {
+    char la = ::tolower(a);
+    while (n--) {
+      if(::tolower(*s) == la) {
+        return s;
+      } else {
+        ++s;
+      }
+    }
+    return nullptr;
+  }
+};
+
+// Case insensitive string
+typedef std::basic_string<char, dn_char_traits> DNString;
+
+struct DNStringHash : public std::hash<std::string> {
+  size_t operator()(const DNString& s) const noexcept {
+    size_t h = static_cast<size_t>(0xc70f6907UL);
+    const char* d = s.data();
+    for (size_t i = 0; i < s.length(); ++i) {
+      char a = ::tolower(*d++);
+      h = std::_Hash_impl::hash(&a, sizeof(a), h);
+    }
+    return h;
+  }
+};
+
+} // namespace
diff --git a/folly/experimental/wangle/acceptor/LoadShedConfiguration.cpp b/folly/experimental/wangle/acceptor/LoadShedConfiguration.cpp
new file mode 100644 (file)
index 0000000..e08e71b
--- /dev/null
@@ -0,0 +1,43 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/experimental/wangle/acceptor/LoadShedConfiguration.h>
+
+#include <folly/Conv.h>
+#include <openssl/ssl.h>
+
+using std::string;
+
+namespace folly {
+
+void LoadShedConfiguration::addWhitelistAddr(folly::StringPiece input) {
+  auto addr = input.str();
+  size_t separator = addr.find_first_of('/');
+  if (separator == string::npos) {
+    whitelistAddrs_.insert(SocketAddress(addr, 0));
+  } else {
+    unsigned prefixLen = folly::to<unsigned>(addr.substr(separator + 1));
+    addr.erase(separator);
+    whitelistNetworks_.insert(NetworkAddress(SocketAddress(addr, 0), prefixLen));
+  }
+}
+
+bool LoadShedConfiguration::isWhitelisted(const SocketAddress& address) const {
+  if (whitelistAddrs_.find(address) != whitelistAddrs_.end()) {
+    return true;
+  }
+  for (auto& network : whitelistNetworks_) {
+    if (network.contains(address)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+}
diff --git a/folly/experimental/wangle/acceptor/LoadShedConfiguration.h b/folly/experimental/wangle/acceptor/LoadShedConfiguration.h
new file mode 100644 (file)
index 0000000..3c70b6e
--- /dev/null
@@ -0,0 +1,107 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <chrono>
+#include <folly/Range.h>
+#include <folly/SocketAddress.h>
+#include <glog/logging.h>
+#include <list>
+#include <set>
+#include <string>
+
+#include <folly/experimental/wangle/acceptor/NetworkAddress.h>
+
+namespace folly {
+
+/**
+ * Class that holds an LoadShed configuration for a service
+ */
+class LoadShedConfiguration {
+ public:
+
+  // Comparison function for SocketAddress that disregards the port
+  struct AddressOnlyCompare {
+    bool operator()(
+     const SocketAddress& addr1,
+     const SocketAddress& addr2) const {
+      return addr1.getIPAddress() < addr2.getIPAddress();
+    }
+  };
+
+  typedef std::set<SocketAddress, AddressOnlyCompare> AddressSet;
+  typedef std::set<NetworkAddress> NetworkSet;
+
+  LoadShedConfiguration() {}
+
+  ~LoadShedConfiguration() {}
+
+  void addWhitelistAddr(folly::StringPiece);
+
+  /**
+   * Set/get the set of IPs that should be whitelisted through even when we're
+   * trying to shed load.
+   */
+  void setWhitelistAddrs(const AddressSet& addrs) { whitelistAddrs_ = addrs; }
+  const AddressSet& getWhitelistAddrs() const { return whitelistAddrs_; }
+
+  /**
+   * Set/get the set of networks that should be whitelisted through even
+   * when we're trying to shed load.
+   */
+  void setWhitelistNetworks(const NetworkSet& networks) {
+    whitelistNetworks_ = networks;
+  }
+  const NetworkSet& getWhitelistNetworks() const { return whitelistNetworks_; }
+
+  /**
+   * Set/get the maximum number of downstream connections across all VIPs.
+   */
+  void setMaxConnections(uint64_t maxConns) { maxConnections_ = maxConns; }
+  uint64_t getMaxConnections() const { return maxConnections_; }
+
+  /**
+   * Set/get the maximum cpu usage.
+   */
+  void setMaxMemUsage(double max) {
+    CHECK(max >= 0);
+    CHECK(max <= 1);
+    maxMemUsage_ = max;
+  }
+  double getMaxMemUsage() const { return maxMemUsage_; }
+
+  /**
+   * Set/get the maximum memory usage.
+   */
+  void setMaxCpuUsage(double max) {
+    CHECK(max >= 0);
+    CHECK(max <= 1);
+    maxCpuUsage_ = max;
+  }
+  double getMaxCpuUsage() const { return maxCpuUsage_; }
+
+  void setLoadUpdatePeriod(std::chrono::milliseconds period) {
+    period_ = period;
+  }
+  std::chrono::milliseconds getLoadUpdatePeriod() const { return period_; }
+
+  bool isWhitelisted(const SocketAddress& addr) const;
+
+ private:
+
+  AddressSet whitelistAddrs_;
+  NetworkSet whitelistNetworks_;
+  uint64_t maxConnections_{0};
+  double maxMemUsage_;
+  double maxCpuUsage_;
+  std::chrono::milliseconds period_;
+};
+
+}
diff --git a/folly/experimental/wangle/acceptor/NetworkAddress.h b/folly/experimental/wangle/acceptor/NetworkAddress.h
new file mode 100644 (file)
index 0000000..3698037
--- /dev/null
@@ -0,0 +1,60 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/SocketAddress.h>
+
+namespace folly {
+
+/**
+ * A simple wrapper around SocketAddress that represents
+ * a network in CIDR notation
+ */
+class NetworkAddress {
+public:
+  /**
+   * Create a NetworkAddress for an addr/prefixLen
+   * @param addr         IPv4 or IPv6 address of the network
+   * @param prefixLen    Prefix length, in bits
+   */
+  NetworkAddress(const folly::SocketAddress& addr,
+      unsigned prefixLen):
+    addr_(addr), prefixLen_(prefixLen) {}
+
+  /** Get the network address */
+  const folly::SocketAddress& getAddress() const {
+    return addr_;
+  }
+
+  /** Get the prefix length in bits */
+  unsigned getPrefixLength() const { return prefixLen_; }
+
+  /** Check whether a given address lies within the network */
+  bool contains(const folly::SocketAddress& addr) const {
+    return addr_.prefixMatch(addr, prefixLen_);
+  }
+
+  /** Comparison operator to enable use in ordered collections */
+  bool operator<(const NetworkAddress& other) const {
+    if (addr_ <  other.addr_) {
+      return true;
+    } else if (other.addr_ < addr_) {
+      return false;
+    } else {
+      return (prefixLen_ < other.prefixLen_);
+    }
+  }
+
+private:
+  folly::SocketAddress addr_;
+  unsigned prefixLen_;
+};
+
+} // namespace
diff --git a/folly/experimental/wangle/acceptor/ServerSocketConfig.h b/folly/experimental/wangle/acceptor/ServerSocketConfig.h
new file mode 100644 (file)
index 0000000..14dc42c
--- /dev/null
@@ -0,0 +1,128 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/experimental/wangle/ssl/SSLCacheOptions.h>
+#include <folly/experimental/wangle/ssl/SSLContextConfig.h>
+#include <folly/experimental/wangle/ssl/TLSTicketKeySeeds.h>
+#include <folly/experimental/wangle/ssl/SSLUtil.h>
+#include <folly/experimental/wangle/acceptor/SocketOptions.h>
+
+#include <boost/optional.hpp>
+#include <chrono>
+#include <fcntl.h>
+#include <folly/Random.h>
+#include <folly/SocketAddress.h>
+#include <folly/String.h>
+#include <folly/io/async/SSLContext.h>
+#include <list>
+#include <string>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/SSLContext.h>
+#include <folly/SocketAddress.h>
+
+namespace folly {
+
+/**
+ * Configuration for a single Acceptor.
+ *
+ * This configures not only accept behavior, but also some types of SSL
+ * behavior that may make sense to configure on a per-VIP basis (e.g. which
+ * cert(s) we use, etc).
+ */
+struct ServerSocketConfig {
+  ServerSocketConfig() {
+    // generate a single random current seed
+    uint8_t seed[32];
+    folly::Random::secureRandom(seed, sizeof(seed));
+    initialTicketSeeds.currentSeeds.push_back(
+      SSLUtil::hexlify(std::string((char *)seed, sizeof(seed))));
+  }
+
+  bool isSSL() const { return !(sslContextConfigs.empty()); }
+
+  /**
+   * Set/get the socket options to apply on all downstream connections.
+   */
+  void setSocketOptions(
+    const AsyncSocket::OptionMap& opts) {
+    socketOptions_ = filterIPSocketOptions(opts, bindAddress.getFamily());
+  }
+  AsyncSocket::OptionMap&
+  getSocketOptions() {
+    return socketOptions_;
+  }
+  const AsyncSocket::OptionMap&
+  getSocketOptions() const {
+    return socketOptions_;
+  }
+
+  bool hasExternalPrivateKey() const {
+    for (const auto& cfg : sslContextConfigs) {
+      if (!cfg.isLocalPrivateKey) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /**
+   * The name of this acceptor; used for stats/reporting purposes.
+   */
+  std::string name;
+
+  /**
+   * The depth of the accept queue backlog.
+   */
+  uint32_t acceptBacklog{1024};
+
+  /**
+   * The number of milliseconds a connection can be idle before we close it.
+   */
+  std::chrono::milliseconds connectionIdleTimeout{600000};
+
+  /**
+   * The address to bind to.
+   */
+  SocketAddress bindAddress;
+
+  /**
+   * Options for controlling the SSL cache.
+   */
+  SSLCacheOptions sslCacheOptions{std::chrono::seconds(600), 20480, 200};
+
+  /**
+   * The initial TLS ticket seeds.
+   */
+  TLSTicketKeySeeds initialTicketSeeds;
+
+  /**
+   * The configs for all the SSL_CTX for use by this Acceptor.
+   */
+  std::vector<SSLContextConfig> sslContextConfigs;
+
+  /**
+   * Determines if the Acceptor does strict checking when loading the SSL
+   * contexts.
+   */
+  bool strictSSL{true};
+
+  /**
+   * Maximum number of concurrent pending SSL handshakes
+   */
+  uint32_t maxConcurrentSSLHandshakes{30720};
+
+ private:
+  AsyncSocket::OptionMap socketOptions_;
+};
+
+} // folly
diff --git a/folly/experimental/wangle/acceptor/SocketOptions.cpp b/folly/experimental/wangle/acceptor/SocketOptions.cpp
new file mode 100644 (file)
index 0000000..c7c82b8
--- /dev/null
@@ -0,0 +1,38 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/experimental/wangle/acceptor/SocketOptions.h>
+
+#include <netinet/tcp.h>
+#include <sys/socket.h>
+
+namespace folly {
+
+AsyncSocket::OptionMap filterIPSocketOptions(
+  const AsyncSocket::OptionMap& allOptions,
+  const int addrFamily) {
+  AsyncSocket::OptionMap opts;
+  int exclude;
+  if (addrFamily == AF_INET) {
+    exclude = IPPROTO_IPV6;
+  } else if (addrFamily == AF_INET6) {
+    exclude = IPPROTO_IP;
+  } else {
+    LOG(FATAL) << "Address family " << addrFamily << " was not IPv4 or IPv6";
+    return opts;
+  }
+  for (const auto& opt: allOptions) {
+    if (opt.first.level != exclude) {
+      opts[opt.first] = opt.second;
+    }
+  }
+  return opts;
+}
+
+}
diff --git a/folly/experimental/wangle/acceptor/SocketOptions.h b/folly/experimental/wangle/acceptor/SocketOptions.h
new file mode 100644 (file)
index 0000000..37ba371
--- /dev/null
@@ -0,0 +1,24 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/AsyncSocket.h>
+
+namespace folly {
+
+/**
+ * Returns a copy of the socket options excluding options with the given
+ * level.
+ */
+AsyncSocket::OptionMap filterIPSocketOptions(
+  const AsyncSocket::OptionMap& allOptions,
+  const int addrFamily);
+
+}
diff --git a/folly/experimental/wangle/acceptor/TransportInfo.cpp b/folly/experimental/wangle/acceptor/TransportInfo.cpp
new file mode 100644 (file)
index 0000000..02de719
--- /dev/null
@@ -0,0 +1,65 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/experimental/wangle/acceptor/TransportInfo.h>
+
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <folly/io/async/AsyncSocket.h>
+
+using std::chrono::microseconds;
+using std::map;
+using std::string;
+
+namespace folly {
+
+bool TransportInfo::initWithSocket(const AsyncSocket* sock) {
+#if defined(__linux__) || defined(__FreeBSD__)
+  if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) {
+    tcpinfoErrno = errno;
+    return false;
+  }
+  rtt = microseconds(tcpinfo.tcpi_rtt);
+  validTcpinfo = true;
+#else
+  tcpinfoErrno = EINVAL;
+  rtt = microseconds(-1);
+#endif
+  return true;
+}
+
+int64_t TransportInfo::readRTT(const AsyncSocket* sock) {
+#if defined(__linux__) || defined(__FreeBSD__)
+  struct tcp_info tcpinfo;
+  if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) {
+    return -1;
+  }
+  return tcpinfo.tcpi_rtt;
+#else
+  return -1;
+#endif
+}
+
+#if defined(__linux__) || defined(__FreeBSD__)
+bool TransportInfo::readTcpInfo(struct tcp_info* tcpinfo,
+                                const AsyncSocket* sock) {
+  socklen_t len = sizeof(struct tcp_info);
+  if (!sock) {
+    return false;
+  }
+  if (getsockopt(sock->getFd(), IPPROTO_TCP,
+                 TCP_INFO, (void*) tcpinfo, &len) < 0) {
+    VLOG(4) << "Error calling getsockopt(): " << strerror(errno);
+    return false;
+  }
+  return true;
+}
+#endif
+
+} // folly
diff --git a/folly/experimental/wangle/acceptor/TransportInfo.h b/folly/experimental/wangle/acceptor/TransportInfo.h
new file mode 100644 (file)
index 0000000..069adb7
--- /dev/null
@@ -0,0 +1,292 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/experimental/wangle/ssl/SSLUtil.h>
+
+#include <chrono>
+#include <netinet/tcp.h>
+#include <string>
+
+namespace folly {
+class AsyncSocket;
+
+/**
+ * A structure that encapsulates byte counters related to the HTTP headers.
+ */
+struct HTTPHeaderSize {
+  /**
+   * The number of bytes used to represent the header after compression or
+   * before decompression. If header compression is not supported, the value
+   * is set to 0.
+   */
+  uint32_t compressed{0};
+
+  /**
+   * The number of bytes used to represent the serialized header before
+   * compression or after decompression, in plain-text format.
+   */
+  uint32_t uncompressed{0};
+};
+
+struct TransportInfo {
+  /*
+   * timestamp of when the connection handshake was completed
+   */
+  std::chrono::steady_clock::time_point acceptTime{};
+
+  /*
+   * connection RTT (Round-Trip Time)
+   */
+  std::chrono::microseconds rtt{0};
+
+#if defined(__linux__) || defined(__FreeBSD__)
+  /*
+   * TCP information as fetched from getsockopt(2)
+   */
+  tcp_info tcpinfo {
+#if __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17
+    0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 32
+#else
+    0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 29
+#endif  // __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17
+  };
+#endif  // defined(__linux__) || defined(__FreeBSD__)
+
+  /*
+   * time for setting the connection, from the moment in was accepted until it
+   * is established.
+   */
+  std::chrono::milliseconds setupTime{0};
+
+  /*
+   * time for setting up the SSL connection or SSL handshake
+   */
+  std::chrono::milliseconds sslSetupTime{0};
+
+  /*
+   * The name of the SSL ciphersuite used by the transaction's
+   * transport.  Returns null if the transport is not SSL.
+   */
+  const char* sslCipher{nullptr};
+
+  /*
+   * The SSL server name used by the transaction's
+   * transport.  Returns null if the transport is not SSL.
+   */
+  const char* sslServerName{nullptr};
+
+  /*
+   * list of ciphers sent by the client
+   */
+  std::string sslClientCiphers{};
+
+  /*
+   * list of compression methods sent by the client
+   */
+  std::string sslClientComprMethods{};
+
+  /*
+   * list of TLS extensions sent by the client
+   */
+  std::string sslClientExts{};
+
+  /*
+   * hash of all the SSL parameters sent by the client
+   */
+  std::string sslSignature{};
+
+  /*
+   * list of ciphers supported by the server
+   */
+  std::string sslServerCiphers{};
+
+  /*
+   * guessed "(os) (browser)" based on SSL Signature
+   */
+  std::string guessedUserAgent{};
+
+  /**
+   * The result of SSL NPN negotiation.
+   */
+  std::string sslNextProtocol{};
+
+  /*
+   * total number of bytes sent over the connection
+   */
+  int64_t totalBytes{0};
+
+  /**
+   * header bytes read
+   */
+  HTTPHeaderSize ingressHeader;
+
+  /*
+   * header bytes written
+   */
+  HTTPHeaderSize egressHeader;
+
+  /*
+   * Here is how the timeToXXXByte variables are planned out:
+   * 1. All timeToXXXByte variables are measuring the ByteEvent from reqStart_
+   * 2. You can get the timing between two ByteEvents by calculating their
+   *    differences. For example:
+   *    timeToLastBodyByteAck - timeToFirstByte
+   *    => Total time to deliver the body
+   * 3. The calculation in point (2) is typically done outside acceptor
+   *
+   * Future plan:
+   * We should log the timestamps (TimePoints) and allow
+   * the consumer to calculate the latency whatever it
+   * wants instead of calculating them in wangle, for the sake of flexibility.
+   * For example:
+   * 1. TimePoint reqStartTimestamp;
+   * 2. TimePoint firstHeaderByteSentTimestamp;
+   * 3. TimePoint firstBodyByteTimestamp;
+   * 3. TimePoint lastBodyByteTimestamp;
+   * 4. TimePoint lastBodyByteAckTimestamp;
+   */
+
+  /*
+   * time to first header byte written to the kernel send buffer
+   * NOTE: It is not 100% accurate since TAsyncSocket does not do
+   * do callback on partial write.
+   */
+  int32_t timeToFirstHeaderByte{-1};
+
+  /*
+   * time to first body byte written to the kernel send buffer
+   */
+  int32_t timeToFirstByte{-1};
+
+  /*
+   * time to last body byte written to the kernel send buffer
+   */
+  int32_t timeToLastByte{-1};
+
+  /*
+   * time to TCP Ack received for the last written body byte
+   */
+  int32_t timeToLastBodyByteAck{-1};
+
+  /*
+   * time it took the client to ACK the last byte, from the moment when the
+   * kernel sent the last byte to the client and until it received the ACK
+   * for that byte
+   */
+  int32_t lastByteAckLatency{-1};
+
+  /*
+   * time spent inside wangle
+   */
+  int32_t proxyLatency{-1};
+
+  /*
+   * time between connection accepted and client message headers completed
+   */
+  int32_t clientLatency{-1};
+
+  /*
+   * latency for communication with the server
+   */
+  int32_t serverLatency{-1};
+
+  /*
+   * time used to get a usable connection.
+   */
+  int32_t connectLatency{-1};
+
+  /*
+   * body bytes written
+   */
+  uint32_t egressBodySize{0};
+
+  /*
+   * value of errno in case of getsockopt() error
+   */
+  int tcpinfoErrno{0};
+
+  /*
+   * bytes read & written during SSL Setup
+   */
+  uint32_t sslSetupBytesWritten{0};
+  uint32_t sslSetupBytesRead{0};
+
+  /**
+   * SSL error detail
+   */
+  uint32_t sslError{0};
+
+  /**
+   * body bytes read
+   */
+  uint32_t ingressBodySize{0};
+
+  /*
+   * The SSL version used by the transaction's transport, in
+   * OpenSSL's format: 4 bits for the major version, followed by 4 bits
+   * for the minor version.  Returns zero for non-SSL.
+   */
+  uint16_t sslVersion{0};
+
+  /*
+   * The SSL certificate size.
+   */
+  uint16_t sslCertSize{0};
+
+  /**
+   * response status code
+   */
+  uint16_t statusCode{0};
+
+  /*
+   * The SSL mode for the transaction's transport: new session,
+   * resumed session, or neither (non-SSL).
+   */
+  SSLResumeEnum sslResume{SSLResumeEnum::NA};
+
+  /*
+   * true if the tcpinfo was successfully read from the kernel
+   */
+  bool validTcpinfo{false};
+
+  /*
+   * true if the connection is SSL, false otherwise
+   */
+  bool ssl{false};
+
+  /*
+   * get the RTT value in milliseconds
+   */
+  std::chrono::milliseconds getRttMs() const {
+    return std::chrono::duration_cast<std::chrono::milliseconds>(rtt);
+  }
+
+  /*
+   * initialize the fields related with tcp_info
+   */
+  bool initWithSocket(const AsyncSocket* sock);
+
+  /*
+   * Get the kernel's estimate of round-trip time (RTT) to the transport's peer
+   * in microseconds. Returns -1 on error.
+   */
+  static int64_t readRTT(const AsyncSocket* sock);
+
+#if defined(__linux__) || defined(__FreeBSD__)
+  /*
+   * perform the getsockopt(2) syscall to fetch TCP info for a given socket
+   */
+  static bool readTcpInfo(struct tcp_info* tcpinfo,
+                          const AsyncSocket* sock);
+#endif
+};
+
+} // folly
diff --git a/folly/experimental/wangle/ssl/ClientHelloExtStats.h b/folly/experimental/wangle/ssl/ClientHelloExtStats.h
new file mode 100644 (file)
index 0000000..a95ee0c
--- /dev/null
@@ -0,0 +1,24 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+namespace folly {
+
+class ClientHelloExtStats {
+ public:
+  virtual ~ClientHelloExtStats() noexcept {}
+
+  // client hello
+  virtual void recordAbsentHostname() noexcept = 0;
+  virtual void recordMatch() noexcept = 0;
+  virtual void recordNotMatch() noexcept = 0;
+};
+
+}
diff --git a/folly/experimental/wangle/ssl/DHParam.h b/folly/experimental/wangle/ssl/DHParam.h
new file mode 100644 (file)
index 0000000..561d569
--- /dev/null
@@ -0,0 +1,53 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <openssl/dh.h>
+
+// The following was auto-generated by
+//  openssl dhparam -C 2048
+DH *get_dh2048()
+        {
+        static unsigned char dh2048_p[]={
+                0xF8,0x87,0xA5,0x15,0x98,0x35,0x20,0x1E,0xF5,0x81,0xE5,0x95,
+                0x1B,0xE4,0x54,0xEA,0x53,0xF5,0xE7,0x26,0x30,0x03,0x06,0x79,
+                0x3C,0xC1,0x0B,0xAD,0x3B,0x59,0x3C,0x61,0x13,0x03,0x7B,0x02,
+                0x70,0xDE,0xC1,0x20,0x11,0x9E,0x94,0x13,0x50,0xF7,0x62,0xFC,
+                0x99,0x0D,0xC1,0x12,0x6E,0x03,0x95,0xA3,0x57,0xC7,0x3C,0xB8,
+                0x6B,0x40,0x56,0x65,0x70,0xFB,0x7A,0xE9,0x02,0xEC,0xD2,0xB6,
+                0x54,0xD7,0x34,0xAD,0x3D,0x9E,0x11,0x61,0x53,0xBE,0xEA,0xB8,
+                0x17,0x48,0xA8,0xDC,0x70,0xAE,0x65,0x99,0x3F,0x82,0x4C,0xFF,
+                0x6A,0xC9,0xFA,0xB1,0xFA,0xE4,0x4F,0x5D,0xA4,0x05,0xC2,0x8E,
+                0x55,0xC0,0xB1,0x1D,0xCC,0x17,0xF3,0xFA,0x65,0xD8,0x6B,0x09,
+                0x13,0x01,0x2A,0x39,0xF1,0x86,0x73,0xE3,0x7A,0xC8,0xDB,0x7D,
+                0xDA,0x1C,0xA1,0x2D,0xBA,0x2C,0x00,0x6B,0x2C,0x55,0x28,0x2B,
+                0xD5,0xF5,0x3C,0x9F,0x50,0xA7,0xB7,0x28,0x9F,0x22,0xD5,0x3A,
+                0xC4,0x53,0x01,0xC9,0xF3,0x69,0xB1,0x8D,0x01,0x36,0xF8,0xA8,
+                0x89,0xCA,0x2E,0x72,0xBC,0x36,0x3A,0x42,0xC1,0x06,0xD6,0x0E,
+                0xCB,0x4D,0x5C,0x1F,0xE4,0xA1,0x17,0xBF,0x55,0x64,0x1B,0xB4,
+                0x52,0xEC,0x15,0xED,0x32,0xB1,0x81,0x07,0xC9,0x71,0x25,0xF9,
+                0x4D,0x48,0x3D,0x18,0xF4,0x12,0x09,0x32,0xC4,0x0B,0x7A,0x4E,
+                0x83,0xC3,0x10,0x90,0x51,0x2E,0xBE,0x87,0xF9,0xDE,0xB4,0xE6,
+                0x3C,0x29,0xB5,0x32,0x01,0x9D,0x95,0x04,0xBD,0x42,0x89,0xFD,
+                0x21,0xEB,0xE9,0x88,0x5A,0x27,0xBB,0x31,0xC4,0x26,0x99,0xAB,
+                0x8C,0xA1,0x76,0xDB,
+                };
+        static unsigned char dh2048_g[]={
+                0x02,
+                };
+        DH *dh;
+
+        if ((dh=DH_new()) == nullptr) return(nullptr);
+        dh->p=BN_bin2bn(dh2048_p,(int)sizeof(dh2048_p),nullptr);
+        dh->g=BN_bin2bn(dh2048_g,(int)sizeof(dh2048_g),nullptr);
+        if ((dh->p == nullptr) || (dh->g == nullptr))
+                { DH_free(dh); return(nullptr); }
+        return(dh);
+        }
diff --git a/folly/experimental/wangle/ssl/PasswordInFile.cpp b/folly/experimental/wangle/ssl/PasswordInFile.cpp
new file mode 100644 (file)
index 0000000..c876c39
--- /dev/null
@@ -0,0 +1,31 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/experimental/wangle/ssl/PasswordInFile.h>
+
+#include <folly/FileUtil.h>
+
+using namespace std;
+
+namespace folly {
+
+PasswordInFile::PasswordInFile(const string& file)
+    : fileName_(file) {
+  folly::readFile(file.c_str(), password_);
+  auto p = password_.find('\0');
+  if (p != std::string::npos) {
+    password_.erase(p);
+  }
+}
+
+PasswordInFile::~PasswordInFile() {
+  OPENSSL_cleanse((char *)password_.data(), password_.length());
+}
+
+}
diff --git a/folly/experimental/wangle/ssl/PasswordInFile.h b/folly/experimental/wangle/ssl/PasswordInFile.h
new file mode 100644 (file)
index 0000000..b0a0922
--- /dev/null
@@ -0,0 +1,38 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/SSLContext.h> // PasswordCollector
+
+namespace folly {
+
+class PasswordInFile: public folly::PasswordCollector {
+ public:
+  explicit PasswordInFile(const std::string& file);
+  ~PasswordInFile();
+
+  void getPassword(std::string& password, int size) override {
+    password = password_;
+  }
+
+  const char* getPasswordStr() const {
+    return password_.c_str();
+  }
+
+  std::string describe() const override {
+    return fileName_;
+  }
+
+ protected:
+  std::string fileName_;
+  std::string password_;
+};
+
+}
diff --git a/folly/experimental/wangle/ssl/SSLCacheOptions.h b/folly/experimental/wangle/ssl/SSLCacheOptions.h
new file mode 100644 (file)
index 0000000..5617537
--- /dev/null
@@ -0,0 +1,23 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <chrono>
+#include <cstdint>
+
+namespace folly {
+
+struct SSLCacheOptions {
+  std::chrono::seconds sslCacheTimeout;
+  uint64_t maxSSLCacheSize;
+  uint64_t sslCacheFlushSize;
+};
+
+}
diff --git a/folly/experimental/wangle/ssl/SSLCacheProvider.h b/folly/experimental/wangle/ssl/SSLCacheProvider.h
new file mode 100644 (file)
index 0000000..feecca4
--- /dev/null
@@ -0,0 +1,69 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/AsyncSSLSocket.h>
+
+namespace folly {
+
+class SSLSessionCacheManager;
+
+/**
+ * Interface to be implemented by providers of external session caches
+ */
+class SSLCacheProvider {
+public:
+  /**
+   * Context saved during an external cache request that is used to
+   * resume the waiting client.
+   */
+  typedef struct {
+    std::string sessionId;
+    SSL_SESSION* session;
+    SSLSessionCacheManager* manager;
+    AsyncSSLSocket* sslSocket;
+    std::unique_ptr<
+      folly::DelayedDestruction::DestructorGuard> guard;
+  } CacheContext;
+
+  virtual ~SSLCacheProvider() {}
+
+  /**
+   * Store a session in the external cache.
+   * @param sessionId   Identifier that can be used later to fetch the
+   *                      session with getAsync()
+   * @param value       Serialized session to store
+   * @param expiration  Relative expiration time: seconds from now
+   * @return true if the storing of the session is initiated successfully
+   *         (though not necessarily completed; the completion may
+   *         happen either before or after this method returns), or
+   *         false if the storing cannot be initiated due to an error.
+   */
+  virtual bool setAsync(const std::string& sessionId,
+                        const std::string& value,
+                        std::chrono::seconds expiration) = 0;
+
+  /**
+   * Retrieve a session from the external cache. When done, call
+   * the cache manager's onGetSuccess() or onGetFailure() callback.
+   * @param sessionId   Session ID to fetch
+   * @param context     Data to pass back to the SSLSessionCacheManager
+   *                      in the completion callback
+   * @return true if the lookup of the session is initiated successfully
+   *         (though not necessarily completed; the completion may
+   *         happen either before or after this method returns), or
+   *         false if the lookup cannot be initiated due to an error.
+   */
+  virtual bool getAsync(const std::string& sessionId,
+                        CacheContext* context) = 0;
+
+};
+
+}
diff --git a/folly/experimental/wangle/ssl/SSLContextConfig.h b/folly/experimental/wangle/ssl/SSLContextConfig.h
new file mode 100644 (file)
index 0000000..bd3f804
--- /dev/null
@@ -0,0 +1,95 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <string>
+#include <folly/io/async/SSLContext.h>
+#include <vector>
+
+/**
+ * SSLContextConfig helps to describe the configs/options for
+ * a SSL_CTX. For example:
+ *
+ *   1. Filename of X509, private key and its password.
+ *   2. ciphers list
+ *   3. NPN list
+ *   4. Is session cache enabled?
+ *   5. Is it the default X509 in SNI operation?
+ *   6. .... and a few more
+ */
+namespace folly {
+
+struct SSLContextConfig {
+  SSLContextConfig() {}
+  ~SSLContextConfig() {}
+
+  struct CertificateInfo {
+    std::string certPath;
+    std::string keyPath;
+    std::string passwordPath;
+  };
+
+  /**
+   * Helpers to set/add a certificate
+   */
+  void setCertificate(const std::string& certPath,
+                      const std::string& keyPath,
+                      const std::string& passwordPath) {
+    certificates.clear();
+    addCertificate(certPath, keyPath, passwordPath);
+  }
+
+  void addCertificate(const std::string& certPath,
+                      const std::string& keyPath,
+                      const std::string& passwordPath) {
+    certificates.emplace_back(CertificateInfo{certPath, keyPath, passwordPath});
+  }
+
+  /**
+   * Set the optional list of protocols to advertise via TLS
+   * Next Protocol Negotiation. An empty list means NPN is not enabled.
+   */
+  void setNextProtocols(const std::list<std::string>& inNextProtocols) {
+    nextProtocols.clear();
+    nextProtocols.push_back({1, inNextProtocols});
+  }
+
+  typedef std::function<bool(char const* server_name)> SNINoMatchFn;
+
+  std::vector<CertificateInfo> certificates;
+  folly::SSLContext::SSLVersion sslVersion{
+    folly::SSLContext::TLSv1};
+  bool sessionCacheEnabled{true};
+  bool sessionTicketEnabled{true};
+  bool clientHelloParsingEnabled{false};
+  std::string sslCiphers{
+    "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:"
+    "ECDHE-ECDSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES128-GCM-SHA256:"
+    "ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-RSA-AES256-SHA:"
+    "AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA:AES256-SHA:"
+    "ECDHE-ECDSA-RC4-SHA:ECDHE-RSA-RC4-SHA:RC4-SHA:RC4-MD5:"
+    "ECDHE-RSA-DES-CBC3-SHA:DES-CBC3-SHA"};
+  std::string eccCurveName;
+  // Ciphers to negotiate if TLS version >= 1.1
+  std::string tls11Ciphers{""};
+  // Weighted lists of NPN strings to advertise
+  std::list<folly::SSLContext::NextProtocolsItem>
+      nextProtocols;
+  bool isLocalPrivateKey{true};
+  // Should this SSLContextConfig be the default for SNI purposes
+  bool isDefault{false};
+  // Callback function to invoke when there are no matching certificates
+  // (will only be invoked once)
+  SNINoMatchFn sniNoMatchFn;
+  // File containing trusted CA's to validate client certificates
+  std::string clientCAFile;
+};
+
+}
diff --git a/folly/experimental/wangle/ssl/SSLContextManager.cpp b/folly/experimental/wangle/ssl/SSLContextManager.cpp
new file mode 100644 (file)
index 0000000..eb9f126
--- /dev/null
@@ -0,0 +1,651 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/experimental/wangle/ssl/SSLContextManager.h>
+
+#include <folly/experimental/wangle/ssl/ClientHelloExtStats.h>
+#include <folly/experimental/wangle/ssl/DHParam.h>
+#include <folly/experimental/wangle/ssl/PasswordInFile.h>
+#include <folly/experimental/wangle/ssl/SSLCacheOptions.h>
+#include <folly/experimental/wangle/ssl/SSLSessionCacheManager.h>
+#include <folly/experimental/wangle/ssl/SSLUtil.h>
+#include <folly/experimental/wangle/ssl/TLSTicketKeyManager.h>
+#include <folly/experimental/wangle/ssl/TLSTicketKeySeeds.h>
+
+#include <folly/Conv.h>
+#include <folly/ScopeGuard.h>
+#include <folly/String.h>
+#include <functional>
+#include <openssl/asn1.h>
+#include <openssl/ssl.h>
+#include <string>
+#include <folly/io/async/EventBase.h>
+
+#define OPENSSL_MISSING_FEATURE(name) \
+do { \
+  throw std::runtime_error("missing " #name " support in openssl");  \
+} while(0)
+
+
+using std::string;
+using std::shared_ptr;
+
+/**
+ * SSLContextManager helps to create and manage all SSL_CTX,
+ * SSLSessionCacheManager and TLSTicketManager for a listening
+ * VIP:PORT. (Note, in SNI, a listening VIP:PORT can have >1 SSL_CTX(s)).
+ *
+ * Other responsibilities:
+ * 1. It also handles the SSL_CTX selection after getting the tlsext_hostname
+ *    in the client hello message.
+ *
+ * Usage:
+ * 1. Each listening VIP:PORT serving SSL should have one SSLContextManager.
+ *    It maps to Acceptor in the wangle vocabulary.
+ *
+ * 2. Create a SSLContextConfig object (e.g. by parsing the JSON config).
+ *
+ * 3. Call SSLContextManager::addSSLContextConfig() which will
+ *    then create and configure the SSL_CTX
+ *
+ * Note: Each Acceptor, with SSL support, should have one SSLContextManager to
+ * manage all SSL_CTX for the VIP:PORT.
+ */
+
+namespace folly {
+
+namespace {
+
+X509* getX509(SSL_CTX* ctx) {
+  SSL* ssl = SSL_new(ctx);
+  SSL_set_connect_state(ssl);
+  X509* x509 = SSL_get_certificate(ssl);
+  CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509);
+  SSL_free(ssl);
+  return x509;
+}
+
+void set_key_from_curve(SSL_CTX* ctx, const std::string& curveName) {
+#if OPENSSL_VERSION_NUMBER >= 0x0090800fL
+#ifndef OPENSSL_NO_ECDH
+  EC_KEY* ecdh = nullptr;
+  int nid;
+
+  /*
+   * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
+   * from RFC 4492 section 5.1.1, or explicitly described curves over
+   * binary fields. OpenSSL only supports the "named curves", which provide
+   * maximum interoperability.
+   */
+
+  nid = OBJ_sn2nid(curveName.c_str());
+  if (nid == 0) {
+    LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
+    return;
+  }
+  ecdh = EC_KEY_new_by_curve_name(nid);
+  if (ecdh == nullptr) {
+    LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
+    return;
+  }
+
+  SSL_CTX_set_tmp_ecdh(ctx, ecdh);
+  EC_KEY_free(ecdh);
+#endif
+#endif
+}
+
+// Helper to create TLSTicketKeyManger and aware of the needed openssl
+// version/feature.
+std::unique_ptr<TLSTicketKeyManager> createTicketManagerHelper(
+  std::shared_ptr<folly::SSLContext> ctx,
+  const TLSTicketKeySeeds* ticketSeeds,
+  const SSLContextConfig& ctxConfig,
+  SSLStats* stats) {
+
+  std::unique_ptr<TLSTicketKeyManager> ticketManager;
+#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
+  if (ticketSeeds && ctxConfig.sessionTicketEnabled) {
+    ticketManager = folly::make_unique<TLSTicketKeyManager>(ctx.get(), stats);
+    ticketManager->setTLSTicketKeySeeds(
+      ticketSeeds->oldSeeds,
+      ticketSeeds->currentSeeds,
+      ticketSeeds->newSeeds);
+  } else {
+    ctx->setOptions(SSL_OP_NO_TICKET);
+  }
+#else
+  if (ticketSeeds && ctxConfig.sessionTicketEnabled) {
+    OPENSSL_MISSING_FEATURE(TLSTicket);
+  }
+#endif
+  return ticketManager;
+}
+
+std::string flattenList(const std::list<std::string>& list) {
+  std::string s;
+  bool first = true;
+  for (auto& item : list) {
+    if (first) {
+      first = false;
+    } else {
+      s.append(", ");
+    }
+    s.append(item);
+  }
+  return s;
+}
+
+}
+
+SSLContextManager::~SSLContextManager() {}
+
+SSLContextManager::SSLContextManager(
+  EventBase* eventBase,
+  const std::string& vipName,
+  bool strict,
+  SSLStats* stats) :
+    stats_(stats),
+    eventBase_(eventBase),
+    strict_(strict) {
+}
+
+void SSLContextManager::addSSLContextConfig(
+  const SSLContextConfig& ctxConfig,
+  const SSLCacheOptions& cacheOptions,
+  const TLSTicketKeySeeds* ticketSeeds,
+  const folly::SocketAddress& vipAddress,
+  const std::shared_ptr<SSLCacheProvider>& externalCache) {
+
+  unsigned numCerts = 0;
+  std::string commonName;
+  std::string lastCertPath;
+  std::unique_ptr<std::list<std::string>> subjectAltName;
+  auto sslCtx = std::make_shared<SSLContext>(ctxConfig.sslVersion);
+  for (const auto& cert : ctxConfig.certificates) {
+    try {
+      sslCtx->loadCertificate(cert.certPath.c_str());
+    } catch (const std::exception& ex) {
+      // The exception isn't very useful without the certificate path name,
+      // so throw a new exception that includes the path to the certificate.
+      string msg = folly::to<string>("error loading SSL certificate ",
+                                     cert.certPath, ": ",
+                                     folly::exceptionStr(ex));
+      LOG(ERROR) << msg;
+      throw std::runtime_error(msg);
+    }
+
+    // Verify that the Common Name and (if present) Subject Alternative Names
+    // are the same for all the certs specified for the SSL context.
+    numCerts++;
+    X509* x509 = getX509(sslCtx->getSSLCtx());
+    auto guard = folly::makeGuard([x509] { X509_free(x509); });
+    auto cn = SSLUtil::getCommonName(x509);
+    if (!cn) {
+      throw std::runtime_error(folly::to<string>("Cannot get CN for X509 ",
+                                                 cert.certPath));
+    }
+    auto altName = SSLUtil::getSubjectAltName(x509);
+    VLOG(2) << "cert " << cert.certPath << " CN: " << *cn;
+    if (altName) {
+      altName->sort();
+      VLOG(2) << "cert " << cert.certPath << " SAN: " << flattenList(*altName);
+    } else {
+      VLOG(2) << "cert " << cert.certPath << " SAN: " << "{none}";
+    }
+    if (numCerts == 1) {
+      commonName = *cn;
+      subjectAltName = std::move(altName);
+    } else {
+      if (commonName != *cn) {
+        throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
+                                          " does not have same CN as ",
+                                          lastCertPath));
+      }
+      if (altName == nullptr) {
+        if (subjectAltName != nullptr) {
+          throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
+                                            " does not have same SAN as ",
+                                            lastCertPath));
+        }
+      } else {
+        if ((subjectAltName == nullptr) || (*altName != *subjectAltName)) {
+          throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
+                                            " does not have same SAN as ",
+                                            lastCertPath));
+        }
+      }
+    }
+    lastCertPath = cert.certPath;
+
+    // TODO t4438250 - Add ECDSA support to the crypto_ssl offload server
+    //                 so we can avoid storing the ECDSA private key in the
+    //                 address space of the Internet-facing process.  For
+    //                 now, if cert name includes "-EC" to denote elliptic
+    //                 curve, we load its private key even if the server as
+    //                 a whole has been configured for async crypto.
+    if (ctxConfig.isLocalPrivateKey ||
+        (cert.certPath.find("-EC") != std::string::npos)) {
+      // The private key lives in the same process
+
+      // This needs to be called before loadPrivateKey().
+      if (!cert.passwordPath.empty()) {
+        auto sslPassword = std::make_shared<PasswordInFile>(cert.passwordPath);
+        sslCtx->passwordCollector(sslPassword);
+      }
+
+      try {
+        sslCtx->loadPrivateKey(cert.keyPath.c_str());
+      } catch (const std::exception& ex) {
+        // Throw an error that includes the key path, so the user can tell
+        // which key had a problem.
+        string msg = folly::to<string>("error loading private SSL key ",
+                                       cert.keyPath, ": ",
+                                       folly::exceptionStr(ex));
+        LOG(ERROR) << msg;
+        throw std::runtime_error(msg);
+      }
+    }
+  }
+  if (!ctxConfig.isLocalPrivateKey) {
+    enableAsyncCrypto(sslCtx);
+  }
+
+  // Let the server pick the highest performing cipher from among the client's
+  // choices.
+  //
+  // Let's use a unique private key for all DH key exchanges.
+  //
+  // Because some old implementations choke on empty fragments, most SSL
+  // applications disable them (it's part of SSL_OP_ALL).  This
+  // will improve performance and decrease write buffer fragmentation.
+  sslCtx->setOptions(SSL_OP_CIPHER_SERVER_PREFERENCE |
+    SSL_OP_SINGLE_DH_USE |
+    SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS);
+
+  // Configure SSL ciphers list
+  if (!ctxConfig.tls11Ciphers.empty()) {
+    // FIXME: create a dummy SSL_CTX for cipher testing purpose? It can
+    //        remove the ordering dependency
+
+    // Test to see if the specified TLS1.1 ciphers are valid.  Note that
+    // these will be overwritten by the ciphers() call below.
+    sslCtx->setCiphersOrThrow(ctxConfig.tls11Ciphers);
+  }
+
+  // Important that we do this *after* checking the TLS1.1 ciphers above,
+  // since we test their validity by actually setting them.
+  sslCtx->ciphers(ctxConfig.sslCiphers);
+
+  // Use a fix DH param
+  DH* dh = get_dh2048();
+  SSL_CTX_set_tmp_dh(sslCtx->getSSLCtx(), dh);
+  DH_free(dh);
+
+  const string& curve = ctxConfig.eccCurveName;
+  if (!curve.empty()) {
+    set_key_from_curve(sslCtx->getSSLCtx(), curve);
+  }
+
+  if (!ctxConfig.clientCAFile.empty()) {
+    try {
+      sslCtx->setVerificationOption(SSLContext::VERIFY_REQ_CLIENT_CERT);
+      sslCtx->loadTrustedCertificates(ctxConfig.clientCAFile.c_str());
+      sslCtx->loadClientCAList(ctxConfig.clientCAFile.c_str());
+    } catch (const std::exception& ex) {
+      string msg = folly::to<string>("error loading client CA",
+                                     ctxConfig.clientCAFile, ": ",
+                                     folly::exceptionStr(ex));
+      LOG(ERROR) << msg;
+      throw std::runtime_error(msg);
+    }
+  }
+
+  // - start - SSL session cache config
+  // the internal cache never does what we want (per-thread-per-vip).
+  // Disable it.  SSLSessionCacheManager will set it appropriately.
+  SSL_CTX_set_session_cache_mode(sslCtx->getSSLCtx(), SSL_SESS_CACHE_OFF);
+  SSL_CTX_set_timeout(sslCtx->getSSLCtx(),
+                      cacheOptions.sslCacheTimeout.count());
+  std::unique_ptr<SSLSessionCacheManager> sessionCacheManager;
+  if (ctxConfig.sessionCacheEnabled &&
+      cacheOptions.maxSSLCacheSize > 0 &&
+      cacheOptions.sslCacheFlushSize > 0) {
+    sessionCacheManager =
+      folly::make_unique<SSLSessionCacheManager>(
+        cacheOptions.maxSSLCacheSize,
+        cacheOptions.sslCacheFlushSize,
+        sslCtx.get(),
+        vipAddress,
+        commonName,
+        eventBase_,
+        stats_,
+        externalCache);
+  }
+  // - end - SSL session cache config
+
+  std::unique_ptr<TLSTicketKeyManager> ticketManager =
+    createTicketManagerHelper(sslCtx, ticketSeeds, ctxConfig, stats_);
+
+  // finalize sslCtx setup by the individual features supported by openssl
+  ctxSetupByOpensslFeature(sslCtx, ctxConfig);
+
+  try {
+    insert(sslCtx,
+           std::move(sessionCacheManager),
+           std::move(ticketManager),
+           ctxConfig.isDefault);
+  } catch (const std::exception& ex) {
+    string msg = folly::to<string>("Error adding certificate : ",
+                                   folly::exceptionStr(ex));
+    LOG(ERROR) << msg;
+    throw std::runtime_error(msg);
+  }
+
+}
+
+#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
+SSLContext::ServerNameCallbackResult
+SSLContextManager::serverNameCallback(SSL* ssl) {
+  shared_ptr<SSLContext> ctx;
+
+  const char* sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+  if (!sn) {
+    VLOG(6) << "Server Name (tlsext_hostname) is missing";
+    if (clientHelloTLSExtStats_) {
+      clientHelloTLSExtStats_->recordAbsentHostname();
+    }
+    return SSLContext::SERVER_NAME_NOT_FOUND;
+  }
+  size_t snLen = strlen(sn);
+  VLOG(6) << "Server Name (SNI TLS extension): '" << sn << "' ";
+
+  // FIXME: This code breaks the abstraction. Suggestion?
+  AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
+  CHECK(sslSocket);
+
+  DNString dnstr(sn, snLen);
+
+  uint32_t count = 0;
+  do {
+    // Try exact match first
+    ctx = getSSLCtx(dnstr);
+    if (ctx) {
+      sslSocket->switchServerSSLContext(ctx);
+      if (clientHelloTLSExtStats_) {
+        clientHelloTLSExtStats_->recordMatch();
+      }
+      return SSLContext::SERVER_NAME_FOUND;
+    }
+
+    ctx = getSSLCtxBySuffix(dnstr);
+    if (ctx) {
+      sslSocket->switchServerSSLContext(ctx);
+      if (clientHelloTLSExtStats_) {
+        clientHelloTLSExtStats_->recordMatch();
+      }
+      return SSLContext::SERVER_NAME_FOUND;
+    }
+
+    // Give the noMatchFn one chance to add the correct cert
+  }
+  while (count++ == 0 && noMatchFn_ && noMatchFn_(sn));
+
+  VLOG(6) << folly::stringPrintf("Cannot find a SSL_CTX for \"%s\"", sn);
+
+  if (clientHelloTLSExtStats_) {
+    clientHelloTLSExtStats_->recordNotMatch();
+  }
+  return SSLContext::SERVER_NAME_NOT_FOUND;
+}
+#endif
+
+// Consolidate all SSL_CTX setup which depends on openssl version/feature
+void
+SSLContextManager::ctxSetupByOpensslFeature(
+  shared_ptr<folly::SSLContext> sslCtx,
+  const SSLContextConfig& ctxConfig) {
+  // Disable compression - profiling shows this to be very expensive in
+  // terms of CPU and memory consumption.
+  //
+#ifdef SSL_OP_NO_COMPRESSION
+  sslCtx->setOptions(SSL_OP_NO_COMPRESSION);
+#endif
+
+  // Enable early release of SSL buffers to reduce the memory footprint
+#ifdef SSL_MODE_RELEASE_BUFFERS
+ sslCtx->getSSLCtx()->mode |= SSL_MODE_RELEASE_BUFFERS;
+#endif
+#ifdef SSL_MODE_EARLY_RELEASE_BBIO
+  sslCtx->getSSLCtx()->mode |=  SSL_MODE_EARLY_RELEASE_BBIO;
+#endif
+
+  // This number should (probably) correspond to HTTPSession::kMaxReadSize
+  // For now, this number must also be large enough to accommodate our
+  // largest certificate, because some older clients (IE6/7) require the
+  // cert to be in a single fragment.
+#ifdef SSL_CTRL_SET_MAX_SEND_FRAGMENT
+  SSL_CTX_set_max_send_fragment(sslCtx->getSSLCtx(), 8000);
+#endif
+
+  // Specify cipher(s) to be used for TLS1.1 client
+  if (!ctxConfig.tls11Ciphers.empty()) {
+#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
+    // Specified TLS1.1 ciphers are valid
+    sslCtx->addClientHelloCallback(
+      std::bind(
+        &SSLContext::switchCiphersIfTLS11,
+        sslCtx.get(),
+        std::placeholders::_1,
+        ctxConfig.tls11Ciphers
+      )
+    );
+#else
+    OPENSSL_MISSING_FEATURE(SNI);
+#endif
+  }
+
+  // NPN (Next Protocol Negotiation)
+  if (!ctxConfig.nextProtocols.empty()) {
+#ifdef OPENSSL_NPN_NEGOTIATED
+    sslCtx->setRandomizedAdvertisedNextProtocols(ctxConfig.nextProtocols);
+#else
+    OPENSSL_MISSING_FEATURE(NPN);
+#endif
+  }
+
+  // SNI
+#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
+  noMatchFn_ = ctxConfig.sniNoMatchFn;
+  if (ctxConfig.isDefault) {
+    if (defaultCtx_) {
+      throw std::runtime_error(">1 X509 is set as default");
+    }
+
+    defaultCtx_ = sslCtx;
+    defaultCtx_->setServerNameCallback(
+      std::bind(&SSLContextManager::serverNameCallback, this,
+                std::placeholders::_1));
+  }
+#else
+  if (ctxs_.size() > 1) {
+    OPENSSL_MISSING_FEATURE(SNI);
+  }
+#endif
+}
+
+void
+SSLContextManager::insert(shared_ptr<SSLContext> sslCtx,
+                          std::unique_ptr<SSLSessionCacheManager> smanager,
+                          std::unique_ptr<TLSTicketKeyManager> tmanager,
+                          bool defaultFallback) {
+  X509* x509 = getX509(sslCtx->getSSLCtx());
+  auto guard = folly::makeGuard([x509] { X509_free(x509); });
+  auto cn = SSLUtil::getCommonName(x509);
+  if (!cn) {
+    throw std::runtime_error("Cannot get CN");
+  }
+
+  /**
+   * Some notes from RFC 2818. Only for future quick references in case of bugs
+   *
+   * RFC 2818 section 3.1:
+   * "......
+   * If a subjectAltName extension of type dNSName is present, that MUST
+   * be used as the identity. Otherwise, the (most specific) Common Name
+   * field in the Subject field of the certificate MUST be used. Although
+   * the use of the Common Name is existing practice, it is deprecated and
+   * Certification Authorities are encouraged to use the dNSName instead.
+   * ......
+   * In some cases, the URI is specified as an IP address rather than a
+   * hostname. In this case, the iPAddress subjectAltName must be present
+   * in the certificate and must exactly match the IP in the URI.
+   * ......"
+   */
+
+  // Not sure if we ever get this kind of X509...
+  // If we do, assume '*' is always in the CN and ignore all subject alternative
+  // names.
+  if (cn->length() == 1 && (*cn)[0] == '*') {
+    if (!defaultFallback) {
+      throw std::runtime_error("STAR X509 is not the default");
+    }
+    ctxs_.emplace_back(sslCtx);
+    sessionCacheManagers_.emplace_back(std::move(smanager));
+    ticketManagers_.emplace_back(std::move(tmanager));
+    return;
+  }
+
+  // Insert by CN
+  insertSSLCtxByDomainName(cn->c_str(), cn->length(), sslCtx);
+
+  // Insert by subject alternative name(s)
+  auto altNames = SSLUtil::getSubjectAltName(x509);
+  if (altNames) {
+    for (auto& name : *altNames) {
+      insertSSLCtxByDomainName(name.c_str(), name.length(), sslCtx);
+    }
+  }
+
+  ctxs_.emplace_back(sslCtx);
+  sessionCacheManagers_.emplace_back(std::move(smanager));
+  ticketManagers_.emplace_back(std::move(tmanager));
+}
+
+void
+SSLContextManager::insertSSLCtxByDomainName(const char* dn, size_t len,
+                                            shared_ptr<SSLContext> sslCtx) {
+  try {
+    insertSSLCtxByDomainNameImpl(dn, len, sslCtx);
+  } catch (const std::runtime_error& ex) {
+    if (strict_) {
+      throw ex;
+    } else {
+      LOG(ERROR) << ex.what() << " DN=" << dn;
+    }
+  }
+}
+void
+SSLContextManager::insertSSLCtxByDomainNameImpl(const char* dn, size_t len,
+                                                shared_ptr<SSLContext> sslCtx)
+{
+  VLOG(4) <<
+    folly::stringPrintf("Adding CN/Subject-alternative-name \"%s\" for "
+                        "SNI search", dn);
+
+  // Only support wildcard domains which are prefixed exactly by "*." .
+  // "*" appearing at other locations is not accepted.
+
+  if (len > 2 && dn[0] == '*') {
+    if (dn[1] == '.') {
+      // skip the first '*'
+      dn++;
+      len--;
+    } else {
+      throw std::runtime_error(
+        "Invalid wildcard CN/subject-alternative-name \"" + std::string(dn) + "\" "
+        "(only allow character \".\" after \"*\"");
+    }
+  }
+
+  if (len == 1 && *dn == '.') {
+    throw std::runtime_error("X509 has only '.' in the CN or subject alternative name "
+                    "(after removing any preceding '*')");
+  }
+
+  if (strchr(dn, '*')) {
+    throw std::runtime_error("X509 has '*' in the the CN or subject alternative name "
+                    "(after removing any preceding '*')");
+  }
+
+  DNString dnstr(dn, len);
+  const auto v = dnMap_.find(dnstr);
+  if (v == dnMap_.end()) {
+    dnMap_.emplace(dnstr, sslCtx);
+  } else if (v->second == sslCtx) {
+    VLOG(6)<< "Duplicate CN or subject alternative name found in the same X509."
+      "  Ignore the later name.";
+  } else {
+    throw std::runtime_error("Duplicate CN or subject alternative name found: \"" +
+                             std::string(dnstr.c_str()) + "\"");
+  }
+}
+
+shared_ptr<SSLContext>
+SSLContextManager::getSSLCtxBySuffix(const DNString& dnstr) const
+{
+  size_t dot;
+
+  if ((dot = dnstr.find_first_of(".")) != DNString::npos) {
+    DNString suffixDNStr(dnstr, dot);
+    const auto v = dnMap_.find(suffixDNStr);
+    if (v != dnMap_.end()) {
+      VLOG(6) << folly::stringPrintf("\"%s\" is a willcard match to \"%s\"",
+                                     dnstr.c_str(), suffixDNStr.c_str());
+      return v->second;
+    }
+  }
+
+  VLOG(6) << folly::stringPrintf("\"%s\" is not a wildcard match",
+                                 dnstr.c_str());
+  return shared_ptr<SSLContext>();
+}
+
+shared_ptr<SSLContext>
+SSLContextManager::getSSLCtx(const DNString& dnstr) const
+{
+  const auto v = dnMap_.find(dnstr);
+  if (v == dnMap_.end()) {
+    VLOG(6) << folly::stringPrintf("\"%s\" is not an exact match",
+                                   dnstr.c_str());
+    return shared_ptr<SSLContext>();
+  } else {
+    VLOG(6) << folly::stringPrintf("\"%s\" is an exact match", dnstr.c_str());
+    return v->second;
+  }
+}
+
+shared_ptr<SSLContext>
+SSLContextManager::getDefaultSSLCtx() const {
+  return defaultCtx_;
+}
+
+void
+SSLContextManager::reloadTLSTicketKeys(
+  const std::vector<std::string>& oldSeeds,
+  const std::vector<std::string>& currentSeeds,
+  const std::vector<std::string>& newSeeds) {
+#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
+  for (auto& tmgr: ticketManagers_) {
+    tmgr->setTLSTicketKeySeeds(oldSeeds, currentSeeds, newSeeds);
+  }
+#endif
+}
+
+} // namespace
diff --git a/folly/experimental/wangle/ssl/SSLContextManager.h b/folly/experimental/wangle/ssl/SSLContextManager.h
new file mode 100644 (file)
index 0000000..2650649
--- /dev/null
@@ -0,0 +1,182 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/EventBase.h>
+#include <folly/io/async/SSLContext.h>
+
+#include <glog/logging.h>
+#include <list>
+#include <memory>
+#include <folly/experimental/wangle/ssl/SSLContextConfig.h>
+#include <folly/experimental/wangle/ssl/SSLSessionCacheManager.h>
+#include <folly/experimental/wangle/ssl/TLSTicketKeySeeds.h>
+#include <folly/experimental/wangle/acceptor/DomainNameMisc.h>
+#include <vector>
+
+namespace folly {
+
+class SocketAddress;
+class SSLContext;
+class ClientHelloExtStats;
+class SSLCacheOptions;
+class SSLStats;
+class TLSTicketKeyManager;
+class TLSTicketKeySeeds;
+
+class SSLContextManager {
+ public:
+
+  explicit SSLContextManager(EventBase* eventBase,
+                             const std::string& vipName, bool strict,
+                             SSLStats* stats);
+  virtual ~SSLContextManager();
+
+  /**
+   * Add a new X509 to SSLContextManager.  The details of a X509
+   * is passed as a SSLContextConfig object.
+   *
+   * @param ctxConfig     Details of a X509, its private key, password, etc.
+   * @param cacheOptions  Options for how to do session caching.
+   * @param ticketSeeds   If non-null, the initial ticket key seeds to use.
+   * @param vipAddress    Which VIP are the X509(s) used for? It is only for
+   *                      for user friendly log message
+   * @param externalCache Optional external provider for the session cache;
+   *                      may be null
+   */
+  void addSSLContextConfig(
+    const SSLContextConfig& ctxConfig,
+    const SSLCacheOptions& cacheOptions,
+    const TLSTicketKeySeeds* ticketSeeds,
+    const folly::SocketAddress& vipAddress,
+    const std::shared_ptr<SSLCacheProvider> &externalCache);
+
+  /**
+   * Get the default SSL_CTX for a VIP
+   */
+  std::shared_ptr<SSLContext>
+    getDefaultSSLCtx() const;
+
+  /**
+   * Search by the _one_ level up subdomain
+   */
+  std::shared_ptr<SSLContext>
+    getSSLCtxBySuffix(const DNString& dnstr) const;
+
+  /**
+   * Search by the full-string domain name
+   */
+  std::shared_ptr<SSLContext>
+    getSSLCtx(const DNString& dnstr) const;
+
+  /**
+   * Insert a SSLContext by domain name.
+   */
+  void insertSSLCtxByDomainName(
+    const char* dn,
+    size_t len,
+    std::shared_ptr<SSLContext> sslCtx);
+
+  void insertSSLCtxByDomainNameImpl(
+    const char* dn,
+    size_t len,
+    std::shared_ptr<SSLContext> sslCtx);
+
+  void reloadTLSTicketKeys(const std::vector<std::string>& oldSeeds,
+                           const std::vector<std::string>& currentSeeds,
+                           const std::vector<std::string>& newSeeds);
+
+  /**
+   * SSLContextManager only collects SNI stats now
+   */
+
+  void setClientHelloExtStats(ClientHelloExtStats* stats) {
+    clientHelloTLSExtStats_ = stats;
+  }
+
+ protected:
+  virtual void enableAsyncCrypto(
+    const std::shared_ptr<SSLContext>& sslCtx) {
+    LOG(FATAL) << "Unsupported in base SSLContextManager";
+  }
+  SSLStats* stats_{nullptr};
+
+ private:
+  SSLContextManager(const SSLContextManager&) = delete;
+
+  void ctxSetupByOpensslFeature(
+    std::shared_ptr<SSLContext> sslCtx,
+    const SSLContextConfig& ctxConfig);
+
+  /**
+   * Callback function from openssl to find the right X509 to
+   * use during SSL handshake
+   */
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && \
+    !defined(OPENSSL_NO_TLSEXT) && \
+    defined(SSL_CTRL_SET_TLSEXT_SERVERNAME_CB)
+# define PROXYGEN_HAVE_SERVERNAMECALLBACK
+  SSLContext::ServerNameCallbackResult
+    serverNameCallback(SSL* ssl);
+#endif
+
+  /**
+   * The following functions help to maintain the data structure for
+   * domain name matching in SNI.  Some notes:
+   *
+   * 1. It is a best match.
+   *
+   * 2. It allows wildcard CN and wildcard subject alternative name in a X509.
+   *    The wildcard name must be _prefixed_ by '*.'.  It errors out whenever
+   *    it sees '*' in any other locations.
+   *
+   * 3. It uses one std::unordered_map<DomainName, SSL_CTX> object to
+   *    do this.  For wildcard name like "*.facebook.com", ".facebook.com"
+   *    is used as the key.
+   *
+   * 4. After getting tlsext_hostname from the client hello message, it
+   *    will do a full string search first and then try one level up to
+   *    match any wildcard name (if any) in the X509.
+   *    [Note, browser also only looks one level up when matching the requesting
+   *     domain name with the wildcard name in the server X509].
+   */
+
+  void insert(
+    std::shared_ptr<SSLContext> sslCtx,
+    std::unique_ptr<SSLSessionCacheManager> cmanager,
+    std::unique_ptr<TLSTicketKeyManager> tManager,
+    bool defaultFallback);
+
+  /**
+   * Container to own the SSLContext, SSLSessionCacheManager and
+   * TLSTicketKeyManager.
+   */
+  std::vector<std::shared_ptr<SSLContext>> ctxs_;
+  std::vector<std::unique_ptr<SSLSessionCacheManager>>
+    sessionCacheManagers_;
+  std::vector<std::unique_ptr<TLSTicketKeyManager>> ticketManagers_;
+
+  std::shared_ptr<SSLContext> defaultCtx_;
+
+  /**
+   * Container to store the (DomainName -> SSL_CTX) mapping
+   */
+  std::unordered_map<
+    DNString,
+    std::shared_ptr<SSLContext>,
+    DNStringHash> dnMap_;
+
+  EventBase* eventBase_;
+  ClientHelloExtStats* clientHelloTLSExtStats_{nullptr};
+  SSLContextConfig::SNINoMatchFn noMatchFn_;
+  bool strict_{true};
+};
+
+} // namespace
diff --git a/folly/experimental/wangle/ssl/SSLSessionCacheManager.cpp b/folly/experimental/wangle/ssl/SSLSessionCacheManager.cpp
new file mode 100644 (file)
index 0000000..fc339a1
--- /dev/null
@@ -0,0 +1,350 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/experimental/wangle/ssl/SSLSessionCacheManager.h>
+
+#include <folly/experimental/wangle/ssl/SSLCacheProvider.h>
+#include <folly/experimental/wangle/ssl/SSLStats.h>
+#include <folly/experimental/wangle/ssl/SSLUtil.h>
+
+#include <folly/io/async/EventBase.h>
+
+using std::string;
+using std::shared_ptr;
+
+namespace {
+
+const uint32_t NUM_CACHE_BUCKETS = 16;
+
+// We use the default ID generator which fills the maximum ID length
+// for the protocol.  16 bytes for SSLv2 or 32 for SSLv3+
+const int MIN_SESSION_ID_LENGTH = 16;
+
+}
+
+#ifndef NO_LIB_GFLAGS
+DEFINE_bool(dcache_unit_test, false, "All VIPs share one session cache");
+#else
+const bool FLAGS_dcache_unit_test = false;
+#endif
+
+namespace folly {
+
+
+int SSLSessionCacheManager::sExDataIndex_ = -1;
+shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::sCache_;
+std::mutex SSLSessionCacheManager::sCacheLock_;
+
+LocalSSLSessionCache::LocalSSLSessionCache(uint32_t maxCacheSize,
+                                           uint32_t cacheCullSize)
+    : sessionCache(maxCacheSize, cacheCullSize) {
+  sessionCache.setPruneHook(std::bind(
+                              &LocalSSLSessionCache::pruneSessionCallback,
+                              this, std::placeholders::_1,
+                              std::placeholders::_2));
+}
+
+void LocalSSLSessionCache::pruneSessionCallback(const string& sessionId,
+                                                SSL_SESSION* session) {
+  VLOG(4) << "Free SSL session from local cache; id="
+          << SSLUtil::hexlify(sessionId);
+  SSL_SESSION_free(session);
+  ++removedSessions_;
+}
+
+
+// SSLSessionCacheManager implementation
+
+SSLSessionCacheManager::SSLSessionCacheManager(
+  uint32_t maxCacheSize,
+  uint32_t cacheCullSize,
+  SSLContext* ctx,
+  const folly::SocketAddress& sockaddr,
+  const string& context,
+  EventBase* eventBase,
+  SSLStats* stats,
+  const std::shared_ptr<SSLCacheProvider>& externalCache):
+    ctx_(ctx),
+    stats_(stats),
+    externalCache_(externalCache) {
+
+  SSL_CTX* sslCtx = ctx->getSSLCtx();
+
+  SSLUtil::getSSLCtxExIndex(&sExDataIndex_);
+
+  SSL_CTX_set_ex_data(sslCtx, sExDataIndex_, this);
+  SSL_CTX_sess_set_new_cb(sslCtx, SSLSessionCacheManager::newSessionCallback);
+  SSL_CTX_sess_set_get_cb(sslCtx, SSLSessionCacheManager::getSessionCallback);
+  SSL_CTX_sess_set_remove_cb(sslCtx,
+                             SSLSessionCacheManager::removeSessionCallback);
+  if (!FLAGS_dcache_unit_test && !context.empty()) {
+    // Use the passed in context
+    SSL_CTX_set_session_id_context(sslCtx, (const uint8_t *)context.data(),
+                                   std::min((int)context.length(),
+                                            SSL_MAX_SSL_SESSION_ID_LENGTH));
+  }
+
+  SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_NO_INTERNAL
+                                 | SSL_SESS_CACHE_SERVER);
+
+  localCache_ = SSLSessionCacheManager::getLocalCache(maxCacheSize,
+                                                      cacheCullSize);
+
+  VLOG(2) << "On VipID=" << sockaddr.describe() << " context=" << context;
+}
+
+SSLSessionCacheManager::~SSLSessionCacheManager() {
+}
+
+void SSLSessionCacheManager::shutdown() {
+  std::lock_guard<std::mutex> g(sCacheLock_);
+  sCache_.reset();
+}
+
+shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::getLocalCache(
+  uint32_t maxCacheSize,
+  uint32_t cacheCullSize) {
+
+  std::lock_guard<std::mutex> g(sCacheLock_);
+  if (!sCache_) {
+    sCache_.reset(new ShardedLocalSSLSessionCache(NUM_CACHE_BUCKETS,
+                                                  maxCacheSize,
+                                                  cacheCullSize));
+  }
+  return sCache_;
+}
+
+int SSLSessionCacheManager::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
+  SSLSessionCacheManager* manager = nullptr;
+  SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
+  manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
+
+  if (manager == nullptr) {
+    LOG(FATAL) << "Null SSLSessionCacheManager in callback";
+    return -1;
+  }
+  return manager->newSession(ssl, session);
+}
+
+
+int SSLSessionCacheManager::newSession(SSL* ssl, SSL_SESSION* session) {
+  string sessionId((char*)session->session_id, session->session_id_length);
+  VLOG(4) << "New SSL session; id=" << SSLUtil::hexlify(sessionId);
+
+  if (stats_) {
+    stats_->recordSSLSession(true /* new session */, false, false);
+  }
+
+  localCache_->storeSession(sessionId, session, stats_);
+
+  if (externalCache_) {
+    VLOG(4) << "New SSL session: send session to external cache; id=" <<
+      SSLUtil::hexlify(sessionId);
+    storeCacheRecord(sessionId, session);
+  }
+
+  return 1;
+}
+
+void SSLSessionCacheManager::removeSessionCallback(SSL_CTX* ctx,
+                                                   SSL_SESSION* session) {
+  SSLSessionCacheManager* manager = nullptr;
+  manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
+
+  if (manager == nullptr) {
+    LOG(FATAL) << "Null SSLSessionCacheManager in callback";
+    return;
+  }
+  return manager->removeSession(ctx, session);
+}
+
+void SSLSessionCacheManager::removeSession(SSL_CTX* ctx,
+                                           SSL_SESSION* session) {
+  string sessionId((char*)session->session_id, session->session_id_length);
+
+  // This hook is only called from SSL when the internal session cache needs to
+  // flush sessions.  Since we run with the internal cache disabled, this should
+  // never be called
+  VLOG(3) << "Remove SSL session; id=" << SSLUtil::hexlify(sessionId);
+
+  localCache_->removeSession(sessionId);
+
+  if (stats_) {
+    stats_->recordSSLSessionRemove();
+  }
+}
+
+SSL_SESSION* SSLSessionCacheManager::getSessionCallback(SSL* ssl,
+                                                        unsigned char* sess_id,
+                                                        int id_len,
+                                                        int* copyflag) {
+  SSLSessionCacheManager* manager = nullptr;
+  SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
+  manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
+
+  if (manager == nullptr) {
+    LOG(FATAL) << "Null SSLSessionCacheManager in callback";
+    return nullptr;
+  }
+  return manager->getSession(ssl, sess_id, id_len, copyflag);
+}
+
+SSL_SESSION* SSLSessionCacheManager::getSession(SSL* ssl,
+                                                unsigned char* session_id,
+                                                int id_len,
+                                                int* copyflag) {
+  VLOG(7) << "SSL get session callback";
+  SSL_SESSION* session = nullptr;
+  bool foreign = false;
+  char const* missReason = nullptr;
+
+  if (id_len < MIN_SESSION_ID_LENGTH) {
+    // We didn't generate this session so it's going to be a miss.
+    // This doesn't get logged or counted in the stats.
+    return nullptr;
+  }
+  string sessionId((char*)session_id, id_len);
+
+  AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
+
+  assert(sslSocket != nullptr);
+
+  // look it up in the local cache first
+  session = localCache_->lookupSession(sessionId);
+#ifdef SSL_SESSION_CB_WOULD_BLOCK
+  if (session == nullptr && externalCache_) {
+    // external cache might have the session
+    foreign = true;
+    if (!SSL_want_sess_cache_lookup(ssl)) {
+      missReason = "reason: No async cache support;";
+    } else {
+      PendingLookupMap::iterator pit = pendingLookups_.find(sessionId);
+      if (pit == pendingLookups_.end()) {
+        auto result = pendingLookups_.emplace(sessionId, PendingLookup());
+        // initiate fetch
+        VLOG(4) << "Get SSL session [Pending]: Initiate Fetch; fd=" <<
+          sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
+        if (lookupCacheRecord(sessionId, sslSocket)) {
+          // response is pending
+          *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
+          return nullptr;
+        } else {
+          missReason = "reason: failed to send lookup request;";
+          pendingLookups_.erase(result.first);
+        }
+      } else {
+        // A lookup was already initiated from this thread
+        if (pit->second.request_in_progress) {
+          // Someone else initiated the request, attach
+          VLOG(4) << "Get SSL session [Pending]: Request in progess: attach; "
+            "fd=" << sslSocket->getFd() << " id=" <<
+            SSLUtil::hexlify(sessionId);
+          std::unique_ptr<DelayedDestruction::DestructorGuard> dg(
+            new DelayedDestruction::DestructorGuard(sslSocket));
+          pit->second.waiters.push_back(
+            std::make_pair(sslSocket, std::move(dg)));
+          *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
+          return nullptr;
+        }
+        // request is complete
+        session = pit->second.session; // nullptr if our friend didn't have it
+        if (session != nullptr) {
+          CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
+        }
+      }
+    }
+  }
+#endif
+
+  bool hit = (session != nullptr);
+  if (stats_) {
+    stats_->recordSSLSession(false, hit, foreign);
+  }
+  if (hit) {
+    sslSocket->setSessionIDResumed(true);
+  }
+
+  VLOG(4) << "Get SSL session [" <<
+    ((hit) ? "Hit" : "Miss") << "]: " <<
+    ((foreign) ? "external" : "local") << " cache; " <<
+    ((missReason != nullptr) ? missReason : "") << "fd=" <<
+    sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
+
+  // We already bumped the refcount
+  *copyflag = 0;
+
+  return session;
+}
+
+bool SSLSessionCacheManager::storeCacheRecord(const string& sessionId,
+                                              SSL_SESSION* session) {
+  std::string sessionString;
+  uint32_t sessionLen = i2d_SSL_SESSION(session, nullptr);
+  sessionString.resize(sessionLen);
+  uint8_t* cp = (uint8_t *)sessionString.data();
+  i2d_SSL_SESSION(session, &cp);
+  size_t expiration = SSL_CTX_get_timeout(ctx_->getSSLCtx());
+  return externalCache_->setAsync(sessionId, sessionString,
+                                  std::chrono::seconds(expiration));
+}
+
+bool SSLSessionCacheManager::lookupCacheRecord(const string& sessionId,
+                                               AsyncSSLSocket* sslSocket) {
+  auto cacheCtx = new SSLCacheProvider::CacheContext();
+  cacheCtx->sessionId = sessionId;
+  cacheCtx->session = nullptr;
+  cacheCtx->sslSocket = sslSocket;
+  cacheCtx->guard.reset(
+      new DelayedDestruction::DestructorGuard(cacheCtx->sslSocket));
+  cacheCtx->manager = this;
+  bool res = externalCache_->getAsync(sessionId, cacheCtx);
+  if (!res) {
+    delete cacheCtx;
+  }
+  return res;
+}
+
+void SSLSessionCacheManager::restartSSLAccept(
+    const SSLCacheProvider::CacheContext* cacheCtx) {
+  PendingLookupMap::iterator pit = pendingLookups_.find(cacheCtx->sessionId);
+  CHECK(pit != pendingLookups_.end());
+  pit->second.request_in_progress = false;
+  pit->second.session = cacheCtx->session;
+  VLOG(7) << "Restart SSL accept";
+  cacheCtx->sslSocket->restartSSLAccept();
+  for (const auto& attachedLookup: pit->second.waiters) {
+    // Wake up anyone else who was waiting for this session
+    VLOG(4) << "Restart SSL accept (waiters) for fd=" <<
+      attachedLookup.first->getFd();
+    attachedLookup.first->restartSSLAccept();
+  }
+  pendingLookups_.erase(pit);
+}
+
+void SSLSessionCacheManager::onGetSuccess(
+    SSLCacheProvider::CacheContext* cacheCtx,
+    const std::string& value) {
+  const uint8_t* cp = (uint8_t*)value.data();
+  cacheCtx->session = d2i_SSL_SESSION(nullptr, &cp, value.length());
+  restartSSLAccept(cacheCtx);
+
+  /* Insert in the LRU after restarting all clients.  The stats logic
+   * in getSession would treat this as a local hit otherwise.
+   */
+  localCache_->storeSession(cacheCtx->sessionId, cacheCtx->session, stats_);
+  delete cacheCtx;
+}
+
+void SSLSessionCacheManager::onGetFailure(
+    SSLCacheProvider::CacheContext* cacheCtx) {
+  restartSSLAccept(cacheCtx);
+  delete cacheCtx;
+}
+
+} // namespace
diff --git a/folly/experimental/wangle/ssl/SSLSessionCacheManager.h b/folly/experimental/wangle/ssl/SSLSessionCacheManager.h
new file mode 100644 (file)
index 0000000..f9c9e5d
--- /dev/null
@@ -0,0 +1,292 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/experimental/wangle/ssl/SSLCacheProvider.h>
+#include <folly/experimental/wangle/ssl/SSLStats.h>
+
+#include <folly/EvictingCacheMap.h>
+#include <mutex>
+#include <folly/io/async/AsyncSSLSocket.h>
+
+namespace folly {
+
+class SSLStats;
+
+/**
+ * Basic SSL session cache map: Maps session id -> session
+ */
+typedef folly::EvictingCacheMap<std::string, SSL_SESSION*> SSLSessionCacheMap;
+
+/**
+ * Holds an SSLSessionCacheMap and associated lock
+ */
+class LocalSSLSessionCache: private boost::noncopyable {
+ public:
+  LocalSSLSessionCache(uint32_t maxCacheSize, uint32_t cacheCullSize);
+
+  ~LocalSSLSessionCache() {
+    std::lock_guard<std::mutex> g(lock);
+    // EvictingCacheMap dtor doesn't free values
+    sessionCache.clear();
+  }
+
+  SSLSessionCacheMap sessionCache;
+  std::mutex lock;
+  uint32_t removedSessions_{0};
+
+ private:
+
+  void pruneSessionCallback(const std::string& sessionId,
+                            SSL_SESSION* session);
+};
+
+/**
+ * A sharded LRU for SSL sessions.  The sharding is inteneded to reduce
+ * contention for the LRU locks.  Assuming uniform distribution, two workers
+ * will contend for the same lock with probability 1 / n_buckets^2.
+ */
+class ShardedLocalSSLSessionCache : private boost::noncopyable {
+ public:
+  ShardedLocalSSLSessionCache(uint32_t n_buckets, uint32_t maxCacheSize,
+                              uint32_t cacheCullSize) {
+    CHECK(n_buckets > 0);
+    maxCacheSize = (uint32_t)(((double)maxCacheSize) / n_buckets);
+    cacheCullSize = (uint32_t)(((double)cacheCullSize) / n_buckets);
+    if (maxCacheSize == 0) {
+      maxCacheSize = 1;
+    }
+    if (cacheCullSize == 0) {
+      cacheCullSize = 1;
+    }
+    for (uint32_t i = 0; i < n_buckets; i++) {
+      caches_.push_back(
+        std::unique_ptr<LocalSSLSessionCache>(
+          new LocalSSLSessionCache(maxCacheSize, cacheCullSize)));
+    }
+  }
+
+  SSL_SESSION* lookupSession(const std::string& sessionId) {
+    size_t bucket = hash(sessionId);
+    SSL_SESSION* session = nullptr;
+    std::lock_guard<std::mutex> g(caches_[bucket]->lock);
+
+    auto itr = caches_[bucket]->sessionCache.find(sessionId);
+    if (itr != caches_[bucket]->sessionCache.end()) {
+      session = itr->second;
+    }
+
+    if (session) {
+      CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
+    }
+    return session;
+  }
+
+  void storeSession(const std::string& sessionId, SSL_SESSION* session,
+                    SSLStats* stats) {
+    size_t bucket = hash(sessionId);
+    SSL_SESSION* oldSession = nullptr;
+    std::lock_guard<std::mutex> g(caches_[bucket]->lock);
+
+    auto itr = caches_[bucket]->sessionCache.find(sessionId);
+    if (itr != caches_[bucket]->sessionCache.end()) {
+      oldSession = itr->second;
+    }
+
+    if (oldSession) {
+      // LRUCacheMap doesn't free on overwrite, so 2x the work for us
+      // This can happen in race conditions
+      SSL_SESSION_free(oldSession);
+    }
+    caches_[bucket]->removedSessions_ = 0;
+    caches_[bucket]->sessionCache.set(sessionId, session, true);
+    if (stats) {
+      stats->recordSSLSessionFree(caches_[bucket]->removedSessions_);
+    }
+  }
+
+  void removeSession(const std::string& sessionId) {
+    size_t bucket = hash(sessionId);
+    std::lock_guard<std::mutex> g(caches_[bucket]->lock);
+    caches_[bucket]->sessionCache.erase(sessionId);
+  }
+
+ private:
+
+  /* SSL session IDs are 32 bytes of random data, hash based on first 16 bits */
+  size_t hash(const std::string& key) {
+    CHECK(key.length() >= 2);
+    return (key[0] << 8 | key[1]) % caches_.size();
+  }
+
+  std::vector< std::unique_ptr<LocalSSLSessionCache> > caches_;
+};
+
+/* A socket/DestructorGuard pair */
+typedef std::pair<AsyncSSLSocket *,
+                  std::unique_ptr<DelayedDestruction::DestructorGuard>>
+  AttachedLookup;
+
+/**
+ * PendingLookup structure
+ *
+ * Keeps track of clients waiting for an SSL session to be retrieved from
+ * the external cache provider.
+ */
+struct PendingLookup {
+  bool request_in_progress;
+  SSL_SESSION* session;
+  std::list<AttachedLookup> waiters;
+
+  PendingLookup() {
+    request_in_progress = true;
+    session = nullptr;
+  }
+};
+
+/* Maps SSL session id to a PendingLookup structure */
+typedef std::map<std::string, PendingLookup> PendingLookupMap;
+
+/**
+ * SSLSessionCacheManager handles all stateful session caching.  There is an
+ * instance of this object per SSL VIP per thread, with a 1:1 correlation with
+ * SSL_CTX.  The cache can work locally or in concert with an external cache
+ * to share sessions across instances.
+ *
+ * There is a single in memory session cache shared by all VIPs.  The cache is
+ * split into N buckets (currently 16) with a separate lock per bucket.  The
+ * VIP ID is hashed and stored as part of the session to handle the
+ * (very unlikely) case of session ID collision.
+ *
+ * When a new SSL session is created, it is added to the LRU cache and
+ * sent to the external cache to be stored.  The external cache
+ * expiration is equal to the SSL session's expiration.
+ *
+ * When a resume request is received, SSLSessionCacheManager first looks in the
+ * local LRU cache for the VIP.  If there is a miss there, an asynchronous
+ * request for this session is dispatched to the external cache.  When the
+ * external cache query returns, the LRU cache is updated if the session was
+ * found, and the SSL_accept call is resumed.
+ *
+ * If additional resume requests for the same session ID arrive in the same
+ * thread while the request is pending, the 2nd - Nth callers attach to the
+ * original external cache requests and are resumed when it comes back.  No
+ * attempt is made to coalesce external cache requests for the same session
+ * ID in different worker threads.  Previous work did this, but the
+ * complexity was deemed to outweigh the potential savings.
+ *
+ */
+class SSLSessionCacheManager : private boost::noncopyable {
+ public:
+  /**
+   * Constructor.  SSL session related callbacks will be set on the underlying
+   * SSL_CTX.  vipId is assumed to a unique string identifying the VIP and must
+   * be the same on all servers that wish to share sessions via the same
+   * external cache.
+   */
+  SSLSessionCacheManager(
+    uint32_t maxCacheSize,
+    uint32_t cacheCullSize,
+    SSLContext* ctx,
+    const folly::SocketAddress& sockaddr,
+    const std::string& context,
+    EventBase* eventBase,
+    SSLStats* stats,
+    const std::shared_ptr<SSLCacheProvider>& externalCache);
+
+  virtual ~SSLSessionCacheManager();
+
+  /**
+   * Call this on shutdown to release the global instance of the
+   * ShardedLocalSSLSessionCache.
+   */
+  static void shutdown();
+
+  /**
+   * Callback for ExternalCache to call when an async get succeeds
+   * @param context  The context that was passed to the async get request
+   * @param value    Serialized session
+   */
+  void onGetSuccess(SSLCacheProvider::CacheContext* context,
+                    const std::string& value);
+
+  /**
+   * Callback for ExternalCache to call when an async get fails, either
+   * because the requested session is not in the external cache or because
+   * of an error.
+   * @param context  The context that was passed to the async get request
+   */
+  void onGetFailure(SSLCacheProvider::CacheContext* context);
+
+ private:
+
+  SSLContext* ctx_;
+  std::shared_ptr<ShardedLocalSSLSessionCache> localCache_;
+  PendingLookupMap pendingLookups_;
+  SSLStats* stats_{nullptr};
+  std::shared_ptr<SSLCacheProvider> externalCache_;
+
+  /**
+   * Invoked by openssl when a new SSL session is created
+   */
+  int newSession(SSL* ssl, SSL_SESSION* session);
+
+  /**
+   * Invoked by openssl when an SSL session is ejected from its internal cache.
+   * This can't be invoked in the current implementation because SSL's internal
+   * caching is disabled.
+   */
+  void removeSession(SSL_CTX* ctx, SSL_SESSION* session);
+
+  /**
+   * Invoked by openssl when a client requests a stateful session resumption.
+   * Triggers a lookup in our local cache and potentially an asynchronous
+   * request to an external cache.
+   */
+  SSL_SESSION* getSession(SSL* ssl, unsigned char* session_id,
+                          int id_len, int* copyflag);
+
+  /**
+   * Store a new session record in the external cache
+   */
+  bool storeCacheRecord(const std::string& sessionId, SSL_SESSION* session);
+
+  /**
+   * Lookup a session in the external cache for the specified SSL socket.
+   */
+  bool lookupCacheRecord(const std::string& sessionId,
+                         AsyncSSLSocket* sslSock);
+
+  /**
+   * Restart all clients waiting for the answer to an external cache query
+   */
+  void restartSSLAccept(const SSLCacheProvider::CacheContext* cacheCtx);
+
+  /**
+   * Get or create the LRU cache for the given VIP ID
+   */
+  static std::shared_ptr<ShardedLocalSSLSessionCache> getLocalCache(
+    uint32_t maxCacheSize, uint32_t cacheCullSize);
+
+  /**
+   * static functions registered as callbacks to openssl via
+   * SSL_CTX_sess_set_new/get/remove_cb
+   */
+  static int newSessionCallback(SSL* ssl, SSL_SESSION* session);
+  static void removeSessionCallback(SSL_CTX* ctx, SSL_SESSION* session);
+  static SSL_SESSION* getSessionCallback(SSL* ssl, unsigned char* session_id,
+                                         int id_len, int* copyflag);
+
+  static int32_t sExDataIndex_;
+  static std::shared_ptr<ShardedLocalSSLSessionCache> sCache_;
+  static std::mutex sCacheLock_;
+};
+
+}
diff --git a/folly/experimental/wangle/ssl/SSLStats.h b/folly/experimental/wangle/ssl/SSLStats.h
new file mode 100644 (file)
index 0000000..761a843
--- /dev/null
@@ -0,0 +1,42 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+namespace folly {
+
+class SSLStats {
+ public:
+  virtual ~SSLStats() noexcept {}
+
+  // downstream
+  virtual void recordSSLAcceptLatency(int64_t latency) noexcept = 0;
+  virtual void recordTLSTicket(bool ticketNew, bool ticketHit) noexcept = 0;
+  virtual void recordSSLSession(bool sessionNew, bool sessionHit, bool foreign)
+    noexcept = 0;
+  virtual void recordSSLSessionRemove() noexcept = 0;
+  virtual void recordSSLSessionFree(uint32_t freed) noexcept = 0;
+  virtual void recordSSLSessionSetError(uint32_t err) noexcept = 0;
+  virtual void recordSSLSessionGetError(uint32_t err) noexcept = 0;
+  virtual void recordClientRenegotiation() noexcept = 0;
+
+  // upstream
+  virtual void recordSSLUpstreamConnection(bool handshake) noexcept = 0;
+  virtual void recordSSLUpstreamConnectionError(bool verifyError) noexcept = 0;
+  virtual void recordCryptoSSLExternalAttempt() noexcept = 0;
+  virtual void recordCryptoSSLExternalConnAlreadyClosed() noexcept = 0;
+  virtual void recordCryptoSSLExternalApplicationException() noexcept = 0;
+  virtual void recordCryptoSSLExternalSuccess() noexcept = 0;
+  virtual void recordCryptoSSLExternalDuration(uint64_t duration) noexcept = 0;
+  virtual void recordCryptoSSLLocalAttempt() noexcept = 0;
+  virtual void recordCryptoSSLLocalSuccess() noexcept = 0;
+
+};
+
+}
diff --git a/folly/experimental/wangle/ssl/SSLUtil.cpp b/folly/experimental/wangle/ssl/SSLUtil.cpp
new file mode 100644 (file)
index 0000000..5557d1e
--- /dev/null
@@ -0,0 +1,76 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/experimental/wangle/ssl/SSLUtil.h>
+
+#include <folly/Memory.h>
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL
+#define OPENSSL_GE_101 1
+#include <openssl/asn1.h>
+#include <openssl/x509v3.h>
+#else
+#undef OPENSSL_GE_101
+#endif
+
+namespace folly {
+
+std::mutex SSLUtil::sIndexLock_;
+
+std::unique_ptr<std::string> SSLUtil::getCommonName(const X509* cert) {
+  X509_NAME* subject = X509_get_subject_name((X509*)cert);
+  if (!subject) {
+    return nullptr;
+  }
+  char cn[ub_common_name + 1];
+  int res = X509_NAME_get_text_by_NID(subject, NID_commonName,
+                                      cn, ub_common_name);
+  if (res <= 0) {
+    return nullptr;
+  } else {
+    cn[ub_common_name] = '\0';
+    return folly::make_unique<std::string>(cn);
+  }
+}
+
+std::unique_ptr<std::list<std::string>> SSLUtil::getSubjectAltName(
+    const X509* cert) {
+#ifdef OPENSSL_GE_101
+  auto nameList = folly::make_unique<std::list<std::string>>();
+  GENERAL_NAMES* names = (GENERAL_NAMES*)X509_get_ext_d2i(
+      (X509*)cert, NID_subject_alt_name, nullptr, nullptr);
+  if (names) {
+    auto guard = folly::makeGuard([names] { GENERAL_NAMES_free(names); });
+    size_t count = sk_GENERAL_NAME_num(names);
+    CHECK(count < std::numeric_limits<int>::max());
+    for (int i = 0; i < (int)count; ++i) {
+      GENERAL_NAME* generalName = sk_GENERAL_NAME_value(names, i);
+      if (generalName->type == GEN_DNS) {
+        ASN1_STRING* s = generalName->d.dNSName;
+        const char* name = (const char*)ASN1_STRING_data(s);
+        // I can't find any docs on what a negative return value here
+        // would mean, so I'm going to ignore it.
+        auto len = ASN1_STRING_length(s);
+        DCHECK(len >= 0);
+        if (size_t(len) != strlen(name)) {
+          // Null byte(s) in the name; return an error rather than depending on
+          // the caller to safely handle this case.
+          return nullptr;
+        }
+        nameList->emplace_back(name);
+      }
+    }
+  }
+  return nameList;
+#else
+  return nullptr;
+#endif
+}
+
+}
diff --git a/folly/experimental/wangle/ssl/SSLUtil.h b/folly/experimental/wangle/ssl/SSLUtil.h
new file mode 100644 (file)
index 0000000..20a17a9
--- /dev/null
@@ -0,0 +1,102 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/String.h>
+#include <mutex>
+#include <folly/io/async/AsyncSSLSocket.h>
+
+namespace folly {
+
+/**
+ * SSL session establish/resume status
+ *
+ * changing these values will break logging pipelines
+ */
+enum class SSLResumeEnum : uint8_t {
+  HANDSHAKE = 0,
+  RESUME_SESSION_ID = 1,
+  RESUME_TICKET = 3,
+  NA = 2
+};
+
+enum class SSLErrorEnum {
+  NO_ERROR,
+  TIMEOUT,
+  DROPPED
+};
+
+class SSLUtil {
+ private:
+  static std::mutex sIndexLock_;
+
+ public:
+  /**
+   * Ensures only one caller will allocate an ex_data index for a given static
+   * or global.
+   */
+  static void getSSLCtxExIndex(int* pindex) {
+    std::lock_guard<std::mutex> g(sIndexLock_);
+    if (*pindex < 0) {
+      *pindex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+    }
+  }
+
+  static void getRSAExIndex(int* pindex) {
+    std::lock_guard<std::mutex> g(sIndexLock_);
+    if (*pindex < 0) {
+      *pindex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+    }
+  }
+
+  static inline std::string hexlify(const std::string& binary) {
+    std::string hex;
+    folly::hexlify<std::string, std::string>(binary, hex);
+
+    return hex;
+  }
+
+  static inline const std::string& hexlify(const std::string& binary,
+                                           std::string& hex) {
+    folly::hexlify<std::string, std::string>(binary, hex);
+
+    return hex;
+  }
+
+  /**
+   * Return the SSL resume type for the given socket.
+   */
+  static inline SSLResumeEnum getResumeState(
+    AsyncSSLSocket* sslSocket) {
+    return sslSocket->getSSLSessionReused() ?
+      (sslSocket->sessionIDResumed() ?
+        SSLResumeEnum::RESUME_SESSION_ID :
+        SSLResumeEnum::RESUME_TICKET) :
+      SSLResumeEnum::HANDSHAKE;
+  }
+
+  /**
+   * Get the Common Name from an X.509 certificate
+   * @param cert  certificate to inspect
+   * @return  common name, or null if an error occurs
+   */
+  static std::unique_ptr<std::string> getCommonName(const X509* cert);
+
+  /**
+   * Get the Subject Alternative Name value(s) from an X.509 certificate
+   * @param cert  certificate to inspect
+   * @return  set of zero or more alternative names, or null if
+   *            an error occurs
+   */
+  static std::unique_ptr<std::list<std::string>> getSubjectAltName(
+      const X509* cert);
+};
+
+}
diff --git a/folly/experimental/wangle/ssl/TLSTicketKeyManager.cpp b/folly/experimental/wangle/ssl/TLSTicketKeyManager.cpp
new file mode 100644 (file)
index 0000000..c02153a
--- /dev/null
@@ -0,0 +1,305 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/experimental/wangle/ssl/TLSTicketKeyManager.h>
+
+#include <folly/experimental/wangle/ssl/SSLStats.h>
+#include <folly/experimental/wangle/ssl/SSLUtil.h>
+
+#include <folly/String.h>
+#include <openssl/aes.h>
+#include <openssl/rand.h>
+#include <openssl/ssl.h>
+#include <folly/io/async/AsyncTimeout.h>
+
+#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
+using std::string;
+
+namespace {
+
+const int kTLSTicketKeyNameLen = 4;
+const int kTLSTicketKeySaltLen = 12;
+
+}
+
+namespace folly {
+
+
+// TLSTicketKeyManager Implementation
+int32_t TLSTicketKeyManager::sExDataIndex_ = -1;
+
+TLSTicketKeyManager::TLSTicketKeyManager(SSLContext* ctx, SSLStats* stats)
+  : ctx_(ctx),
+    randState_(0),
+    stats_(stats) {
+  SSLUtil::getSSLCtxExIndex(&sExDataIndex_);
+  SSL_CTX_set_ex_data(ctx_->getSSLCtx(), sExDataIndex_, this);
+}
+
+TLSTicketKeyManager::~TLSTicketKeyManager() {
+}
+
+int
+TLSTicketKeyManager::callback(SSL* ssl, unsigned char* keyName,
+                              unsigned char* iv,
+                              EVP_CIPHER_CTX* cipherCtx,
+                              HMAC_CTX* hmacCtx, int encrypt) {
+  TLSTicketKeyManager* manager = nullptr;
+  SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
+  manager = (TLSTicketKeyManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
+
+  if (manager == nullptr) {
+    LOG(FATAL) << "Null TLSTicketKeyManager in callback" ;
+    return -1;
+  }
+  return manager->processTicket(ssl, keyName, iv, cipherCtx, hmacCtx, encrypt);
+}
+
+int
+TLSTicketKeyManager::processTicket(SSL* ssl, unsigned char* keyName,
+                                   unsigned char* iv,
+                                   EVP_CIPHER_CTX* cipherCtx,
+                                   HMAC_CTX* hmacCtx, int encrypt) {
+  uint8_t salt[kTLSTicketKeySaltLen];
+  uint8_t* saltptr = nullptr;
+  uint8_t output[SHA256_DIGEST_LENGTH];
+  uint8_t* hmacKey = nullptr;
+  uint8_t* aesKey = nullptr;
+  TLSTicketKeySource* key = nullptr;
+  int result = 0;
+
+  if (encrypt) {
+    key = findEncryptionKey();
+    if (key == nullptr) {
+      // no keys available to encrypt
+      VLOG(2) << "No TLS ticket key found";
+      return -1;
+    }
+    VLOG(4) << "Encrypting new ticket with key name=" <<
+      SSLUtil::hexlify(key->keyName_);
+
+    // Get a random salt and write out key name
+    RAND_pseudo_bytes(salt, (int)sizeof(salt));
+    memcpy(keyName, key->keyName_.data(), kTLSTicketKeyNameLen);
+    memcpy(keyName + kTLSTicketKeyNameLen, salt, kTLSTicketKeySaltLen);
+
+    // Create the unique keys by hashing with the salt
+    makeUniqueKeys(key->keySource_, sizeof(key->keySource_), salt, output);
+    // This relies on the fact that SHA256 has 32 bytes of output
+    // and that AES-128 keys are 16 bytes
+    hmacKey = output;
+    aesKey = output + SHA256_DIGEST_LENGTH / 2;
+
+    // Initialize iv and cipher/mac CTX
+    RAND_pseudo_bytes(iv, AES_BLOCK_SIZE);
+    HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2,
+                 EVP_sha256(), nullptr);
+    EVP_EncryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv);
+
+    result = 1;
+  } else {
+    key = findDecryptionKey(keyName);
+    if (key == nullptr) {
+      // no ticket found for decryption - will issue a new ticket
+      if (VLOG_IS_ON(4)) {
+        string skeyName((char *)keyName, kTLSTicketKeyNameLen);
+        VLOG(4) << "Can't find ticket key with name=" <<
+          SSLUtil::hexlify(skeyName)<< ", will generate new ticket";
+      }
+
+      result = 0;
+    } else {
+      VLOG(4) << "Decrypting ticket with key name=" <<
+        SSLUtil::hexlify(key->keyName_);
+
+      // Reconstruct the unique key via the salt
+      saltptr = keyName + kTLSTicketKeyNameLen;
+      makeUniqueKeys(key->keySource_, sizeof(key->keySource_), saltptr, output);
+      hmacKey = output;
+      aesKey = output + SHA256_DIGEST_LENGTH / 2;
+
+      // Initialize cipher/mac CTX
+      HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2,
+                   EVP_sha256(), nullptr);
+      EVP_DecryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv);
+
+      result = 1;
+    }
+  }
+  // result records whether a ticket key was found to decrypt this ticket,
+  // not wether the session was re-used.
+  if (stats_) {
+    stats_->recordTLSTicket(encrypt, result);
+  }
+
+  return result;
+}
+
+bool
+TLSTicketKeyManager::setTLSTicketKeySeeds(
+    const std::vector<std::string>& oldSeeds,
+    const std::vector<std::string>& currentSeeds,
+    const std::vector<std::string>& newSeeds) {
+
+  bool result = true;
+
+  activeKeys_.clear();
+  ticketKeys_.clear();
+  ticketSeeds_.clear();
+  const std::vector<string> *seedList = &oldSeeds;
+  for (uint32_t i = 0; i < 3; i++) {
+    TLSTicketSeedType type = (TLSTicketSeedType)i;
+    if (type == SEED_CURRENT) {
+      seedList = &currentSeeds;
+    } else if (type == SEED_NEW) {
+      seedList = &newSeeds;
+    }
+
+    for (const auto& seedInput: *seedList) {
+      TLSTicketSeed* seed = insertSeed(seedInput, type);
+      if (seed == nullptr) {
+        result = false;
+        continue;
+      }
+      insertNewKey(seed, 1, nullptr);
+    }
+  }
+  if (!result) {
+    VLOG(2) << "One or more seeds failed to decode";
+  }
+
+  if (ticketKeys_.size() == 0 || activeKeys_.size() == 0) {
+    LOG(WARNING) << "No keys configured, falling back to default";
+    SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), nullptr);
+    return false;
+  }
+  SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(),
+                                   TLSTicketKeyManager::callback);
+
+  return true;
+}
+
+string
+TLSTicketKeyManager::makeKeyName(TLSTicketSeed* seed, uint32_t n,
+                                 unsigned char* nameBuf) {
+  SHA256_CTX ctx;
+
+  SHA256_Init(&ctx);
+  SHA256_Update(&ctx, seed->seedName_, sizeof(seed->seedName_));
+  SHA256_Update(&ctx, &n, sizeof(n));
+  SHA256_Final(nameBuf, &ctx);
+  return string((char *)nameBuf, kTLSTicketKeyNameLen);
+}
+
+TLSTicketKeyManager::TLSTicketKeySource*
+TLSTicketKeyManager::insertNewKey(TLSTicketSeed* seed, uint32_t hashCount,
+                                  TLSTicketKeySource* prevKey) {
+  unsigned char nameBuf[SHA256_DIGEST_LENGTH];
+  std::unique_ptr<TLSTicketKeySource> newKey(new TLSTicketKeySource);
+
+  // This function supports hash chaining but it is not currently used.
+
+  if (prevKey != nullptr) {
+    hashNth(prevKey->keySource_, sizeof(prevKey->keySource_),
+            newKey->keySource_, 1);
+  } else {
+    // can't go backwards or the current is missing, start from the beginning
+    hashNth((unsigned char *)seed->seed_.data(), seed->seed_.length(),
+            newKey->keySource_, hashCount);
+  }
+
+  newKey->hashCount_ = hashCount;
+  newKey->keyName_ = makeKeyName(seed, hashCount, nameBuf);
+  newKey->type_ = seed->type_;
+  auto it = ticketKeys_.insert(std::make_pair(newKey->keyName_,
+        std::move(newKey)));
+
+  auto key = it.first->second.get();
+  if (key->type_ == SEED_CURRENT) {
+    activeKeys_.push_back(key);
+  }
+  VLOG(4) << "Adding key for " << hashCount << " type=" <<
+    (uint32_t)key->type_ << " Name=" << SSLUtil::hexlify(key->keyName_);
+
+  return key;
+}
+
+void
+TLSTicketKeyManager::hashNth(const unsigned char* input, size_t input_len,
+                             unsigned char* output, uint32_t n) {
+  assert(n > 0);
+  for (uint32_t i = 0; i < n; i++) {
+    SHA256(input, input_len, output);
+    input = output;
+    input_len = SHA256_DIGEST_LENGTH;
+  }
+}
+
+TLSTicketKeyManager::TLSTicketSeed *
+TLSTicketKeyManager::insertSeed(const string& seedInput,
+                                TLSTicketSeedType type) {
+  TLSTicketSeed* seed = nullptr;
+  string seedOutput;
+
+  if (!folly::unhexlify<string, string>(seedInput, seedOutput)) {
+    LOG(WARNING) << "Failed to decode seed type=" << (uint32_t)type <<
+      " seed=" << seedInput;
+    return seed;
+  }
+
+  seed = new TLSTicketSeed();
+  seed->seed_ = seedOutput;
+  seed->type_ = type;
+  SHA256((unsigned char *)seedOutput.data(), seedOutput.length(),
+         seed->seedName_);
+  ticketSeeds_.push_back(std::unique_ptr<TLSTicketSeed>(seed));
+
+  return seed;
+}
+
+TLSTicketKeyManager::TLSTicketKeySource *
+TLSTicketKeyManager::findEncryptionKey() {
+  TLSTicketKeySource* result = nullptr;
+  // call to rand here is a bit hokey since it's not cryptographically
+  // random, and is predictably seeded with 0.  However, activeKeys_
+  // is probably not going to have very many keys in it, and most
+  // likely only 1.
+  size_t numKeys = activeKeys_.size();
+  if (numKeys > 0) {
+    result = activeKeys_[rand_r(&randState_) % numKeys];
+  }
+  return result;
+}
+
+TLSTicketKeyManager::TLSTicketKeySource *
+TLSTicketKeyManager::findDecryptionKey(unsigned char* keyName) {
+  string name((char *)keyName, kTLSTicketKeyNameLen);
+  TLSTicketKeySource* key = nullptr;
+  TLSTicketKeyMap::iterator mapit = ticketKeys_.find(name);
+  if (mapit != ticketKeys_.end()) {
+    key = mapit->second.get();
+  }
+  return key;
+}
+
+void
+TLSTicketKeyManager::makeUniqueKeys(unsigned char* parentKey,
+                                    size_t keyLen,
+                                    unsigned char* salt,
+                                    unsigned char* output) {
+  SHA256_CTX hash_ctx;
+
+  SHA256_Init(&hash_ctx);
+  SHA256_Update(&hash_ctx, parentKey, keyLen);
+  SHA256_Update(&hash_ctx, salt, kTLSTicketKeySaltLen);
+  SHA256_Final(output, &hash_ctx);
+}
+
+} // namespace
+#endif
diff --git a/folly/experimental/wangle/ssl/TLSTicketKeyManager.h b/folly/experimental/wangle/ssl/TLSTicketKeyManager.h
new file mode 100644 (file)
index 0000000..4000c13
--- /dev/null
@@ -0,0 +1,198 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/SSLContext.h>
+#include <folly/io/async/EventBase.h>
+
+namespace folly {
+
+#ifndef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
+class TLSTicketKeyManager {};
+#else
+class SSLStats;
+/**
+ * The TLSTicketKeyManager handles TLS ticket key encryption and decryption in
+ * a way that facilitates sharing the ticket keys across a range of servers.
+ * Hash chaining is employed to achieve frequent key rotation with minimal
+ * configuration change.  The scheme is as follows:
+ *
+ * The manager is supplied with three lists of seeds (old, current and new).
+ * The config should be updated with new seeds periodically (e.g., daily).
+ * 3 config changes are recommended to achieve the smoothest seed rotation
+ * eg:
+ *     1. Introduce new seed in the push prior to rotation
+ *     2. Rotation push
+ *     3. Remove old seeds in the push following rotation
+ *
+ * Multiple seeds are supported but only a single seed is required.
+ *
+ * Generating encryption keys from the seed works as follows.  For a given
+ * seed, hash forward N times where N is currently the constant 1.
+ * This is the base key.  The name of the base key is the first 4
+ * bytes of hash(hash(seed), N).  This is copied into the first 4 bytes of the
+ * TLS ticket key name field.
+ *
+ * For each new ticket encryption, the manager generates a random 12 byte salt.
+ * Hash the salt and the base key together to form the encryption key for
+ * that ticket.  The salt is included in the ticket's 'key name' field so it
+ * can be used to derive the decryption key.  The salt is copied into the second
+ * 8 bytes of the TLS ticket key name field.
+ *
+ * A key is valid for decryption for the lifetime of the instance.
+ * Sessions will be valid for less time than that, which results in an extra
+ * symmetric decryption to discover the session is expired.
+ *
+ * A TLSTicketKeyManager should be used in only one thread, and should have
+ * a 1:1 relationship with the SSLContext provided.
+ *
+ */
+class TLSTicketKeyManager : private boost::noncopyable {
+ public:
+
+  explicit TLSTicketKeyManager(folly::SSLContext* ctx,
+                               SSLStats* stats);
+
+  virtual ~TLSTicketKeyManager();
+
+  /**
+   * SSL callback to set up encryption/decryption context for a TLS Ticket Key.
+   *
+   * This will be supplied to the SSL library via
+   * SSL_CTX_set_tlsext_ticket_key_cb.
+   */
+  static int callback(SSL* ssl, unsigned char* keyName,
+                      unsigned char* iv,
+                      EVP_CIPHER_CTX* cipherCtx,
+                      HMAC_CTX* hmacCtx, int encrypt);
+
+  /**
+   * Initialize the manager with three sets of seeds.  There must be at least
+   * one current seed, or the manager will revert to the default SSL behavior.
+   *
+   * @param oldSeeds Seeds previously used which can still decrypt.
+   * @param currentSeeds Seeds to use for new ticket encryptions.
+   * @param newSeeds Seeds which will be used soon, can be used to decrypt
+   *                 in case some servers in the cluster have already rotated.
+   */
+  bool setTLSTicketKeySeeds(const std::vector<std::string>& oldSeeds,
+                            const std::vector<std::string>& currentSeeds,
+                            const std::vector<std::string>& newSeeds);
+
+ private:
+  enum TLSTicketSeedType {
+    SEED_OLD = 0,
+    SEED_CURRENT,
+    SEED_NEW
+  };
+
+  /* The seeds supplied by the configuration */
+  struct TLSTicketSeed {
+    std::string seed_;
+    TLSTicketSeedType type_;
+    unsigned char seedName_[SHA256_DIGEST_LENGTH];
+  };
+
+  struct TLSTicketKeySource {
+    int32_t hashCount_;
+    std::string keyName_;
+    TLSTicketSeedType type_;
+    unsigned char keySource_[SHA256_DIGEST_LENGTH];
+  };
+
+  /**
+   * Method to setup encryption/decryption context for a TLS Ticket Key
+   *
+   * OpenSSL documentation is thin on the return value semantics.
+   *
+   * For encrypt=1, return < 0 on error, >= 0 for successfully initialized
+   * For encrypt=0, return < 0 on error, 0 on key not found
+   *                 1 on key found, 2 renew_ticket
+   *
+   * renew_ticket means a new ticket will be issued.  We could return this value
+   * when receiving a ticket encrypted with a key derived from an OLD seed.
+   * However, session_timeout seconds after deploying with a seed
+   * rotated from  CURRENT -> OLD, there will be no valid tickets outstanding
+   * encrypted with the old key.  This grace period means no unnecessary
+   * handshakes will be performed.  If the seed is believed compromised, it
+   * should NOT be configured as an OLD seed.
+   */
+  int processTicket(SSL* ssl, unsigned char* keyName,
+                    unsigned char* iv,
+                    EVP_CIPHER_CTX* cipherCtx,
+                    HMAC_CTX* hmacCtx, int encrypt);
+
+  // Creates the name for the nth key generated from seed
+  std::string makeKeyName(TLSTicketSeed* seed, uint32_t n,
+                          unsigned char* nameBuf);
+
+  /**
+   * Creates the key hashCount hashes from the given seed and inserts it in
+   * ticketKeys.  A naked pointer to the key is returned for additional
+   * processing if needed.
+   */
+  TLSTicketKeySource* insertNewKey(TLSTicketSeed* seed, uint32_t hashCount,
+                                   TLSTicketKeySource* prevKeySource);
+
+  /**
+   * hashes input N times placing result in output, which must be at least
+   * SHA256_DIGEST_LENGTH long.
+   */
+  void hashNth(const unsigned char* input, size_t input_len,
+               unsigned char* output, uint32_t n);
+
+  /**
+   * Adds the given seed to the manager
+   */
+  TLSTicketSeed* insertSeed(const std::string& seedInput,
+                            TLSTicketSeedType type);
+
+  /**
+   * Locate a key for encrypting a new ticket
+   */
+  TLSTicketKeySource* findEncryptionKey();
+
+  /**
+   * Locate a key for decrypting a ticket with the given keyName
+   */
+  TLSTicketKeySource* findDecryptionKey(unsigned char* keyName);
+
+  /**
+   * Derive a unique key from the parent key and the salt via hashing
+   */
+  void makeUniqueKeys(unsigned char* parentKey, size_t keyLen,
+                      unsigned char* salt, unsigned char* output);
+
+  /**
+   * For standalone decryption utility
+   */
+  friend int decrypt_fb_ticket(folly::TLSTicketKeyManager* manager,
+                               const std::string& testTicket,
+                               SSL_SESSION **psess);
+
+  typedef std::vector<std::unique_ptr<TLSTicketSeed>> TLSTicketSeedList;
+  typedef std::map<std::string, std::unique_ptr<TLSTicketKeySource> >
+    TLSTicketKeyMap;
+  typedef std::vector<TLSTicketKeySource *> TLSActiveKeyList;
+
+  TLSTicketSeedList ticketSeeds_;
+  // All key sources that can be used for decryption
+  TLSTicketKeyMap ticketKeys_;
+  // Key sources that can be used for encryption
+  TLSActiveKeyList activeKeys_;
+
+  folly::SSLContext* ctx_;
+  uint32_t randState_;
+  SSLStats* stats_{nullptr};
+
+  static int32_t sExDataIndex_;
+};
+#endif
+}
diff --git a/folly/experimental/wangle/ssl/TLSTicketKeySeeds.h b/folly/experimental/wangle/ssl/TLSTicketKeySeeds.h
new file mode 100644 (file)
index 0000000..c40ae58
--- /dev/null
@@ -0,0 +1,20 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+namespace folly {
+
+struct TLSTicketKeySeeds {
+  std::vector<std::string> oldSeeds;
+  std::vector<std::string> currentSeeds;
+  std::vector<std::string> newSeeds;
+};
+
+}
diff --git a/folly/experimental/wangle/ssl/test/SSLCacheTest.cpp b/folly/experimental/wangle/ssl/test/SSLCacheTest.cpp
new file mode 100644 (file)
index 0000000..2433cfc
--- /dev/null
@@ -0,0 +1,272 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/Portability.h>
+#include <folly/io/async/EventBase.h>
+#include <gflags/gflags.h>
+#include <iostream>
+#include <thread>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <vector>
+
+using namespace std;
+using namespace folly;
+
+DEFINE_int32(clients, 1, "Number of simulated SSL clients");
+DEFINE_int32(threads, 1, "Number of threads to spread clients across");
+DEFINE_int32(requests, 2, "Total number of requests per client");
+DEFINE_int32(port, 9423, "Server port");
+DEFINE_bool(sticky, false, "A given client sends all reqs to one "
+            "(random) server");
+DEFINE_bool(global, false, "All clients in a thread use the same SSL session");
+DEFINE_bool(handshakes, false, "Force 100% handshakes");
+
+string f_servers[10];
+int f_num_servers = 0;
+int tnum = 0;
+
+class ClientRunner {
+ public:
+
+  ClientRunner(): reqs(0), hits(0), miss(0), num(tnum++) {}
+  void run();
+
+  int reqs;
+  int hits;
+  int miss;
+  int num;
+};
+
+class SSLCacheClient : public AsyncSocket::ConnectCallback,
+                       public AsyncSSLSocket::HandshakeCB
+{
+private:
+  EventBase* eventBase_;
+  int currReq_;
+  int serverIdx_;
+  AsyncSocket* socket_;
+  AsyncSSLSocket* sslSocket_;
+  SSL_SESSION* session_;
+  SSL_SESSION **pSess_;
+  std::shared_ptr<SSLContext> ctx_;
+  ClientRunner* cr_;
+
+public:
+  SSLCacheClient(EventBase* eventBase, SSL_SESSION **pSess, ClientRunner* cr);
+  ~SSLCacheClient() {
+    if (session_ && !FLAGS_global)
+      SSL_SESSION_free(session_);
+    if (socket_ != nullptr) {
+      if (sslSocket_ != nullptr) {
+        sslSocket_->destroy();
+        sslSocket_ = nullptr;
+      }
+      socket_->destroy();
+      socket_ = nullptr;
+    }
+  };
+
+  void start();
+
+  virtual void connectSuccess() noexcept;
+
+  virtual void connectErr(const AsyncSocketException& ex)
+    noexcept ;
+
+  virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept;
+
+  virtual void handshakeErr(
+    AsyncSSLSocket* sock,
+    const AsyncSocketException& ex) noexcept;
+
+};
+
+int
+main(int argc, char* argv[])
+{
+  gflags::SetUsageMessage(std::string("\n\n"
+"usage: sslcachetest [options] -c <clients> -t <threads> servers\n"
+));
+  gflags::ParseCommandLineFlags(&argc, &argv, true);
+  int reqs = 0;
+  int hits = 0;
+  int miss = 0;
+  struct timeval start;
+  struct timeval end;
+  struct timeval result;
+
+  srand((unsigned int)time(nullptr));
+
+  for (int i = 1; i < argc; i++) {
+    f_servers[f_num_servers++] = argv[i];
+  }
+  if (f_num_servers == 0) {
+    cout << "require at least one server\n";
+    return 1;
+  }
+
+  gettimeofday(&start, nullptr);
+  if (FLAGS_threads == 1) {
+    ClientRunner r;
+    r.run();
+    gettimeofday(&end, nullptr);
+    reqs = r.reqs;
+    hits = r.hits;
+    miss = r.miss;
+  }
+  else {
+    std::vector<ClientRunner> clients;
+    std::vector<std::thread> threads;
+    for (int t = 0; t < FLAGS_threads; t++) {
+      threads.emplace_back([&] {
+          clients[t].run();
+        });
+    }
+    for (auto& thr: threads) {
+      thr.join();
+    }
+    gettimeofday(&end, nullptr);
+
+    for (const auto& client: clients) {
+      reqs += client.reqs;
+      hits += client.hits;
+      miss += client.miss;
+    }
+  }
+
+  timersub(&end, &start, &result);
+
+  cout << "Requests: " << reqs << endl;
+  cout << "Handshakes: " << miss << endl;
+  cout << "Resumes: " << hits << endl;
+  cout << "Runtime(ms): " << result.tv_sec << "." << result.tv_usec / 1000 <<
+    endl;
+
+  cout << "ops/sec: " << (reqs * 1.0) /
+    ((double)result.tv_sec * 1.0 + (double)result.tv_usec / 1000000.0) << endl;
+
+  return 0;
+}
+
+void
+ClientRunner::run()
+{
+  EventBase eb;
+  std::list<SSLCacheClient *> clients;
+  SSL_SESSION* session = nullptr;
+
+  for (int i = 0; i < FLAGS_clients; i++) {
+    SSLCacheClient* c = new SSLCacheClient(&eb, &session, this);
+    c->start();
+    clients.push_back(c);
+  }
+
+  eb.loop();
+
+  for (auto it = clients.begin(); it != clients.end(); it++) {
+    delete* it;
+  }
+
+  reqs += hits + miss;
+}
+
+SSLCacheClient::SSLCacheClient(EventBase* eb,
+                               SSL_SESSION **pSess,
+                               ClientRunner* cr)
+    : eventBase_(eb),
+      currReq_(0),
+      serverIdx_(0),
+      socket_(nullptr),
+      sslSocket_(nullptr),
+      session_(nullptr),
+      pSess_(pSess),
+      cr_(cr)
+{
+  ctx_.reset(new SSLContext());
+  ctx_->setOptions(SSL_OP_NO_TICKET);
+}
+
+void
+SSLCacheClient::start()
+{
+  if (currReq_ >= FLAGS_requests) {
+    cout << "+";
+    return;
+  }
+
+  if (currReq_ == 0 || !FLAGS_sticky) {
+    serverIdx_ = rand() % f_num_servers;
+  }
+  if (socket_ != nullptr) {
+    if (sslSocket_ != nullptr) {
+      sslSocket_->destroy();
+      sslSocket_ = nullptr;
+    }
+    socket_->destroy();
+    socket_ = nullptr;
+  }
+  socket_ = new AsyncSocket(eventBase_);
+  socket_->connect(this, f_servers[serverIdx_], (uint16_t)FLAGS_port);
+}
+
+void
+SSLCacheClient::connectSuccess() noexcept
+{
+  sslSocket_ = new AsyncSSLSocket(ctx_, eventBase_, socket_->detachFd(),
+                                   false);
+
+  if (!FLAGS_handshakes) {
+    if (session_ != nullptr)
+      sslSocket_->setSSLSession(session_);
+    else if (FLAGS_global && pSess_ != nullptr)
+      sslSocket_->setSSLSession(*pSess_);
+  }
+  sslSocket_->sslConn(this);
+}
+
+void
+SSLCacheClient::connectErr(const AsyncSocketException& ex)
+  noexcept
+{
+  cout << "connectError: " << ex.what() << endl;
+}
+
+void
+SSLCacheClient::handshakeSuc(AsyncSSLSocket* socket) noexcept
+{
+  if (sslSocket_->getSSLSessionReused()) {
+    cr_->hits++;
+  } else {
+    cr_->miss++;
+    if (session_ != nullptr) {
+      SSL_SESSION_free(session_);
+    }
+    session_ = sslSocket_->getSSLSession();
+    if (FLAGS_global && pSess_ != nullptr && *pSess_ == nullptr) {
+      *pSess_ = session_;
+    }
+  }
+  if ( ((cr_->hits + cr_->miss) % 100) == ((100 / FLAGS_threads) * cr_->num)) {
+    cout << ".";
+    cout.flush();
+  }
+  sslSocket_->closeNow();
+  currReq_++;
+  this->start();
+}
+
+void
+SSLCacheClient::handshakeErr(
+  AsyncSSLSocket* sock,
+  const AsyncSocketException& ex)
+  noexcept
+{
+  cout << "handshakeError: " << ex.what() << endl;
+}
diff --git a/folly/experimental/wangle/ssl/test/SSLContextManagerTest.cpp b/folly/experimental/wangle/ssl/test/SSLContextManagerTest.cpp
new file mode 100644 (file)
index 0000000..6e5815c
--- /dev/null
@@ -0,0 +1,87 @@
+/*
+ *  Copyright (c) 2014, Facebook, Inc.
+ *  All rights reserved.
+ *
+ *  This source code is licensed under the BSD-style license found in the
+ *  LICENSE file in the root directory of this source tree. An additional grant
+ *  of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/io/async/EventBase.h>
+#include <folly/io/async/SSLContext.h>
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+#include <folly/experimental/wangle/ssl/SSLContextManager.h>
+#include <folly/experimental/wangle/acceptor/DomainNameMisc.h>
+
+using std::shared_ptr;
+
+namespace folly {
+
+TEST(SSLContextManagerTest, Test1)
+{
+  EventBase eventBase;
+  SSLContextManager sslCtxMgr(&eventBase, "vip_ssl_context_manager_test_",
+                              true, nullptr);
+  auto www_facebook_com_ctx = std::make_shared<SSLContext>();
+  auto start_facebook_com_ctx = std::make_shared<SSLContext>();
+  auto start_abc_facebook_com_ctx = std::make_shared<SSLContext>();
+
+  sslCtxMgr.insertSSLCtxByDomainName(
+    "www.facebook.com",
+    strlen("www.facebook.com"),
+    www_facebook_com_ctx);
+  sslCtxMgr.insertSSLCtxByDomainName(
+    "www.facebook.com",
+    strlen("www.facebook.com"),
+    www_facebook_com_ctx);
+  try {
+    sslCtxMgr.insertSSLCtxByDomainName(
+      "www.facebook.com",
+      strlen("www.facebook.com"),
+      std::make_shared<SSLContext>());
+  } catch (const std::exception& ex) {
+  }
+  sslCtxMgr.insertSSLCtxByDomainName(
+    "*.facebook.com",
+    strlen("*.facebook.com"),
+    start_facebook_com_ctx);
+  sslCtxMgr.insertSSLCtxByDomainName(
+    "*.abc.facebook.com",
+    strlen("*.abc.facebook.com"),
+    start_abc_facebook_com_ctx);
+  try {
+    sslCtxMgr.insertSSLCtxByDomainName(
+      "*.abc.facebook.com",
+      strlen("*.abc.facebook.com"),
+      std::make_shared<SSLContext>());
+    FAIL();
+  } catch (const std::exception& ex) {
+  }
+
+  shared_ptr<SSLContext> retCtx;
+  retCtx = sslCtxMgr.getSSLCtx(DNString("www.facebook.com"));
+  EXPECT_EQ(retCtx, www_facebook_com_ctx);
+  retCtx = sslCtxMgr.getSSLCtx(DNString("WWW.facebook.com"));
+  EXPECT_EQ(retCtx, www_facebook_com_ctx);
+  EXPECT_FALSE(sslCtxMgr.getSSLCtx(DNString("xyz.facebook.com")));
+
+  retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("xyz.facebook.com"));
+  EXPECT_EQ(retCtx, start_facebook_com_ctx);
+  retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("XYZ.facebook.com"));
+  EXPECT_EQ(retCtx, start_facebook_com_ctx);
+
+  retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("www.abc.facebook.com"));
+  EXPECT_EQ(retCtx, start_abc_facebook_com_ctx);
+
+  // ensure "facebook.com" does not match "*.facebook.com"
+  EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("facebook.com")));
+  // ensure "Xfacebook.com" does not match "*.facebook.com"
+  EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("Xfacebook.com")));
+  // ensure wildcard name only matches one domain up
+  EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("abc.xyz.facebook.com")));
+
+  eventBase.loop(); // Clean up events before SSLContextManager is destructed
+}
+
+}